1#[inline]
22pub fn reptile_update_simd(target: &mut [f32], source: &[f32], lr: f32) {
23 assert_eq!(target.len(), source.len(), "length mismatch");
24 let n = target.len();
25 let chunks = n / 8;
26 let remainder = n % 8;
27
28 let (target_chunks, target_rest) = target.split_at_mut(chunks * 8);
29 let (source_chunks, source_rest) = source.split_at(chunks * 8);
30
31 for (t_chunk, s_chunk) in target_chunks
33 .chunks_exact_mut(8)
34 .zip(source_chunks.chunks_exact(8))
35 {
36 for i in 0..8 {
37 t_chunk[i] += lr * (s_chunk[i] - t_chunk[i]);
38 }
39 }
40
41 for i in 0..remainder {
43 target_rest[i] += lr * (source_rest[i] - target_rest[i]);
44 }
45}
46
47#[inline]
54pub fn polyak_update_simd(target: &mut [f32], source: &[f32], tau: f32) {
55 assert_eq!(target.len(), source.len(), "length mismatch");
56 let n = target.len();
57 let one_minus_tau = 1.0 - tau;
58
59 let chunks = n / 8;
60 let remainder = n % 8;
61
62 let (target_chunks, target_rest) = target.split_at_mut(chunks * 8);
63 let (source_chunks, source_rest) = source.split_at(chunks * 8);
64
65 for (t_chunk, s_chunk) in target_chunks
66 .chunks_exact_mut(8)
67 .zip(source_chunks.chunks_exact(8))
68 {
69 for i in 0..8 {
70 t_chunk[i] = tau * s_chunk[i] + one_minus_tau * t_chunk[i];
71 }
72 }
73
74 for i in 0..remainder {
75 target_rest[i] = tau * source_rest[i] + one_minus_tau * target_rest[i];
76 }
77}
78
79#[inline]
91pub fn pbrs_simd(
92 rewards: &[f64],
93 phi_current: &[f64],
94 phi_next: &[f64],
95 gamma: f64,
96 dones: &[f64],
97) -> Vec<f64> {
98 let n = rewards.len();
99 assert_eq!(phi_current.len(), n);
100 assert_eq!(phi_next.len(), n);
101 assert_eq!(dones.len(), n);
102
103 let mut output = vec![0.0f64; n];
104
105 let chunks = n / 4;
106 let remainder = n % 4;
107
108 for chunk_idx in 0..chunks {
110 let base = chunk_idx * 4;
111 for i in 0..4 {
112 let idx = base + i;
113 let not_done = 1.0 - dones[idx];
116 let shaping = not_done * (gamma * phi_next[idx] - phi_current[idx]);
117 output[idx] = rewards[idx] + shaping;
118 }
119 }
120
121 let base = chunks * 4;
123 for i in 0..remainder {
124 let idx = base + i;
125 let not_done = 1.0 - dones[idx];
126 let shaping = not_done * (gamma * phi_next[idx] - phi_current[idx]);
127 output[idx] = rewards[idx] + shaping;
128 }
129
130 output
131}
132
133#[inline]
144pub fn compute_priorities_simd(losses: &[f64], epsilon: f64) -> Vec<f64> {
145 let n = losses.len();
146 let mut output = vec![0.0f64; n];
147
148 let chunks = n / 4;
149 let remainder = n % 4;
150
151 for chunk_idx in 0..chunks {
152 let base = chunk_idx * 4;
153 for i in 0..4 {
154 output[base + i] = losses[base + i].abs() + epsilon;
155 }
156 }
157
158 let base = chunks * 4;
159 for i in 0..remainder {
160 output[base + i] = losses[base + i].abs() + epsilon;
161 }
162
163 output
164}
165
166#[inline]
173pub fn average_weights_simd(vectors: &[&[f32]]) -> Vec<f32> {
174 assert!(!vectors.is_empty(), "cannot average zero vectors");
175 let dim = vectors[0].len();
176 for v in vectors.iter().skip(1) {
177 assert_eq!(v.len(), dim, "all vectors must have the same length");
178 }
179
180 let n = vectors.len() as f32;
181 let mut result = vec![0.0f32; dim];
182
183 for v in vectors {
184 let chunks = dim / 8;
185 let remainder = dim % 8;
186
187 let (result_chunks, result_rest) = result.split_at_mut(chunks * 8);
188 let (v_chunks, v_rest) = v.split_at(chunks * 8);
189
190 for (r_chunk, v_chunk) in result_chunks
191 .chunks_exact_mut(8)
192 .zip(v_chunks.chunks_exact(8))
193 {
194 for i in 0..8 {
195 r_chunk[i] += v_chunk[i];
196 }
197 }
198
199 for i in 0..remainder {
200 result_rest[i] += v_rest[i];
201 }
202 }
203
204 let chunks = dim / 8;
206 let remainder = dim % 8;
207 let (result_chunks, result_rest) = result.split_at_mut(chunks * 8);
208 for chunk in result_chunks.chunks_exact_mut(8) {
209 for item in chunk.iter_mut().take(8) {
210 *item /= n;
211 }
212 }
213 for item in result_rest.iter_mut().take(remainder) {
214 *item /= n;
215 }
216
217 result
218}
219
220#[inline]
229pub fn copy_pixel_row(dst: &mut [f32], src: &[f32]) {
230 debug_assert_eq!(dst.len(), src.len());
231 dst.copy_from_slice(src);
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
243 fn reptile_simd_lr_one_copies() {
244 let mut target = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
245 let source = vec![10.0f32, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0];
246 reptile_update_simd(&mut target, &source, 1.0);
247 assert_eq!(target, source);
248 }
249
250 #[test]
251 fn reptile_simd_lr_zero_no_change() {
252 let original = vec![1.0f32, 2.0, 3.0];
253 let mut target = original.clone();
254 let source = vec![10.0f32, 20.0, 30.0];
255 reptile_update_simd(&mut target, &source, 0.0);
256 assert_eq!(target, original);
257 }
258
259 #[test]
260 fn reptile_simd_matches_scalar() {
261 let mut target_simd = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
262 let mut target_scalar = target_simd.clone();
263 let source: Vec<f32> = (0..10).map(|i| i as f32 * 3.0 + 0.5).collect();
264 let lr = 0.3;
265
266 reptile_update_simd(&mut target_simd, &source, lr);
267
268 for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
270 *t += lr * (s - *t);
271 }
272
273 for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
274 assert!(
275 (a - b).abs() < 1e-6,
276 "mismatch at index {i}: simd={a}, scalar={b}"
277 );
278 }
279 }
280
281 #[test]
282 fn reptile_simd_empty_slices() {
283 let mut target: Vec<f32> = vec![];
284 let source: Vec<f32> = vec![];
285 reptile_update_simd(&mut target, &source, 0.5);
286 assert!(target.is_empty());
287 }
288
289 #[test]
294 fn polyak_simd_tau_one_copies() {
295 let mut target = vec![1.0f32; 9];
296 let source = vec![5.0f32; 9];
297 polyak_update_simd(&mut target, &source, 1.0);
298 assert_eq!(target, source);
299 }
300
301 #[test]
302 fn polyak_simd_tau_zero_no_change() {
303 let original = vec![3.0f32; 5];
304 let mut target = original.clone();
305 let source = vec![99.0f32; 5];
306 polyak_update_simd(&mut target, &source, 0.0);
307 assert_eq!(target, original);
308 }
309
310 #[test]
311 fn polyak_simd_matches_scalar() {
312 let mut target_simd: Vec<f32> = (0..17).map(|i| i as f32 * 0.7).collect();
313 let mut target_scalar = target_simd.clone();
314 let source: Vec<f32> = (0..17).map(|i| i as f32 * 2.1 + 0.3).collect();
315 let tau = 0.005;
316
317 polyak_update_simd(&mut target_simd, &source, tau);
318
319 let one_minus_tau = 1.0 - tau;
320 for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
321 *t = tau * s + one_minus_tau * *t;
322 }
323
324 for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
325 assert!(
326 (a - b).abs() < 1e-5,
327 "mismatch at index {i}: simd={a}, scalar={b}"
328 );
329 }
330 }
331
332 #[test]
337 fn pbrs_simd_known_values() {
338 let rewards = &[1.0, 2.0];
339 let phi = &[0.5, 0.3];
340 let phi_next = &[0.3, 0.8];
341 let gamma = 0.99;
342 let dones = &[0.0, 0.0];
343 let result = pbrs_simd(rewards, phi, phi_next, gamma, dones);
344 assert!((result[0] - 0.797).abs() < 1e-10, "got {}", result[0]);
345 assert!((result[1] - 2.492).abs() < 1e-10, "got {}", result[1]);
346 }
347
348 #[test]
349 fn pbrs_simd_done_zeroes_shaping() {
350 let rewards = &[1.0, 2.0];
351 let phi = &[0.5, 0.3];
352 let phi_next = &[0.3, 0.8];
353 let gamma = 0.99;
354 let dones = &[0.0, 1.0];
355 let result = pbrs_simd(rewards, phi, phi_next, gamma, dones);
356 assert!((result[1] - 2.0).abs() < 1e-10, "got {}", result[1]);
357 }
358
359 #[test]
360 fn pbrs_simd_matches_scalar() {
361 use crate::training::reward_shaping::shape_rewards_pbrs;
362
363 let n = 13; let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
365 let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.1 + 0.05).collect();
366 let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.2 - 0.1).collect();
367 let gamma = 0.99;
368 let dones: Vec<f64> = (0..n).map(|i| if i == 5 { 1.0 } else { 0.0 }).collect();
369
370 let simd_result = pbrs_simd(&rewards, &phi, &phi_next, gamma, &dones);
371 let scalar_result = shape_rewards_pbrs(&rewards, &phi, &phi_next, gamma, &dones).unwrap();
372
373 for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
374 assert!(
375 (a - b).abs() < 1e-10,
376 "mismatch at index {i}: simd={a}, scalar={b}"
377 );
378 }
379 }
380
381 #[test]
386 fn priorities_simd_abs_plus_epsilon() {
387 let losses = &[-3.0, 0.0, 2.5, -0.1];
388 let eps = 0.01;
389 let result = compute_priorities_simd(losses, eps);
390 assert!((result[0] - 3.01).abs() < 1e-10);
391 assert!((result[1] - 0.01).abs() < 1e-10);
392 assert!((result[2] - 2.51).abs() < 1e-10);
393 assert!((result[3] - 0.11).abs() < 1e-10);
394 }
395
396 #[test]
397 fn priorities_simd_matches_scalar() {
398 let losses: Vec<f64> = (0..11).map(|i| (i as f64 - 5.0) * 1.7).collect();
399 let eps = 0.001;
400
401 let simd_result = compute_priorities_simd(&losses, eps);
402 let scalar_result: Vec<f64> = losses.iter().map(|l| l.abs() + eps).collect();
403
404 for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
405 assert!(
406 (a - b).abs() < 1e-12,
407 "mismatch at index {i}: simd={a}, scalar={b}"
408 );
409 }
410 }
411
412 #[test]
417 fn average_weights_simd_single_vector() {
418 let v = vec![1.0f32, 2.0, 3.0];
419 let result = average_weights_simd(&[&v]);
420 assert_eq!(result, v);
421 }
422
423 #[test]
424 fn average_weights_simd_two_vectors() {
425 let v1 = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
426 let v2 = [9.0f32, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
427 let result = average_weights_simd(&[&v1, &v2]);
428 for &r in &result {
429 assert!((r - 5.0).abs() < 1e-5, "expected 5.0, got {r}");
430 }
431 }
432
433 #[test]
438 fn copy_pixel_row_works() {
439 let src = [1.0f32, 2.0, 3.0, 4.0];
440 let mut dst = [0.0f32; 4];
441 copy_pixel_row(&mut dst, &src);
442 assert_eq!(dst, src);
443 }
444
445 mod proptests {
450 use super::*;
451 use proptest::prelude::*;
452
453 proptest! {
454 #[test]
455 fn prop_reptile_simd_matches_scalar(
456 dim in 1usize..500,
457 lr in 0.0f32..1.0,
458 ) {
459 let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 2.1 + 0.3).collect();
460 let original: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.7).collect();
461
462 let mut target_simd = original.clone();
463 reptile_update_simd(&mut target_simd, &source, lr);
464
465 let mut target_scalar = original.clone();
466 for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
467 *t += lr * (s - *t);
468 }
469
470 for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
471 prop_assert!(
472 (a - b).abs() < 1e-4,
473 "reptile mismatch at {i}: simd={a}, scalar={b}"
474 );
475 }
476 }
477
478 #[test]
479 fn prop_polyak_simd_matches_scalar(
480 dim in 1usize..500,
481 tau in 0.0f32..1.0,
482 ) {
483 let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 3.0).collect();
484 let original: Vec<f32> = (0..dim).map(|i| i as f32).collect();
485
486 let mut target_simd = original.clone();
487 polyak_update_simd(&mut target_simd, &source, tau);
488
489 let one_minus_tau = 1.0 - tau;
490 let mut target_scalar = original.clone();
491 for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
492 *t = tau * s + one_minus_tau * *t;
493 }
494
495 for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
496 prop_assert!(
497 (a - b).abs() < 1e-4,
498 "polyak mismatch at {i}: simd={a}, scalar={b}"
499 );
500 }
501 }
502
503 #[test]
504 fn prop_pbrs_simd_matches_scalar(n in 1usize..200) {
505 let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
506 let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
507 let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.2).collect();
508 let gamma = 0.99;
509 let dones: Vec<f64> = (0..n).map(|i| if i % 7 == 0 { 1.0 } else { 0.0 }).collect();
510
511 let simd_result = pbrs_simd(&rewards, &phi, &phi_next, gamma, &dones);
512
513 let mut scalar_result = vec![0.0f64; n];
515 for i in 0..n {
516 if dones[i] == 1.0 {
517 scalar_result[i] = rewards[i];
518 } else {
519 scalar_result[i] = rewards[i] + gamma * phi_next[i] - phi[i];
520 }
521 }
522
523 for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
524 prop_assert!(
525 (a - b).abs() < 1e-10,
526 "pbrs mismatch at {i}: simd={a}, scalar={b}"
527 );
528 }
529 }
530
531 #[test]
532 fn prop_priorities_simd_matches_scalar(n in 1usize..200) {
533 let losses: Vec<f64> = (0..n).map(|i| (i as f64 - 50.0) * 1.3).collect();
534 let eps = 0.01;
535
536 let simd_result = compute_priorities_simd(&losses, eps);
537 let scalar_result: Vec<f64> = losses.iter().map(|l| l.abs() + eps).collect();
538
539 for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
540 prop_assert!(
541 (a - b).abs() < 1e-12,
542 "priority mismatch at {i}: simd={a}, scalar={b}"
543 );
544 }
545 }
546
547 #[test]
548 fn prop_average_weights_simd_matches_scalar(
549 dim in 1usize..100,
550 num_vecs in 1usize..10,
551 ) {
552 let vectors: Vec<Vec<f32>> = (0..num_vecs)
553 .map(|v| (0..dim).map(|i| (v * dim + i) as f32 * 0.1).collect())
554 .collect();
555 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
556
557 let simd_result = average_weights_simd(&refs);
558
559 let n = num_vecs as f32;
561 let mut scalar_result = vec![0.0f32; dim];
562 for v in &vectors {
563 for (r, &val) in scalar_result.iter_mut().zip(v.iter()) {
564 *r += val;
565 }
566 }
567 for r in &mut scalar_result {
568 *r /= n;
569 }
570
571 for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
572 prop_assert!(
573 (a - b).abs() < 1e-3,
574 "average mismatch at {i}: simd={a}, scalar={b}"
575 );
576 }
577 }
578 }
579 }
580}