rlox_core/training/
simd_ops.rs

1//! SIMD-accelerated operations for training hot loops.
2//!
3//! Provides vectorized versions of weight updates, reward shaping, and
4//! priority computation. Gated behind the `simd` feature flag.
5//!
6//! Strategy: use `chunks_exact` with manual unrolling to give LLVM strong
7//! auto-vectorization hints. On x86_64, the compiler will emit AVX2/SSE
8//! instructions for these patterns. No nightly features or `std::simd` required.
9
10// ---------------------------------------------------------------------------
11// Weight operations (f32)
12// ---------------------------------------------------------------------------
13
14/// SIMD-friendly Reptile update: `target[i] += lr * (source[i] - target[i])`
15///
16/// Processes 8 f32s at a time via `chunks_exact`, giving LLVM a clear
17/// vectorization opportunity (256-bit AVX2 = 8 x f32).
18///
19/// # Panics
20/// Panics if `target.len() != source.len()`.
21#[inline]
22pub fn reptile_update_simd(target: &mut [f32], source: &[f32], lr: f32) {
23    assert_eq!(target.len(), source.len(), "length mismatch");
24    let n = target.len();
25    let chunks = n / 8;
26    let remainder = n % 8;
27
28    let (target_chunks, target_rest) = target.split_at_mut(chunks * 8);
29    let (source_chunks, source_rest) = source.split_at(chunks * 8);
30
31    // Process 8 elements at a time -- LLVM will auto-vectorize this
32    for (t_chunk, s_chunk) in target_chunks
33        .chunks_exact_mut(8)
34        .zip(source_chunks.chunks_exact(8))
35    {
36        for i in 0..8 {
37            t_chunk[i] += lr * (s_chunk[i] - t_chunk[i]);
38        }
39    }
40
41    // Scalar remainder
42    for i in 0..remainder {
43        target_rest[i] += lr * (source_rest[i] - target_rest[i]);
44    }
45}
46
47/// SIMD-friendly Polyak update: `target[i] = tau * source[i] + (1 - tau) * target[i]`
48///
49/// Processes 8 f32s at a time.
50///
51/// # Panics
52/// Panics if `target.len() != source.len()`.
53#[inline]
54pub fn polyak_update_simd(target: &mut [f32], source: &[f32], tau: f32) {
55    assert_eq!(target.len(), source.len(), "length mismatch");
56    let n = target.len();
57    let one_minus_tau = 1.0 - tau;
58
59    let chunks = n / 8;
60    let remainder = n % 8;
61
62    let (target_chunks, target_rest) = target.split_at_mut(chunks * 8);
63    let (source_chunks, source_rest) = source.split_at(chunks * 8);
64
65    for (t_chunk, s_chunk) in target_chunks
66        .chunks_exact_mut(8)
67        .zip(source_chunks.chunks_exact(8))
68    {
69        for i in 0..8 {
70            t_chunk[i] = tau * s_chunk[i] + one_minus_tau * t_chunk[i];
71        }
72    }
73
74    for i in 0..remainder {
75        target_rest[i] = tau * source_rest[i] + one_minus_tau * target_rest[i];
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Reward shaping (f64)
81// ---------------------------------------------------------------------------
82
83/// SIMD-friendly PBRS: `result[i] = rewards[i] + gamma * phi_next[i] - phi_current[i]`
84///
85/// At episode boundaries (`dones[i] == 1.0`), returns the raw reward.
86/// Processes 4 f64s at a time (256-bit AVX2 = 4 x f64).
87///
88/// # Panics
89/// Panics if all slices are not the same length.
90#[inline]
91pub fn pbrs_simd(
92    rewards: &[f64],
93    phi_current: &[f64],
94    phi_next: &[f64],
95    gamma: f64,
96    dones: &[f64],
97) -> Vec<f64> {
98    let n = rewards.len();
99    assert_eq!(phi_current.len(), n);
100    assert_eq!(phi_next.len(), n);
101    assert_eq!(dones.len(), n);
102
103    let mut output = vec![0.0f64; n];
104
105    let chunks = n / 4;
106    let remainder = n % 4;
107
108    // Process 4 elements at a time
109    for chunk_idx in 0..chunks {
110        let base = chunk_idx * 4;
111        for i in 0..4 {
112            let idx = base + i;
113            // Branchless: mask = 1.0 when not done, 0.0 when done
114            // done==1.0 -> shaping=0, done==0.0 -> shaping=gamma*phi_next - phi_current
115            let not_done = 1.0 - dones[idx];
116            let shaping = not_done * (gamma * phi_next[idx] - phi_current[idx]);
117            output[idx] = rewards[idx] + shaping;
118        }
119    }
120
121    // Scalar remainder
122    let base = chunks * 4;
123    for i in 0..remainder {
124        let idx = base + i;
125        let not_done = 1.0 - dones[idx];
126        let shaping = not_done * (gamma * phi_next[idx] - phi_current[idx]);
127        output[idx] = rewards[idx] + shaping;
128    }
129
130    output
131}
132
133// ---------------------------------------------------------------------------
134// Priority computation (f64)
135// ---------------------------------------------------------------------------
136
137/// SIMD-friendly LAP priority: `priority[i] = |loss[i]| + epsilon`
138///
139/// The `abs + add` portion is trivially vectorizable. The subsequent
140/// `powf(alpha)` is left to the caller (not SIMD-friendly).
141///
142/// Processes 4 f64s at a time.
143#[inline]
144pub fn compute_priorities_simd(losses: &[f64], epsilon: f64) -> Vec<f64> {
145    let n = losses.len();
146    let mut output = vec![0.0f64; n];
147
148    let chunks = n / 4;
149    let remainder = n % 4;
150
151    for chunk_idx in 0..chunks {
152        let base = chunk_idx * 4;
153        for i in 0..4 {
154            output[base + i] = losses[base + i].abs() + epsilon;
155        }
156    }
157
158    let base = chunks * 4;
159    for i in 0..remainder {
160        output[base + i] = losses[base + i].abs() + epsilon;
161    }
162
163    output
164}
165
166/// SIMD-friendly weight vector averaging: `result[i] = sum(vectors[j][i]) / n`
167///
168/// Accumulates across vectors using chunks of 8 f32s for the inner loop.
169///
170/// # Panics
171/// Panics if `vectors` is empty or vectors have different lengths.
172#[inline]
173pub fn average_weights_simd(vectors: &[&[f32]]) -> Vec<f32> {
174    assert!(!vectors.is_empty(), "cannot average zero vectors");
175    let dim = vectors[0].len();
176    for v in vectors.iter().skip(1) {
177        assert_eq!(v.len(), dim, "all vectors must have the same length");
178    }
179
180    let n = vectors.len() as f32;
181    let mut result = vec![0.0f32; dim];
182
183    for v in vectors {
184        let chunks = dim / 8;
185        let remainder = dim % 8;
186
187        let (result_chunks, result_rest) = result.split_at_mut(chunks * 8);
188        let (v_chunks, v_rest) = v.split_at(chunks * 8);
189
190        for (r_chunk, v_chunk) in result_chunks
191            .chunks_exact_mut(8)
192            .zip(v_chunks.chunks_exact(8))
193        {
194            for i in 0..8 {
195                r_chunk[i] += v_chunk[i];
196            }
197        }
198
199        for i in 0..remainder {
200            result_rest[i] += v_rest[i];
201        }
202    }
203
204    // Divide by n
205    let chunks = dim / 8;
206    let remainder = dim % 8;
207    let (result_chunks, result_rest) = result.split_at_mut(chunks * 8);
208    for chunk in result_chunks.chunks_exact_mut(8) {
209        for item in chunk.iter_mut().take(8) {
210            *item /= n;
211        }
212    }
213    for item in result_rest.iter_mut().take(remainder) {
214        *item /= n;
215    }
216
217    result
218}
219
220// ---------------------------------------------------------------------------
221// Image augmentation helper
222// ---------------------------------------------------------------------------
223
224/// Copy a contiguous row of pixels using `copy_from_slice` (auto-vectorizes
225/// to SIMD memcpy on all targets).
226///
227/// This is a thin wrapper that makes the intent explicit for the compiler.
228#[inline]
229pub fn copy_pixel_row(dst: &mut [f32], src: &[f32]) {
230    debug_assert_eq!(dst.len(), src.len());
231    dst.copy_from_slice(src);
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    // -----------------------------------------------------------------------
239    // Reptile SIMD tests
240    // -----------------------------------------------------------------------
241
242    #[test]
243    fn reptile_simd_lr_one_copies() {
244        let mut target = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
245        let source = vec![10.0f32, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0];
246        reptile_update_simd(&mut target, &source, 1.0);
247        assert_eq!(target, source);
248    }
249
250    #[test]
251    fn reptile_simd_lr_zero_no_change() {
252        let original = vec![1.0f32, 2.0, 3.0];
253        let mut target = original.clone();
254        let source = vec![10.0f32, 20.0, 30.0];
255        reptile_update_simd(&mut target, &source, 0.0);
256        assert_eq!(target, original);
257    }
258
259    #[test]
260    fn reptile_simd_matches_scalar() {
261        let mut target_simd = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
262        let mut target_scalar = target_simd.clone();
263        let source: Vec<f32> = (0..10).map(|i| i as f32 * 3.0 + 0.5).collect();
264        let lr = 0.3;
265
266        reptile_update_simd(&mut target_simd, &source, lr);
267
268        // Scalar reference
269        for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
270            *t += lr * (s - *t);
271        }
272
273        for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
274            assert!(
275                (a - b).abs() < 1e-6,
276                "mismatch at index {i}: simd={a}, scalar={b}"
277            );
278        }
279    }
280
281    #[test]
282    fn reptile_simd_empty_slices() {
283        let mut target: Vec<f32> = vec![];
284        let source: Vec<f32> = vec![];
285        reptile_update_simd(&mut target, &source, 0.5);
286        assert!(target.is_empty());
287    }
288
289    // -----------------------------------------------------------------------
290    // Polyak SIMD tests
291    // -----------------------------------------------------------------------
292
293    #[test]
294    fn polyak_simd_tau_one_copies() {
295        let mut target = vec![1.0f32; 9];
296        let source = vec![5.0f32; 9];
297        polyak_update_simd(&mut target, &source, 1.0);
298        assert_eq!(target, source);
299    }
300
301    #[test]
302    fn polyak_simd_tau_zero_no_change() {
303        let original = vec![3.0f32; 5];
304        let mut target = original.clone();
305        let source = vec![99.0f32; 5];
306        polyak_update_simd(&mut target, &source, 0.0);
307        assert_eq!(target, original);
308    }
309
310    #[test]
311    fn polyak_simd_matches_scalar() {
312        let mut target_simd: Vec<f32> = (0..17).map(|i| i as f32 * 0.7).collect();
313        let mut target_scalar = target_simd.clone();
314        let source: Vec<f32> = (0..17).map(|i| i as f32 * 2.1 + 0.3).collect();
315        let tau = 0.005;
316
317        polyak_update_simd(&mut target_simd, &source, tau);
318
319        let one_minus_tau = 1.0 - tau;
320        for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
321            *t = tau * s + one_minus_tau * *t;
322        }
323
324        for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
325            assert!(
326                (a - b).abs() < 1e-5,
327                "mismatch at index {i}: simd={a}, scalar={b}"
328            );
329        }
330    }
331
332    // -----------------------------------------------------------------------
333    // PBRS SIMD tests
334    // -----------------------------------------------------------------------
335
336    #[test]
337    fn pbrs_simd_known_values() {
338        let rewards = &[1.0, 2.0];
339        let phi = &[0.5, 0.3];
340        let phi_next = &[0.3, 0.8];
341        let gamma = 0.99;
342        let dones = &[0.0, 0.0];
343        let result = pbrs_simd(rewards, phi, phi_next, gamma, dones);
344        assert!((result[0] - 0.797).abs() < 1e-10, "got {}", result[0]);
345        assert!((result[1] - 2.492).abs() < 1e-10, "got {}", result[1]);
346    }
347
348    #[test]
349    fn pbrs_simd_done_zeroes_shaping() {
350        let rewards = &[1.0, 2.0];
351        let phi = &[0.5, 0.3];
352        let phi_next = &[0.3, 0.8];
353        let gamma = 0.99;
354        let dones = &[0.0, 1.0];
355        let result = pbrs_simd(rewards, phi, phi_next, gamma, dones);
356        assert!((result[1] - 2.0).abs() < 1e-10, "got {}", result[1]);
357    }
358
359    #[test]
360    fn pbrs_simd_matches_scalar() {
361        use crate::training::reward_shaping::shape_rewards_pbrs;
362
363        let n = 13; // non-multiple of 4 to test remainder
364        let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
365        let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.1 + 0.05).collect();
366        let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.2 - 0.1).collect();
367        let gamma = 0.99;
368        let dones: Vec<f64> = (0..n).map(|i| if i == 5 { 1.0 } else { 0.0 }).collect();
369
370        let simd_result = pbrs_simd(&rewards, &phi, &phi_next, gamma, &dones);
371        let scalar_result = shape_rewards_pbrs(&rewards, &phi, &phi_next, gamma, &dones).unwrap();
372
373        for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
374            assert!(
375                (a - b).abs() < 1e-10,
376                "mismatch at index {i}: simd={a}, scalar={b}"
377            );
378        }
379    }
380
381    // -----------------------------------------------------------------------
382    // Priority SIMD tests
383    // -----------------------------------------------------------------------
384
385    #[test]
386    fn priorities_simd_abs_plus_epsilon() {
387        let losses = &[-3.0, 0.0, 2.5, -0.1];
388        let eps = 0.01;
389        let result = compute_priorities_simd(losses, eps);
390        assert!((result[0] - 3.01).abs() < 1e-10);
391        assert!((result[1] - 0.01).abs() < 1e-10);
392        assert!((result[2] - 2.51).abs() < 1e-10);
393        assert!((result[3] - 0.11).abs() < 1e-10);
394    }
395
396    #[test]
397    fn priorities_simd_matches_scalar() {
398        let losses: Vec<f64> = (0..11).map(|i| (i as f64 - 5.0) * 1.7).collect();
399        let eps = 0.001;
400
401        let simd_result = compute_priorities_simd(&losses, eps);
402        let scalar_result: Vec<f64> = losses.iter().map(|l| l.abs() + eps).collect();
403
404        for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
405            assert!(
406                (a - b).abs() < 1e-12,
407                "mismatch at index {i}: simd={a}, scalar={b}"
408            );
409        }
410    }
411
412    // -----------------------------------------------------------------------
413    // Average weights SIMD tests
414    // -----------------------------------------------------------------------
415
416    #[test]
417    fn average_weights_simd_single_vector() {
418        let v = vec![1.0f32, 2.0, 3.0];
419        let result = average_weights_simd(&[&v]);
420        assert_eq!(result, v);
421    }
422
423    #[test]
424    fn average_weights_simd_two_vectors() {
425        let v1 = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
426        let v2 = [9.0f32, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
427        let result = average_weights_simd(&[&v1, &v2]);
428        for &r in &result {
429            assert!((r - 5.0).abs() < 1e-5, "expected 5.0, got {r}");
430        }
431    }
432
433    // -----------------------------------------------------------------------
434    // copy_pixel_row test
435    // -----------------------------------------------------------------------
436
437    #[test]
438    fn copy_pixel_row_works() {
439        let src = [1.0f32, 2.0, 3.0, 4.0];
440        let mut dst = [0.0f32; 4];
441        copy_pixel_row(&mut dst, &src);
442        assert_eq!(dst, src);
443    }
444
445    // -----------------------------------------------------------------------
446    // Proptests: SIMD == scalar for random inputs
447    // -----------------------------------------------------------------------
448
449    mod proptests {
450        use super::*;
451        use proptest::prelude::*;
452
453        proptest! {
454            #[test]
455            fn prop_reptile_simd_matches_scalar(
456                dim in 1usize..500,
457                lr in 0.0f32..1.0,
458            ) {
459                let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 2.1 + 0.3).collect();
460                let original: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.7).collect();
461
462                let mut target_simd = original.clone();
463                reptile_update_simd(&mut target_simd, &source, lr);
464
465                let mut target_scalar = original.clone();
466                for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
467                    *t += lr * (s - *t);
468                }
469
470                for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
471                    prop_assert!(
472                        (a - b).abs() < 1e-4,
473                        "reptile mismatch at {i}: simd={a}, scalar={b}"
474                    );
475                }
476            }
477
478            #[test]
479            fn prop_polyak_simd_matches_scalar(
480                dim in 1usize..500,
481                tau in 0.0f32..1.0,
482            ) {
483                let source: Vec<f32> = (0..dim).map(|i| (i as f32) * 3.0).collect();
484                let original: Vec<f32> = (0..dim).map(|i| i as f32).collect();
485
486                let mut target_simd = original.clone();
487                polyak_update_simd(&mut target_simd, &source, tau);
488
489                let one_minus_tau = 1.0 - tau;
490                let mut target_scalar = original.clone();
491                for (t, &s) in target_scalar.iter_mut().zip(source.iter()) {
492                    *t = tau * s + one_minus_tau * *t;
493                }
494
495                for (i, (a, b)) in target_simd.iter().zip(target_scalar.iter()).enumerate() {
496                    prop_assert!(
497                        (a - b).abs() < 1e-4,
498                        "polyak mismatch at {i}: simd={a}, scalar={b}"
499                    );
500                }
501            }
502
503            #[test]
504            fn prop_pbrs_simd_matches_scalar(n in 1usize..200) {
505                let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
506                let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
507                let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.2).collect();
508                let gamma = 0.99;
509                let dones: Vec<f64> = (0..n).map(|i| if i % 7 == 0 { 1.0 } else { 0.0 }).collect();
510
511                let simd_result = pbrs_simd(&rewards, &phi, &phi_next, gamma, &dones);
512
513                // Scalar reference
514                let mut scalar_result = vec![0.0f64; n];
515                for i in 0..n {
516                    if dones[i] == 1.0 {
517                        scalar_result[i] = rewards[i];
518                    } else {
519                        scalar_result[i] = rewards[i] + gamma * phi_next[i] - phi[i];
520                    }
521                }
522
523                for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
524                    prop_assert!(
525                        (a - b).abs() < 1e-10,
526                        "pbrs mismatch at {i}: simd={a}, scalar={b}"
527                    );
528                }
529            }
530
531            #[test]
532            fn prop_priorities_simd_matches_scalar(n in 1usize..200) {
533                let losses: Vec<f64> = (0..n).map(|i| (i as f64 - 50.0) * 1.3).collect();
534                let eps = 0.01;
535
536                let simd_result = compute_priorities_simd(&losses, eps);
537                let scalar_result: Vec<f64> = losses.iter().map(|l| l.abs() + eps).collect();
538
539                for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
540                    prop_assert!(
541                        (a - b).abs() < 1e-12,
542                        "priority mismatch at {i}: simd={a}, scalar={b}"
543                    );
544                }
545            }
546
547            #[test]
548            fn prop_average_weights_simd_matches_scalar(
549                dim in 1usize..100,
550                num_vecs in 1usize..10,
551            ) {
552                let vectors: Vec<Vec<f32>> = (0..num_vecs)
553                    .map(|v| (0..dim).map(|i| (v * dim + i) as f32 * 0.1).collect())
554                    .collect();
555                let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
556
557                let simd_result = average_weights_simd(&refs);
558
559                // Scalar reference
560                let n = num_vecs as f32;
561                let mut scalar_result = vec![0.0f32; dim];
562                for v in &vectors {
563                    for (r, &val) in scalar_result.iter_mut().zip(v.iter()) {
564                        *r += val;
565                    }
566                }
567                for r in &mut scalar_result {
568                    *r /= n;
569                }
570
571                for (i, (a, b)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
572                    prop_assert!(
573                        (a - b).abs() < 1e-3,
574                        "average mismatch at {i}: simd={a}, scalar={b}"
575                    );
576                }
577            }
578        }
579    }
580}