1use crate::error::RloxError;
7
8pub trait WeightUpdate: Send + Sync {
15 fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError>;
22
23 fn name(&self) -> &str;
25}
26
27pub struct ReptileUpdate;
29
30impl WeightUpdate for ReptileUpdate {
31 #[inline]
32 fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError> {
33 reptile_update(target, source, lr)
34 }
35
36 fn name(&self) -> &str {
37 "Reptile"
38 }
39}
40
41pub struct PolyakUpdate;
43
44impl WeightUpdate for PolyakUpdate {
45 #[inline]
46 fn apply(&self, target: &mut [f32], source: &[f32], lr: f32) -> Result<(), RloxError> {
47 polyak_update(target, source, lr)
48 }
49
50 fn name(&self) -> &str {
51 "Polyak"
52 }
53}
54
55#[inline]
61pub fn reptile_update(
62 meta_params: &mut [f32],
63 task_params: &[f32],
64 meta_lr: f32,
65) -> Result<(), RloxError> {
66 if meta_params.len() != task_params.len() {
67 return Err(RloxError::ShapeMismatch {
68 expected: format!("target.len()={}", meta_params.len()),
69 got: format!("source.len()={}", task_params.len()),
70 });
71 }
72
73 for (m, &t) in meta_params.iter_mut().zip(task_params.iter()) {
74 *m += meta_lr * (t - *m);
75 }
76 Ok(())
77}
78
79#[inline]
84pub fn polyak_update(target: &mut [f32], source: &[f32], tau: f32) -> Result<(), RloxError> {
85 if target.len() != source.len() {
86 return Err(RloxError::ShapeMismatch {
87 expected: format!("target.len()={}", target.len()),
88 got: format!("source.len()={}", source.len()),
89 });
90 }
91
92 let one_minus_tau = 1.0 - tau;
93 for (t, &s) in target.iter_mut().zip(source.iter()) {
94 *t = tau * s + one_minus_tau * *t;
95 }
96 Ok(())
97}
98
99pub fn average_weight_vectors(vectors: &[&[f32]]) -> Result<Vec<f32>, RloxError> {
103 if vectors.is_empty() {
104 return Err(RloxError::BufferError(
105 "cannot average zero weight vectors".into(),
106 ));
107 }
108
109 let dim = vectors[0].len();
110 for (i, v) in vectors.iter().enumerate().skip(1) {
111 if v.len() != dim {
112 return Err(RloxError::ShapeMismatch {
113 expected: format!("all vectors length {dim}"),
114 got: format!("vectors[{i}].len()={}", v.len()),
115 });
116 }
117 }
118
119 let n = vectors.len() as f32;
120 let mut result = vec![0.0f32; dim];
121 for v in vectors {
122 for (r, &val) in result.iter_mut().zip(v.iter()) {
123 *r += val;
124 }
125 }
126 for r in &mut result {
127 *r /= n;
128 }
129 Ok(result)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_reptile_update_lr_one_copies() {
138 let mut meta = vec![1.0f32, 2.0, 3.0];
139 let task = vec![4.0f32, 5.0, 6.0];
140 reptile_update(&mut meta, &task, 1.0).unwrap();
141 assert_eq!(meta, vec![4.0, 5.0, 6.0]);
142 }
143
144 #[test]
145 fn test_reptile_update_lr_zero_no_change() {
146 let mut meta = vec![1.0f32, 2.0, 3.0];
147 let task = vec![4.0f32, 5.0, 6.0];
148 reptile_update(&mut meta, &task, 0.0).unwrap();
149 assert_eq!(meta, vec![1.0, 2.0, 3.0]);
150 }
151
152 #[test]
153 fn test_reptile_update_lr_half() {
154 let mut meta = vec![0.0f32, 0.0];
155 let task = vec![2.0f32, 4.0];
156 reptile_update(&mut meta, &task, 0.5).unwrap();
157 assert!((meta[0] - 1.0).abs() < 1e-6);
158 assert!((meta[1] - 2.0).abs() < 1e-6);
159 }
160
161 #[test]
162 fn test_reptile_length_mismatch() {
163 let mut meta = vec![1.0f32, 2.0, 3.0];
164 let task = vec![4.0f32, 5.0];
165 let result = reptile_update(&mut meta, &task, 0.5);
166 assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
167 }
168
169 #[test]
170 fn test_average_weight_vectors_mean() {
171 let v1 = [1.0f32, 2.0, 3.0];
172 let v2 = [4.0f32, 5.0, 6.0];
173 let result = average_weight_vectors(&[&v1, &v2]).unwrap();
174 assert!((result[0] - 2.5).abs() < 1e-6);
175 assert!((result[1] - 3.5).abs() < 1e-6);
176 assert!((result[2] - 4.5).abs() < 1e-6);
177 }
178
179 #[test]
180 fn test_average_single_vector() {
181 let v = [7.0f32, 8.0, 9.0];
182 let result = average_weight_vectors(&[&v]).unwrap();
183 assert_eq!(result, vec![7.0, 8.0, 9.0]);
184 }
185
186 #[test]
187 fn test_average_empty_errors() {
188 let result = average_weight_vectors(&[]);
189 assert!(result.is_err());
190 }
191
192 #[test]
193 fn test_polyak_update_tau_one_copies() {
194 let mut target = vec![1.0f32, 2.0];
195 let source = vec![3.0f32, 4.0];
196 polyak_update(&mut target, &source, 1.0).unwrap();
197 assert_eq!(target, vec![3.0, 4.0]);
198 }
199
200 #[test]
201 fn test_polyak_update_tau_zero_no_change() {
202 let mut target = vec![1.0f32, 2.0];
203 let source = vec![3.0f32, 4.0];
204 polyak_update(&mut target, &source, 0.0).unwrap();
205 assert_eq!(target, vec![1.0, 2.0]);
206 }
207
208 #[test]
209 fn test_polyak_length_mismatch() {
210 let mut target = vec![1.0f32, 2.0];
211 let source = vec![3.0f32];
212 let result = polyak_update(&mut target, &source, 0.5);
213 assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
214 }
215
216 #[test]
217 fn test_trait_object_safety() {
218 let update: Box<dyn WeightUpdate> = Box::new(ReptileUpdate);
219 assert_eq!(update.name(), "Reptile");
220 let update: Box<dyn WeightUpdate> = Box::new(PolyakUpdate);
221 assert_eq!(update.name(), "Polyak");
222 }
223
224 mod proptests {
225 use super::*;
226 use proptest::prelude::*;
227
228 proptest! {
229 #[test]
230 fn prop_reptile_interpolates(
231 dim in 1usize..100,
232 lr in 0.0f32..1.0,
233 ) {
234 let meta: Vec<f32> = (0..dim).map(|i| i as f32).collect();
235 let task: Vec<f32> = (0..dim).map(|i| (i as f32) * 2.0 + 1.0).collect();
236 let mut result = meta.clone();
237 reptile_update(&mut result, &task, lr).unwrap();
238 for i in 0..dim {
239 let lo = meta[i].min(task[i]);
240 let hi = meta[i].max(task[i]);
241 prop_assert!(
242 result[i] >= lo - 1e-6 && result[i] <= hi + 1e-6,
243 "result[{i}]={} not in [{lo}, {hi}]", result[i]
244 );
245 }
246 }
247
248 #[test]
249 fn prop_polyak_interpolates(
250 dim in 1usize..100,
251 tau in 0.0f32..1.0,
252 ) {
253 let target: Vec<f32> = (0..dim).map(|i| i as f32).collect();
254 let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 3.0).collect();
255 let mut result = target.clone();
256 polyak_update(&mut result, &source, tau).unwrap();
257 for i in 0..dim {
258 let lo = target[i].min(source[i]);
259 let hi = target[i].max(source[i]);
260 prop_assert!(
261 result[i] >= lo - 1e-6 && result[i] <= hi + 1e-6,
262 "result[{i}]={} not in [{lo}, {hi}]", result[i]
263 );
264 }
265 }
266
267 #[test]
268 fn prop_average_idempotent(
269 dim in 1usize..50,
270 n in 1usize..10,
271 ) {
272 let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.7).collect();
273 let refs: Vec<&[f32]> = (0..n).map(|_| v.as_slice()).collect();
274 let result = average_weight_vectors(&refs).unwrap();
275 for i in 0..dim {
276 prop_assert!(
277 (result[i] - v[i]).abs() < 1e-5,
278 "averaging {n} copies: result[{i}]={}, expected {}",
279 result[i], v[i]
280 );
281 }
282 }
283 }
284 }
285}