rlox_core/training/
weight_ops.rs

1//! Weight vector operations for meta-learning and target network updates.
2//!
3//! Provides Reptile-style meta updates, Polyak (exponential moving average)
4//! updates for SAC/TD3 target networks, and weight vector averaging.
5
6use crate::error::RloxError;
7
8/// Trait for weight vector update strategies.
9///
10/// All operations are in-place on flat f32 weight vectors.
11/// This enables adding new meta-learning update rules (MAML outer step,
12/// Lookahead, exponential moving average variants) without modifying
13/// existing code.
14pub trait WeightUpdate: Send + Sync {
15    /// Apply the update rule in-place.
16    ///
17    /// `target` is modified. `source` is read-only. `lr` controls step size.
18    ///
19    /// # Errors
20    /// Returns `ShapeMismatch` if `target.len() != source.len()`.
21    fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError>;
22
23    /// Human-readable name.
24    fn name(&self) -> &str;
25}
26
27/// Reptile meta-learning update strategy.
28pub struct ReptileUpdate;
29
30impl WeightUpdate for ReptileUpdate {
31    #[inline]
32    fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError> {
33        reptile_update(target, source, lr)
34    }
35
36    fn name(&self) -> &str {
37        "Reptile"
38    }
39}
40
41/// Polyak (EMA) target network update strategy.
42pub struct PolyakUpdate;
43
44impl WeightUpdate for PolyakUpdate {
45    #[inline]
46    fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError> {
47        polyak_update(target, source, lr)
48    }
49
50    fn name(&self) -> &str {
51        "Polyak"
52    }
53}
54
55/// Reptile weight update: `params += lr * (task_params - params)`
56///
57/// Operates in-place on `meta_params`. Both slices must have the same length.
58/// When `meta_lr == 1.0`, this copies `task_params` into `meta_params`.
59/// When `meta_lr == 0.0`, `meta_params` is unchanged.
60#[inline]
61pub fn reptile_update(
62    meta_params: &mut [f32],
63    task_params: &[f32],
64    meta_lr: f32,
65) -> Result<(), RloxError> {
66    if meta_params.len() != task_params.len() {
67        return Err(RloxError::ShapeMismatch {
68            expected: format!("target.len()={}", meta_params.len()),
69            got: format!("source.len()={}", task_params.len()),
70        });
71    }
72
73    for (m, &t) in meta_params.iter_mut().zip(task_params.iter()) {
74        *m += meta_lr * (t - *m);
75    }
76    Ok(())
77}
78
79/// Exponential moving average (Polyak update):
80///   `target[i] = tau * source[i] + (1 - tau) * target[i]`
81///
82/// Used by SAC/TD3 for target network updates. Operates in-place on `target`.
83#[inline]
84pub fn polyak_update(target: &mut [f32], source: &[f32], tau: f32) -> Result<(), RloxError> {
85    if target.len() != source.len() {
86        return Err(RloxError::ShapeMismatch {
87            expected: format!("target.len()={}", target.len()),
88            got: format!("source.len()={}", source.len()),
89        });
90    }
91
92    let one_minus_tau = 1.0 - tau;
93    for (t, &s) in target.iter_mut().zip(source.iter()) {
94        *t = tau * s + one_minus_tau * *t;
95    }
96    Ok(())
97}
98
99/// Average N weight vectors element-wise: `result[i] = mean(vectors[j][i] for all j)`
100///
101/// All vectors must have the same length.
102pub fn average_weight_vectors(vectors: &[&[f32]]) -> Result<Vec<f32>, RloxError> {
103    if vectors.is_empty() {
104        return Err(RloxError::BufferError(
105            "cannot average zero weight vectors".into(),
106        ));
107    }
108
109    let dim = vectors[0].len();
110    for (i, v) in vectors.iter().enumerate().skip(1) {
111        if v.len() != dim {
112            return Err(RloxError::ShapeMismatch {
113                expected: format!("all vectors length {dim}"),
114                got: format!("vectors[{i}].len()={}", v.len()),
115            });
116        }
117    }
118
119    let n = vectors.len() as f32;
120    let mut result = vec![0.0f32; dim];
121    for v in vectors {
122        for (r, &val) in result.iter_mut().zip(v.iter()) {
123            *r += val;
124        }
125    }
126    for r in &mut result {
127        *r /= n;
128    }
129    Ok(result)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_reptile_update_lr_one_copies() {
138        let mut meta = vec![1.0f32, 2.0, 3.0];
139        let task = vec![4.0f32, 5.0, 6.0];
140        reptile_update(&mut meta, &task, 1.0).unwrap();
141        assert_eq!(meta, vec![4.0, 5.0, 6.0]);
142    }
143
144    #[test]
145    fn test_reptile_update_lr_zero_no_change() {
146        let mut meta = vec![1.0f32, 2.0, 3.0];
147        let task = vec![4.0f32, 5.0, 6.0];
148        reptile_update(&mut meta, &task, 0.0).unwrap();
149        assert_eq!(meta, vec![1.0, 2.0, 3.0]);
150    }
151
152    #[test]
153    fn test_reptile_update_lr_half() {
154        let mut meta = vec![0.0f32, 0.0];
155        let task = vec![2.0f32, 4.0];
156        reptile_update(&mut meta, &task, 0.5).unwrap();
157        assert!((meta[0] - 1.0).abs() < 1e-6);
158        assert!((meta[1] - 2.0).abs() < 1e-6);
159    }
160
161    #[test]
162    fn test_reptile_length_mismatch() {
163        let mut meta = vec![1.0f32, 2.0, 3.0];
164        let task = vec![4.0f32, 5.0];
165        let result = reptile_update(&mut meta, &task, 0.5);
166        assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
167    }
168
169    #[test]
170    fn test_average_weight_vectors_mean() {
171        let v1 = [1.0f32, 2.0, 3.0];
172        let v2 = [4.0f32, 5.0, 6.0];
173        let result = average_weight_vectors(&[&v1, &v2]).unwrap();
174        assert!((result[0] - 2.5).abs() < 1e-6);
175        assert!((result[1] - 3.5).abs() < 1e-6);
176        assert!((result[2] - 4.5).abs() < 1e-6);
177    }
178
179    #[test]
180    fn test_average_single_vector() {
181        let v = [7.0f32, 8.0, 9.0];
182        let result = average_weight_vectors(&[&v]).unwrap();
183        assert_eq!(result, vec![7.0, 8.0, 9.0]);
184    }
185
186    #[test]
187    fn test_average_empty_errors() {
188        let result = average_weight_vectors(&[]);
189        assert!(result.is_err());
190    }
191
192    #[test]
193    fn test_polyak_update_tau_one_copies() {
194        let mut target = vec![1.0f32, 2.0];
195        let source = vec![3.0f32, 4.0];
196        polyak_update(&mut target, &source, 1.0).unwrap();
197        assert_eq!(target, vec![3.0, 4.0]);
198    }
199
200    #[test]
201    fn test_polyak_update_tau_zero_no_change() {
202        let mut target = vec![1.0f32, 2.0];
203        let source = vec![3.0f32, 4.0];
204        polyak_update(&mut target, &source, 0.0).unwrap();
205        assert_eq!(target, vec![1.0, 2.0]);
206    }
207
208    #[test]
209    fn test_polyak_length_mismatch() {
210        let mut target = vec![1.0f32, 2.0];
211        let source = vec![3.0f32];
212        let result = polyak_update(&mut target, &source, 0.5);
213        assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
214    }
215
216    #[test]
217    fn test_trait_object_safety() {
218        let update: Box<dyn WeightUpdate> = Box::new(ReptileUpdate);
219        assert_eq!(update.name(), "Reptile");
220        let update: Box<dyn WeightUpdate> = Box::new(PolyakUpdate);
221        assert_eq!(update.name(), "Polyak");
222    }
223
224    mod proptests {
225        use super::*;
226        use proptest::prelude::*;
227
228        proptest! {
229            #[test]
230            fn prop_reptile_interpolates(
231                dim in 1usize..100,
232                lr in 0.0f32..1.0,
233            ) {
234                let meta: Vec<f32> = (0..dim).map(|i| i as f32).collect();
235                let task: Vec<f32> = (0..dim).map(|i| (i as f32) * 2.0 + 1.0).collect();
236                let mut result = meta.clone();
237                reptile_update(&mut result, &task, lr).unwrap();
238                for i in 0..dim {
239                    let lo = meta[i].min(task[i]);
240                    let hi = meta[i].max(task[i]);
241                    prop_assert!(
242                        result[i] >= lo - 1e-6 && result[i] <= hi + 1e-6,
243                        "result[{i}]={} not in [{lo}, {hi}]", result[i]
244                    );
245                }
246            }
247
248            #[test]
249            fn prop_polyak_interpolates(
250                dim in 1usize..100,
251                tau in 0.0f32..1.0,
252            ) {
253                let target: Vec<f32> = (0..dim).map(|i| i as f32).collect();
254                let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 3.0).collect();
255                let mut result = target.clone();
256                polyak_update(&mut result, &source, tau).unwrap();
257                for i in 0..dim {
258                    let lo = target[i].min(source[i]);
259                    let hi = target[i].max(source[i]);
260                    prop_assert!(
261                        result[i] >= lo - 1e-6 && result[i] <= hi + 1e-6,
262                        "result[{i}]={} not in [{lo}, {hi}]", result[i]
263                    );
264                }
265            }
266
267            #[test]
268            fn prop_average_idempotent(
269                dim in 1usize..50,
270                n in 1usize..10,
271            ) {
272                let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.7).collect();
273                let refs: Vec<&[f32]> = (0..n).map(|_| v.as_slice()).collect();
274                let result = average_weight_vectors(&refs).unwrap();
275                for i in 0..dim {
276                    prop_assert!(
277                        (result[i] - v[i]).abs() < 1e-5,
278                        "averaging {n} copies: result[{i}]={}, expected {}",
279                        result[i], v[i]
280                    );
281                }
282            }
283        }
284    }
285}