rlox_core/training/
normalization.rs

1/// Online running statistics using Welford's algorithm.
2///
3/// Computes running mean, population variance, and standard deviation
4/// in a numerically stable, single-pass manner.
5#[derive(Debug, Clone)]
6pub struct RunningStats {
7    count: u64,
8    mean: f64,
9    m2: f64,
10}
11
12impl RunningStats {
13    /// Create a new empty statistics accumulator.
14    pub fn new() -> Self {
15        Self {
16            count: 0,
17            mean: 0.0,
18            m2: 0.0,
19        }
20    }
21
22    /// Update with a single observation (Welford's online algorithm).
23    pub fn update(&mut self, value: f64) {
24        self.count += 1;
25        let delta = value - self.mean;
26        self.mean += delta / self.count as f64;
27        let delta2 = value - self.mean;
28        self.m2 += delta * delta2;
29    }
30
31    /// Update with a batch of observations.
32    pub fn batch_update(&mut self, values: &[f64]) {
33        for &v in values {
34            self.update(v);
35        }
36    }
37
38    /// Current running mean.
39    pub fn mean(&self) -> f64 {
40        self.mean
41    }
42
43    /// Population variance (divide by n).
44    pub fn var(&self) -> f64 {
45        if self.count < 1 {
46            return 0.0;
47        }
48        self.m2 / self.count as f64
49    }
50
51    /// Population standard deviation.
52    pub fn std(&self) -> f64 {
53        self.var().sqrt()
54    }
55
56    /// Normalize a value to a z-score using current mean and std.
57    /// Returns 0.0 if std is near zero to avoid division by zero.
58    pub fn normalize(&self, value: f64) -> f64 {
59        let s = self.std();
60        if s < 1e-8 {
61            return 0.0;
62        }
63        (value - self.mean) / s
64    }
65
66    /// Number of observations seen.
67    pub fn count(&self) -> u64 {
68        self.count
69    }
70
71    /// Reset all accumulated statistics.
72    pub fn reset(&mut self) {
73        self.count = 0;
74        self.mean = 0.0;
75        self.m2 = 0.0;
76    }
77}
78
79impl Default for RunningStats {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85/// Per-dimension online running statistics using Welford's algorithm.
86///
87/// Maintains independent mean/variance accumulators for each dimension,
88/// enabling proper per-feature observation normalization as in SB3's
89/// `RunningMeanStd`.
90#[derive(Debug, Clone)]
91pub struct RunningStatsVec {
92    dim: usize,
93    count: u64,
94    mean: Vec<f64>,
95    m2: Vec<f64>,
96}
97
98impl RunningStatsVec {
99    /// Create a new accumulator for vectors of the given dimensionality.
100    pub fn new(dim: usize) -> Self {
101        Self {
102            dim,
103            count: 0,
104            mean: vec![0.0; dim],
105            m2: vec![0.0; dim],
106        }
107    }
108
109    /// Update with a single sample of length `dim` (Welford per dimension).
110    ///
111    /// # Panics
112    ///
113    /// Panics if `values.len() != dim`.
114    #[inline]
115    pub fn update(&mut self, values: &[f64]) {
116        assert_eq!(
117            values.len(),
118            self.dim,
119            "expected {} dimensions, got {}",
120            self.dim,
121            values.len()
122        );
123        self.count += 1;
124        let n = self.count as f64;
125        for (i, &val) in values.iter().enumerate().take(self.dim) {
126            let delta = val - self.mean[i];
127            self.mean[i] += delta / n;
128            let delta2 = val - self.mean[i];
129            self.m2[i] += delta * delta2;
130        }
131    }
132
133    /// Update with a flat batch of `batch_size` samples, each of `dim` dimensions.
134    ///
135    /// `data` must have length `batch_size * dim`, laid out as
136    /// `[sample0_dim0, sample0_dim1, ..., sample1_dim0, ...]`.
137    ///
138    /// # Panics
139    ///
140    /// Panics if `data.len() != batch_size * dim`.
141    #[inline]
142    pub fn batch_update(&mut self, data: &[f64], batch_size: usize) {
143        assert_eq!(
144            data.len(),
145            batch_size * self.dim,
146            "expected {} elements (batch_size={} * dim={}), got {}",
147            batch_size * self.dim,
148            batch_size,
149            self.dim,
150            data.len()
151        );
152        for sample in data.chunks_exact(self.dim) {
153            self.update(sample);
154        }
155    }
156
157    /// Return the current per-dimension mean vector (clone).
158    #[inline]
159    pub fn mean(&self) -> Vec<f64> {
160        self.mean.clone()
161    }
162
163    /// Borrow the per-dimension mean as a slice (zero-cost).
164    #[inline]
165    pub fn mean_ref(&self) -> &[f64] {
166        &self.mean
167    }
168
169    /// Return the per-dimension population variance vector.
170    #[inline]
171    pub fn var(&self) -> Vec<f64> {
172        if self.count < 1 {
173            return vec![0.0; self.dim];
174        }
175        let n = self.count as f64;
176        self.m2.iter().map(|&m| m / n).collect()
177    }
178
179    /// Return the per-dimension population standard deviation vector.
180    #[inline]
181    pub fn std(&self) -> Vec<f64> {
182        self.var().iter().map(|&v| v.sqrt()).collect()
183    }
184
185    /// Normalize a single sample: `(values - mean) / max(std, 1e-8)` per dimension.
186    ///
187    /// # Panics
188    ///
189    /// Panics if `values.len() != dim`.
190    #[inline]
191    pub fn normalize(&self, values: &[f64]) -> Vec<f64> {
192        assert_eq!(
193            values.len(),
194            self.dim,
195            "expected {} dimensions, got {}",
196            self.dim,
197            values.len()
198        );
199        let std = self.std();
200        values
201            .iter()
202            .zip(self.mean.iter())
203            .zip(std.iter())
204            .map(|((&v, &m), &s)| (v - m) / s.max(1e-8))
205            .collect()
206    }
207
208    /// Normalize a flat batch of `batch_size` samples.
209    ///
210    /// # Panics
211    ///
212    /// Panics if `data.len() != batch_size * dim`.
213    #[inline]
214    pub fn normalize_batch(&self, data: &[f64], batch_size: usize) -> Vec<f64> {
215        assert_eq!(
216            data.len(),
217            batch_size * self.dim,
218            "expected {} elements (batch_size={} * dim={}), got {}",
219            batch_size * self.dim,
220            batch_size,
221            self.dim,
222            data.len()
223        );
224        let std = self.std();
225        let mut out = Vec::with_capacity(data.len());
226        for sample in data.chunks_exact(self.dim) {
227            for i in 0..self.dim {
228                out.push((sample[i] - self.mean[i]) / std[i].max(1e-8));
229            }
230        }
231        out
232    }
233
234    /// Number of samples seen so far.
235    #[inline]
236    pub fn count(&self) -> u64 {
237        self.count
238    }
239
240    /// Dimensionality of the tracked vectors.
241    #[inline]
242    pub fn dim(&self) -> usize {
243        self.dim
244    }
245
246    /// Reset all accumulated statistics, keeping the dimensionality.
247    pub fn reset(&mut self) {
248        self.count = 0;
249        self.mean.fill(0.0);
250        self.m2.fill(0.0);
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn running_stats_new_is_empty() {
260        let stats = RunningStats::new();
261        assert_eq!(stats.count(), 0);
262    }
263
264    #[test]
265    fn running_stats_single_sample() {
266        let mut stats = RunningStats::new();
267        stats.update(5.0);
268        assert!((stats.mean() - 5.0).abs() < 1e-10);
269        assert_eq!(stats.count(), 1);
270        let _ = stats.var();
271        let _ = stats.std();
272    }
273
274    #[test]
275    fn running_stats_welford_known_values() {
276        let mut stats = RunningStats::new();
277        for &x in &[2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
278            stats.update(x);
279        }
280        assert!(
281            (stats.mean() - 5.0).abs() < 1e-10,
282            "mean should be 5.0, got {}",
283            stats.mean()
284        );
285        assert!(
286            (stats.var() - 4.0).abs() < 1e-10,
287            "variance should be 4.0, got {}",
288            stats.var()
289        );
290        assert!(
291            (stats.std() - 2.0).abs() < 1e-10,
292            "std should be 2.0, got {}",
293            stats.std()
294        );
295    }
296
297    #[test]
298    fn running_stats_normalize_produces_z_score() {
299        let mut stats = RunningStats::new();
300        for &x in &[2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
301            stats.update(x);
302        }
303        let z = stats.normalize(5.0);
304        assert!(z.abs() < 1e-10, "normalize(mean) should be ~0, got {z}");
305        let z2 = stats.normalize(7.0);
306        assert!(
307            (z2 - 1.0).abs() < 1e-10,
308            "normalize(mean+std) should be ~1, got {z2}"
309        );
310    }
311
312    #[test]
313    fn running_stats_normalize_with_zero_std_does_not_panic() {
314        let mut stats = RunningStats::new();
315        stats.update(5.0);
316        stats.update(5.0);
317        stats.update(5.0);
318        let z = stats.normalize(5.0);
319        assert!(z.is_finite(), "normalize with zero std must be finite");
320    }
321
322    #[test]
323    fn running_stats_large_stream_numerically_stable() {
324        let mut stats = RunningStats::new();
325        let base = 1_000_000.0f64;
326        for i in 0..10_000 {
327            stats.update(base + (i as f64) * 0.001);
328        }
329        let expected_mean = base + 5.0 - 0.001 / 2.0;
330        assert!(
331            (stats.mean() - expected_mean).abs() < 0.01,
332            "mean imprecise for large offset: got {}, expected ~{expected_mean}",
333            stats.mean()
334        );
335    }
336
337    #[test]
338    fn running_stats_reset_clears_state() {
339        let mut stats = RunningStats::new();
340        for &x in &[1.0f64, 2.0, 3.0] {
341            stats.update(x);
342        }
343        stats.reset();
344        assert_eq!(stats.count(), 0);
345    }
346
347    #[test]
348    fn running_stats_nan_input_does_not_silently_corrupt() {
349        let mut stats = RunningStats::new();
350        stats.update(1.0);
351        stats.update(2.0);
352        let mean_before = stats.mean();
353        stats.update(f64::NAN);
354        let mean_after = stats.mean();
355        if mean_after.is_finite() {
356            assert!(
357                (mean_after - mean_before).abs() < 1e-10 || mean_after.is_nan(),
358                "NaN input corrupted finite mean: was {mean_before}, now {mean_after}"
359            );
360        }
361    }
362
363    #[test]
364    fn running_stats_batch_update() {
365        let mut stats = RunningStats::new();
366        stats.batch_update(&[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]);
367        assert!((stats.mean() - 5.0).abs() < 1e-10);
368        assert_eq!(stats.count(), 8);
369    }
370
371    // -----------------------------------------------------------------------
372    // RunningStatsVec tests
373    // -----------------------------------------------------------------------
374
375    #[test]
376    fn stats_vec_new_is_empty() {
377        let stats = RunningStatsVec::new(3);
378        assert_eq!(stats.count(), 0);
379        assert_eq!(stats.dim(), 3);
380        assert_eq!(stats.mean(), vec![0.0; 3]);
381        assert_eq!(stats.var(), vec![0.0; 3]);
382    }
383
384    #[test]
385    fn stats_vec_single_sample() {
386        let mut stats = RunningStatsVec::new(3);
387        stats.update(&[1.0, 2.0, 3.0]);
388        assert_eq!(stats.count(), 1);
389        assert_eq!(stats.mean(), vec![1.0, 2.0, 3.0]);
390        // Variance of a single sample is 0
391        assert_eq!(stats.var(), vec![0.0, 0.0, 0.0]);
392    }
393
394    #[test]
395    fn stats_vec_known_values_per_dim() {
396        // Dim 0: [2, 4, 4, 4, 5, 5, 7, 9] -> mean=5.0, var=4.0
397        // Dim 1: [1, 1, 1, 1, 1, 1, 1, 1] -> mean=1.0, var=0.0
398        let mut stats = RunningStatsVec::new(2);
399        let samples: &[&[f64]] = &[
400            &[2.0, 1.0],
401            &[4.0, 1.0],
402            &[4.0, 1.0],
403            &[4.0, 1.0],
404            &[5.0, 1.0],
405            &[5.0, 1.0],
406            &[7.0, 1.0],
407            &[9.0, 1.0],
408        ];
409        for s in samples {
410            stats.update(s);
411        }
412        assert_eq!(stats.count(), 8);
413        let mean = stats.mean();
414        assert!((mean[0] - 5.0).abs() < 1e-10, "dim0 mean: {}", mean[0]);
415        assert!((mean[1] - 1.0).abs() < 1e-10, "dim1 mean: {}", mean[1]);
416        let var = stats.var();
417        assert!((var[0] - 4.0).abs() < 1e-10, "dim0 var: {}", var[0]);
418        assert!(var[1].abs() < 1e-10, "dim1 var: {}", var[1]);
419        let std = stats.std();
420        assert!((std[0] - 2.0).abs() < 1e-10, "dim0 std: {}", std[0]);
421        assert!(std[1].abs() < 1e-10, "dim1 std: {}", std[1]);
422    }
423
424    #[test]
425    fn stats_vec_batch_update_matches_sequential() {
426        let mut seq = RunningStatsVec::new(3);
427        let mut batch = RunningStatsVec::new(3);
428
429        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
430        for sample in data.chunks(3) {
431            seq.update(sample);
432        }
433        batch.batch_update(&data, 3);
434
435        assert_eq!(seq.count(), batch.count());
436        for i in 0..3 {
437            assert!(
438                (seq.mean()[i] - batch.mean()[i]).abs() < 1e-10,
439                "dim {i} mean mismatch"
440            );
441            assert!(
442                (seq.var()[i] - batch.var()[i]).abs() < 1e-10,
443                "dim {i} var mismatch"
444            );
445        }
446    }
447
448    #[test]
449    fn stats_vec_normalize_produces_z_scores() {
450        let mut stats = RunningStatsVec::new(2);
451        // Dim 0: [2, 4, 4, 4, 5, 5, 7, 9] -> mean=5, std=2
452        // Dim 1: [10, 20, 30, 40, 50, 60, 70, 80] -> mean=45, std=~22.36
453        let dim0 = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
454        let dim1 = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
455        for i in 0..8 {
456            stats.update(&[dim0[i], dim1[i]]);
457        }
458
459        // Normalizing the mean should give ~0
460        let z = stats.normalize(&[5.0, 45.0]);
461        assert!(z[0].abs() < 1e-10, "z[0] should be ~0, got {}", z[0]);
462        assert!(z[1].abs() < 1e-10, "z[1] should be ~0, got {}", z[1]);
463
464        // Normalizing mean+std should give ~1
465        let std = stats.std();
466        let z2 = stats.normalize(&[5.0 + std[0], 45.0 + std[1]]);
467        assert!(
468            (z2[0] - 1.0).abs() < 1e-10,
469            "z2[0] should be ~1, got {}",
470            z2[0]
471        );
472        assert!(
473            (z2[1] - 1.0).abs() < 1e-10,
474            "z2[1] should be ~1, got {}",
475            z2[1]
476        );
477    }
478
479    #[test]
480    fn stats_vec_normalize_with_zero_std_clamps() {
481        let mut stats = RunningStatsVec::new(2);
482        stats.update(&[5.0, 3.0]);
483        stats.update(&[5.0, 3.0]);
484        // Both dims have zero variance
485        let z = stats.normalize(&[6.0, 4.0]);
486        assert!(z[0].is_finite(), "dim0 normalize must be finite");
487        assert!(z[1].is_finite(), "dim1 normalize must be finite");
488        // (6 - 5) / max(0, 1e-8) = 1e8
489        assert!((z[0] - 1e8).abs() < 1.0, "dim0: {}", z[0]);
490    }
491
492    #[test]
493    fn stats_vec_normalize_batch() {
494        let mut stats = RunningStatsVec::new(2);
495        stats.update(&[0.0, 0.0]);
496        stats.update(&[10.0, 20.0]);
497        // mean=[5, 10], var=[25, 100], std=[5, 10]
498
499        let data = [5.0, 10.0, 10.0, 20.0]; // 2 samples
500        let out = stats.normalize_batch(&data, 2);
501        assert!(out[0].abs() < 1e-10, "sample0 dim0 should be 0");
502        assert!(out[1].abs() < 1e-10, "sample0 dim1 should be 0");
503        assert!((out[2] - 1.0).abs() < 1e-10, "sample1 dim0 should be 1");
504        assert!((out[3] - 1.0).abs() < 1e-10, "sample1 dim1 should be 1");
505    }
506
507    #[test]
508    fn stats_vec_reset_clears_state() {
509        let mut stats = RunningStatsVec::new(2);
510        stats.update(&[1.0, 2.0]);
511        stats.update(&[3.0, 4.0]);
512        stats.reset();
513        assert_eq!(stats.count(), 0);
514        assert_eq!(stats.dim(), 2);
515        assert_eq!(stats.mean(), vec![0.0, 0.0]);
516    }
517
518    #[test]
519    #[should_panic(expected = "expected 3 dimensions, got 2")]
520    fn stats_vec_update_wrong_dim_panics() {
521        let mut stats = RunningStatsVec::new(3);
522        stats.update(&[1.0, 2.0]);
523    }
524
525    #[test]
526    #[should_panic(expected = "expected 6 elements")]
527    fn stats_vec_batch_update_wrong_len_panics() {
528        let mut stats = RunningStatsVec::new(3);
529        stats.batch_update(&[1.0, 2.0, 3.0, 4.0], 2);
530    }
531
532    #[test]
533    #[should_panic(expected = "expected 2 dimensions, got 3")]
534    fn stats_vec_normalize_wrong_dim_panics() {
535        let stats = RunningStatsVec::new(2);
536        stats.normalize(&[1.0, 2.0, 3.0]);
537    }
538
539    #[test]
540    fn stats_vec_large_stream_numerically_stable() {
541        let mut stats = RunningStatsVec::new(2);
542        let base = 1_000_000.0f64;
543        for i in 0..10_000 {
544            let v = i as f64 * 0.001;
545            stats.update(&[base + v, -base - v]);
546        }
547        let expected_mean = base + 5.0 - 0.001 / 2.0;
548        let mean = stats.mean();
549        assert!(
550            (mean[0] - expected_mean).abs() < 0.01,
551            "dim0 mean imprecise: got {}, expected ~{expected_mean}",
552            mean[0]
553        );
554        assert!(
555            (mean[1] + expected_mean).abs() < 0.01,
556            "dim1 mean imprecise: got {}, expected ~{}",
557            mean[1],
558            -expected_mean
559        );
560    }
561
562    #[test]
563    fn stats_vec_hopper_like_multi_scale() {
564        // Simulate Hopper-like observations: dim0 in [-1,1], dim1 in [-10,10]
565        let mut stats = RunningStatsVec::new(2);
566        for i in 0..1000 {
567            let t = i as f64 / 999.0;
568            let pos = -1.0 + 2.0 * t; // [-1, 1]
569            let vel = -10.0 + 20.0 * t; // [-10, 10]
570            stats.update(&[pos, vel]);
571        }
572        let std = stats.std();
573        // The stds should reflect the different scales
574        assert!(
575            std[1] > std[0] * 5.0,
576            "velocity std ({}) should be much larger than position std ({})",
577            std[1],
578            std[0]
579        );
580        // Normalizing should bring both dims to similar scale
581        let z = stats.normalize(&[0.5, 5.0]);
582        assert!(
583            z[0].abs() < 5.0 && z[1].abs() < 5.0,
584            "normalized values should be moderate z-scores, got {:?}",
585            z
586        );
587    }
588
589    mod proptests {
590        use super::*;
591        use proptest::prelude::*;
592
593        proptest! {
594            #[test]
595            fn running_stats_mean_matches_batch_mean(
596                values in proptest::collection::vec(-1000.0f64..1000.0, 2..200)
597            ) {
598                let mut stats = RunningStats::new();
599                for &v in &values {
600                    stats.update(v);
601                }
602                let batch_mean = values.iter().sum::<f64>() / values.len() as f64;
603                prop_assert!(
604                    (stats.mean() - batch_mean).abs() < 1e-8,
605                    "running mean {:.10} != batch mean {:.10}",
606                    stats.mean(), batch_mean
607                );
608            }
609
610            #[test]
611            fn running_stats_variance_non_negative(
612                values in proptest::collection::vec(-1000.0f64..1000.0, 2..200)
613            ) {
614                let mut stats = RunningStats::new();
615                for &v in &values {
616                    stats.update(v);
617                }
618                prop_assert!(stats.var() >= 0.0, "variance must be non-negative");
619                prop_assert!(stats.std() >= 0.0, "std must be non-negative");
620            }
621
622            #[test]
623            fn running_stats_std_equals_sqrt_var(
624                values in proptest::collection::vec(-100.0f64..100.0, 2..100)
625            ) {
626                let mut stats = RunningStats::new();
627                for &v in &values {
628                    stats.update(v);
629                }
630                let computed_std = stats.var().sqrt();
631                prop_assert!(
632                    (stats.std() - computed_std).abs() < 1e-10,
633                    "std {} != sqrt(var) {}",
634                    stats.std(), computed_std
635                );
636            }
637
638            #[test]
639            fn running_stats_count_matches_updates(
640                values in proptest::collection::vec(-100.0f64..100.0, 0..200)
641            ) {
642                let mut stats = RunningStats::new();
643                for &v in &values {
644                    stats.update(v);
645                }
646                prop_assert_eq!(stats.count() as usize, values.len());
647            }
648
649            #[test]
650            fn stats_vec_per_dim_mean_matches_naive(
651                dim in 1usize..8,
652                n_samples in 2usize..50,
653            ) {
654                // Generate deterministic data from dim and n_samples
655                let mut data = Vec::with_capacity(n_samples * dim);
656                for s in 0..n_samples {
657                    for d in 0..dim {
658                        data.push((s as f64) * 0.1 + (d as f64) * 10.0);
659                    }
660                }
661
662                let mut stats = RunningStatsVec::new(dim);
663                stats.batch_update(&data, n_samples);
664
665                // Compute naive per-dim mean
666                for d in 0..dim {
667                    let sum: f64 = (0..n_samples).map(|s| data[s * dim + d]).sum();
668                    let naive_mean = sum / n_samples as f64;
669                    prop_assert!(
670                        (stats.mean()[d] - naive_mean).abs() < 1e-8,
671                        "dim {d}: running mean {} != naive mean {}",
672                        stats.mean()[d], naive_mean
673                    );
674                }
675            }
676
677            #[test]
678            fn stats_vec_variance_non_negative(
679                dim in 1usize..6,
680                n_samples in 2usize..50,
681            ) {
682                let mut data = Vec::with_capacity(n_samples * dim);
683                for s in 0..n_samples {
684                    for d in 0..dim {
685                        data.push((s as f64) * 0.7 - (d as f64) * 3.0);
686                    }
687                }
688
689                let mut stats = RunningStatsVec::new(dim);
690                stats.batch_update(&data, n_samples);
691
692                for d in 0..dim {
693                    prop_assert!(
694                        stats.var()[d] >= 0.0,
695                        "dim {d} variance must be non-negative, got {}",
696                        stats.var()[d]
697                    );
698                }
699            }
700
701            #[test]
702            fn stats_vec_normalize_roundtrip_z_mean_zero(
703                dim in 1usize..6,
704                n_samples in 5usize..50,
705            ) {
706                let mut data = Vec::with_capacity(n_samples * dim);
707                for s in 0..n_samples {
708                    for d in 0..dim {
709                        data.push((s as f64) * 1.3 + (d as f64) * 7.0);
710                    }
711                }
712
713                let mut stats = RunningStatsVec::new(dim);
714                stats.batch_update(&data, n_samples);
715
716                // Normalizing the mean vector should give zeros
717                let z = stats.normalize(&stats.mean());
718                for (d, &val) in z.iter().enumerate().take(dim) {
719                    prop_assert!(
720                        val.abs() < 1e-8,
721                        "normalize(mean)[{d}] should be ~0, got {}",
722                        val
723                    );
724                }
725            }
726        }
727    }
728}