1#[derive(Debug, Clone)]
6pub struct RunningStats {
7 count: u64,
8 mean: f64,
9 m2: f64,
10}
11
12impl RunningStats {
13 pub fn new() -> Self {
15 Self {
16 count: 0,
17 mean: 0.0,
18 m2: 0.0,
19 }
20 }
21
22 pub fn update(&mut self, value: f64) {
24 self.count += 1;
25 let delta = value - self.mean;
26 self.mean += delta / self.count as f64;
27 let delta2 = value - self.mean;
28 self.m2 += delta * delta2;
29 }
30
31 pub fn batch_update(&mut self, values: &[f64]) {
33 for &v in values {
34 self.update(v);
35 }
36 }
37
38 pub fn mean(&self) -> f64 {
40 self.mean
41 }
42
43 pub fn var(&self) -> f64 {
45 if self.count < 1 {
46 return 0.0;
47 }
48 self.m2 / self.count as f64
49 }
50
51 pub fn std(&self) -> f64 {
53 self.var().sqrt()
54 }
55
56 pub fn normalize(&self, value: f64) -> f64 {
59 let s = self.std();
60 if s < 1e-8 {
61 return 0.0;
62 }
63 (value - self.mean) / s
64 }
65
66 pub fn count(&self) -> u64 {
68 self.count
69 }
70
71 pub fn reset(&mut self) {
73 self.count = 0;
74 self.mean = 0.0;
75 self.m2 = 0.0;
76 }
77}
78
79impl Default for RunningStats {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85#[derive(Debug, Clone)]
91pub struct RunningStatsVec {
92 dim: usize,
93 count: u64,
94 mean: Vec<f64>,
95 m2: Vec<f64>,
96}
97
98impl RunningStatsVec {
99 pub fn new(dim: usize) -> Self {
101 Self {
102 dim,
103 count: 0,
104 mean: vec![0.0; dim],
105 m2: vec![0.0; dim],
106 }
107 }
108
109 #[inline]
115 pub fn update(&mut self, values: &[f64]) {
116 assert_eq!(
117 values.len(),
118 self.dim,
119 "expected {} dimensions, got {}",
120 self.dim,
121 values.len()
122 );
123 self.count += 1;
124 let n = self.count as f64;
125 for (i, &val) in values.iter().enumerate().take(self.dim) {
126 let delta = val - self.mean[i];
127 self.mean[i] += delta / n;
128 let delta2 = val - self.mean[i];
129 self.m2[i] += delta * delta2;
130 }
131 }
132
133 #[inline]
142 pub fn batch_update(&mut self, data: &[f64], batch_size: usize) {
143 assert_eq!(
144 data.len(),
145 batch_size * self.dim,
146 "expected {} elements (batch_size={} * dim={}), got {}",
147 batch_size * self.dim,
148 batch_size,
149 self.dim,
150 data.len()
151 );
152 for sample in data.chunks_exact(self.dim) {
153 self.update(sample);
154 }
155 }
156
157 #[inline]
159 pub fn mean(&self) -> Vec<f64> {
160 self.mean.clone()
161 }
162
163 #[inline]
165 pub fn mean_ref(&self) -> &[f64] {
166 &self.mean
167 }
168
169 #[inline]
171 pub fn var(&self) -> Vec<f64> {
172 if self.count < 1 {
173 return vec![0.0; self.dim];
174 }
175 let n = self.count as f64;
176 self.m2.iter().map(|&m| m / n).collect()
177 }
178
179 #[inline]
181 pub fn std(&self) -> Vec<f64> {
182 self.var().iter().map(|&v| v.sqrt()).collect()
183 }
184
185 #[inline]
191 pub fn normalize(&self, values: &[f64]) -> Vec<f64> {
192 assert_eq!(
193 values.len(),
194 self.dim,
195 "expected {} dimensions, got {}",
196 self.dim,
197 values.len()
198 );
199 let std = self.std();
200 values
201 .iter()
202 .zip(self.mean.iter())
203 .zip(std.iter())
204 .map(|((&v, &m), &s)| (v - m) / s.max(1e-8))
205 .collect()
206 }
207
208 #[inline]
214 pub fn normalize_batch(&self, data: &[f64], batch_size: usize) -> Vec<f64> {
215 assert_eq!(
216 data.len(),
217 batch_size * self.dim,
218 "expected {} elements (batch_size={} * dim={}), got {}",
219 batch_size * self.dim,
220 batch_size,
221 self.dim,
222 data.len()
223 );
224 let std = self.std();
225 let mut out = Vec::with_capacity(data.len());
226 for sample in data.chunks_exact(self.dim) {
227 for i in 0..self.dim {
228 out.push((sample[i] - self.mean[i]) / std[i].max(1e-8));
229 }
230 }
231 out
232 }
233
234 #[inline]
236 pub fn count(&self) -> u64 {
237 self.count
238 }
239
240 #[inline]
242 pub fn dim(&self) -> usize {
243 self.dim
244 }
245
246 pub fn reset(&mut self) {
248 self.count = 0;
249 self.mean.fill(0.0);
250 self.m2.fill(0.0);
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn running_stats_new_is_empty() {
260 let stats = RunningStats::new();
261 assert_eq!(stats.count(), 0);
262 }
263
264 #[test]
265 fn running_stats_single_sample() {
266 let mut stats = RunningStats::new();
267 stats.update(5.0);
268 assert!((stats.mean() - 5.0).abs() < 1e-10);
269 assert_eq!(stats.count(), 1);
270 let _ = stats.var();
271 let _ = stats.std();
272 }
273
274 #[test]
275 fn running_stats_welford_known_values() {
276 let mut stats = RunningStats::new();
277 for &x in &[2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
278 stats.update(x);
279 }
280 assert!(
281 (stats.mean() - 5.0).abs() < 1e-10,
282 "mean should be 5.0, got {}",
283 stats.mean()
284 );
285 assert!(
286 (stats.var() - 4.0).abs() < 1e-10,
287 "variance should be 4.0, got {}",
288 stats.var()
289 );
290 assert!(
291 (stats.std() - 2.0).abs() < 1e-10,
292 "std should be 2.0, got {}",
293 stats.std()
294 );
295 }
296
297 #[test]
298 fn running_stats_normalize_produces_z_score() {
299 let mut stats = RunningStats::new();
300 for &x in &[2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
301 stats.update(x);
302 }
303 let z = stats.normalize(5.0);
304 assert!(z.abs() < 1e-10, "normalize(mean) should be ~0, got {z}");
305 let z2 = stats.normalize(7.0);
306 assert!(
307 (z2 - 1.0).abs() < 1e-10,
308 "normalize(mean+std) should be ~1, got {z2}"
309 );
310 }
311
312 #[test]
313 fn running_stats_normalize_with_zero_std_does_not_panic() {
314 let mut stats = RunningStats::new();
315 stats.update(5.0);
316 stats.update(5.0);
317 stats.update(5.0);
318 let z = stats.normalize(5.0);
319 assert!(z.is_finite(), "normalize with zero std must be finite");
320 }
321
322 #[test]
323 fn running_stats_large_stream_numerically_stable() {
324 let mut stats = RunningStats::new();
325 let base = 1_000_000.0f64;
326 for i in 0..10_000 {
327 stats.update(base + (i as f64) * 0.001);
328 }
329 let expected_mean = base + 5.0 - 0.001 / 2.0;
330 assert!(
331 (stats.mean() - expected_mean).abs() < 0.01,
332 "mean imprecise for large offset: got {}, expected ~{expected_mean}",
333 stats.mean()
334 );
335 }
336
337 #[test]
338 fn running_stats_reset_clears_state() {
339 let mut stats = RunningStats::new();
340 for &x in &[1.0f64, 2.0, 3.0] {
341 stats.update(x);
342 }
343 stats.reset();
344 assert_eq!(stats.count(), 0);
345 }
346
347 #[test]
348 fn running_stats_nan_input_does_not_silently_corrupt() {
349 let mut stats = RunningStats::new();
350 stats.update(1.0);
351 stats.update(2.0);
352 let mean_before = stats.mean();
353 stats.update(f64::NAN);
354 let mean_after = stats.mean();
355 if mean_after.is_finite() {
356 assert!(
357 (mean_after - mean_before).abs() < 1e-10 || mean_after.is_nan(),
358 "NaN input corrupted finite mean: was {mean_before}, now {mean_after}"
359 );
360 }
361 }
362
363 #[test]
364 fn running_stats_batch_update() {
365 let mut stats = RunningStats::new();
366 stats.batch_update(&[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]);
367 assert!((stats.mean() - 5.0).abs() < 1e-10);
368 assert_eq!(stats.count(), 8);
369 }
370
371 #[test]
376 fn stats_vec_new_is_empty() {
377 let stats = RunningStatsVec::new(3);
378 assert_eq!(stats.count(), 0);
379 assert_eq!(stats.dim(), 3);
380 assert_eq!(stats.mean(), vec![0.0; 3]);
381 assert_eq!(stats.var(), vec![0.0; 3]);
382 }
383
384 #[test]
385 fn stats_vec_single_sample() {
386 let mut stats = RunningStatsVec::new(3);
387 stats.update(&[1.0, 2.0, 3.0]);
388 assert_eq!(stats.count(), 1);
389 assert_eq!(stats.mean(), vec![1.0, 2.0, 3.0]);
390 assert_eq!(stats.var(), vec![0.0, 0.0, 0.0]);
392 }
393
394 #[test]
395 fn stats_vec_known_values_per_dim() {
396 let mut stats = RunningStatsVec::new(2);
399 let samples: &[&[f64]] = &[
400 &[2.0, 1.0],
401 &[4.0, 1.0],
402 &[4.0, 1.0],
403 &[4.0, 1.0],
404 &[5.0, 1.0],
405 &[5.0, 1.0],
406 &[7.0, 1.0],
407 &[9.0, 1.0],
408 ];
409 for s in samples {
410 stats.update(s);
411 }
412 assert_eq!(stats.count(), 8);
413 let mean = stats.mean();
414 assert!((mean[0] - 5.0).abs() < 1e-10, "dim0 mean: {}", mean[0]);
415 assert!((mean[1] - 1.0).abs() < 1e-10, "dim1 mean: {}", mean[1]);
416 let var = stats.var();
417 assert!((var[0] - 4.0).abs() < 1e-10, "dim0 var: {}", var[0]);
418 assert!(var[1].abs() < 1e-10, "dim1 var: {}", var[1]);
419 let std = stats.std();
420 assert!((std[0] - 2.0).abs() < 1e-10, "dim0 std: {}", std[0]);
421 assert!(std[1].abs() < 1e-10, "dim1 std: {}", std[1]);
422 }
423
424 #[test]
425 fn stats_vec_batch_update_matches_sequential() {
426 let mut seq = RunningStatsVec::new(3);
427 let mut batch = RunningStatsVec::new(3);
428
429 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
430 for sample in data.chunks(3) {
431 seq.update(sample);
432 }
433 batch.batch_update(&data, 3);
434
435 assert_eq!(seq.count(), batch.count());
436 for i in 0..3 {
437 assert!(
438 (seq.mean()[i] - batch.mean()[i]).abs() < 1e-10,
439 "dim {i} mean mismatch"
440 );
441 assert!(
442 (seq.var()[i] - batch.var()[i]).abs() < 1e-10,
443 "dim {i} var mismatch"
444 );
445 }
446 }
447
448 #[test]
449 fn stats_vec_normalize_produces_z_scores() {
450 let mut stats = RunningStatsVec::new(2);
451 let dim0 = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
454 let dim1 = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
455 for i in 0..8 {
456 stats.update(&[dim0[i], dim1[i]]);
457 }
458
459 let z = stats.normalize(&[5.0, 45.0]);
461 assert!(z[0].abs() < 1e-10, "z[0] should be ~0, got {}", z[0]);
462 assert!(z[1].abs() < 1e-10, "z[1] should be ~0, got {}", z[1]);
463
464 let std = stats.std();
466 let z2 = stats.normalize(&[5.0 + std[0], 45.0 + std[1]]);
467 assert!(
468 (z2[0] - 1.0).abs() < 1e-10,
469 "z2[0] should be ~1, got {}",
470 z2[0]
471 );
472 assert!(
473 (z2[1] - 1.0).abs() < 1e-10,
474 "z2[1] should be ~1, got {}",
475 z2[1]
476 );
477 }
478
479 #[test]
480 fn stats_vec_normalize_with_zero_std_clamps() {
481 let mut stats = RunningStatsVec::new(2);
482 stats.update(&[5.0, 3.0]);
483 stats.update(&[5.0, 3.0]);
484 let z = stats.normalize(&[6.0, 4.0]);
486 assert!(z[0].is_finite(), "dim0 normalize must be finite");
487 assert!(z[1].is_finite(), "dim1 normalize must be finite");
488 assert!((z[0] - 1e8).abs() < 1.0, "dim0: {}", z[0]);
490 }
491
492 #[test]
493 fn stats_vec_normalize_batch() {
494 let mut stats = RunningStatsVec::new(2);
495 stats.update(&[0.0, 0.0]);
496 stats.update(&[10.0, 20.0]);
497 let data = [5.0, 10.0, 10.0, 20.0]; let out = stats.normalize_batch(&data, 2);
501 assert!(out[0].abs() < 1e-10, "sample0 dim0 should be 0");
502 assert!(out[1].abs() < 1e-10, "sample0 dim1 should be 0");
503 assert!((out[2] - 1.0).abs() < 1e-10, "sample1 dim0 should be 1");
504 assert!((out[3] - 1.0).abs() < 1e-10, "sample1 dim1 should be 1");
505 }
506
507 #[test]
508 fn stats_vec_reset_clears_state() {
509 let mut stats = RunningStatsVec::new(2);
510 stats.update(&[1.0, 2.0]);
511 stats.update(&[3.0, 4.0]);
512 stats.reset();
513 assert_eq!(stats.count(), 0);
514 assert_eq!(stats.dim(), 2);
515 assert_eq!(stats.mean(), vec![0.0, 0.0]);
516 }
517
518 #[test]
519 #[should_panic(expected = "expected 3 dimensions, got 2")]
520 fn stats_vec_update_wrong_dim_panics() {
521 let mut stats = RunningStatsVec::new(3);
522 stats.update(&[1.0, 2.0]);
523 }
524
525 #[test]
526 #[should_panic(expected = "expected 6 elements")]
527 fn stats_vec_batch_update_wrong_len_panics() {
528 let mut stats = RunningStatsVec::new(3);
529 stats.batch_update(&[1.0, 2.0, 3.0, 4.0], 2);
530 }
531
532 #[test]
533 #[should_panic(expected = "expected 2 dimensions, got 3")]
534 fn stats_vec_normalize_wrong_dim_panics() {
535 let stats = RunningStatsVec::new(2);
536 stats.normalize(&[1.0, 2.0, 3.0]);
537 }
538
539 #[test]
540 fn stats_vec_large_stream_numerically_stable() {
541 let mut stats = RunningStatsVec::new(2);
542 let base = 1_000_000.0f64;
543 for i in 0..10_000 {
544 let v = i as f64 * 0.001;
545 stats.update(&[base + v, -base - v]);
546 }
547 let expected_mean = base + 5.0 - 0.001 / 2.0;
548 let mean = stats.mean();
549 assert!(
550 (mean[0] - expected_mean).abs() < 0.01,
551 "dim0 mean imprecise: got {}, expected ~{expected_mean}",
552 mean[0]
553 );
554 assert!(
555 (mean[1] + expected_mean).abs() < 0.01,
556 "dim1 mean imprecise: got {}, expected ~{}",
557 mean[1],
558 -expected_mean
559 );
560 }
561
562 #[test]
563 fn stats_vec_hopper_like_multi_scale() {
564 let mut stats = RunningStatsVec::new(2);
566 for i in 0..1000 {
567 let t = i as f64 / 999.0;
568 let pos = -1.0 + 2.0 * t; let vel = -10.0 + 20.0 * t; stats.update(&[pos, vel]);
571 }
572 let std = stats.std();
573 assert!(
575 std[1] > std[0] * 5.0,
576 "velocity std ({}) should be much larger than position std ({})",
577 std[1],
578 std[0]
579 );
580 let z = stats.normalize(&[0.5, 5.0]);
582 assert!(
583 z[0].abs() < 5.0 && z[1].abs() < 5.0,
584 "normalized values should be moderate z-scores, got {:?}",
585 z
586 );
587 }
588
589 mod proptests {
590 use super::*;
591 use proptest::prelude::*;
592
593 proptest! {
594 #[test]
595 fn running_stats_mean_matches_batch_mean(
596 values in proptest::collection::vec(-1000.0f64..1000.0, 2..200)
597 ) {
598 let mut stats = RunningStats::new();
599 for &v in &values {
600 stats.update(v);
601 }
602 let batch_mean = values.iter().sum::<f64>() / values.len() as f64;
603 prop_assert!(
604 (stats.mean() - batch_mean).abs() < 1e-8,
605 "running mean {:.10} != batch mean {:.10}",
606 stats.mean(), batch_mean
607 );
608 }
609
610 #[test]
611 fn running_stats_variance_non_negative(
612 values in proptest::collection::vec(-1000.0f64..1000.0, 2..200)
613 ) {
614 let mut stats = RunningStats::new();
615 for &v in &values {
616 stats.update(v);
617 }
618 prop_assert!(stats.var() >= 0.0, "variance must be non-negative");
619 prop_assert!(stats.std() >= 0.0, "std must be non-negative");
620 }
621
622 #[test]
623 fn running_stats_std_equals_sqrt_var(
624 values in proptest::collection::vec(-100.0f64..100.0, 2..100)
625 ) {
626 let mut stats = RunningStats::new();
627 for &v in &values {
628 stats.update(v);
629 }
630 let computed_std = stats.var().sqrt();
631 prop_assert!(
632 (stats.std() - computed_std).abs() < 1e-10,
633 "std {} != sqrt(var) {}",
634 stats.std(), computed_std
635 );
636 }
637
638 #[test]
639 fn running_stats_count_matches_updates(
640 values in proptest::collection::vec(-100.0f64..100.0, 0..200)
641 ) {
642 let mut stats = RunningStats::new();
643 for &v in &values {
644 stats.update(v);
645 }
646 prop_assert_eq!(stats.count() as usize, values.len());
647 }
648
649 #[test]
650 fn stats_vec_per_dim_mean_matches_naive(
651 dim in 1usize..8,
652 n_samples in 2usize..50,
653 ) {
654 let mut data = Vec::with_capacity(n_samples * dim);
656 for s in 0..n_samples {
657 for d in 0..dim {
658 data.push((s as f64) * 0.1 + (d as f64) * 10.0);
659 }
660 }
661
662 let mut stats = RunningStatsVec::new(dim);
663 stats.batch_update(&data, n_samples);
664
665 for d in 0..dim {
667 let sum: f64 = (0..n_samples).map(|s| data[s * dim + d]).sum();
668 let naive_mean = sum / n_samples as f64;
669 prop_assert!(
670 (stats.mean()[d] - naive_mean).abs() < 1e-8,
671 "dim {d}: running mean {} != naive mean {}",
672 stats.mean()[d], naive_mean
673 );
674 }
675 }
676
677 #[test]
678 fn stats_vec_variance_non_negative(
679 dim in 1usize..6,
680 n_samples in 2usize..50,
681 ) {
682 let mut data = Vec::with_capacity(n_samples * dim);
683 for s in 0..n_samples {
684 for d in 0..dim {
685 data.push((s as f64) * 0.7 - (d as f64) * 3.0);
686 }
687 }
688
689 let mut stats = RunningStatsVec::new(dim);
690 stats.batch_update(&data, n_samples);
691
692 for d in 0..dim {
693 prop_assert!(
694 stats.var()[d] >= 0.0,
695 "dim {d} variance must be non-negative, got {}",
696 stats.var()[d]
697 );
698 }
699 }
700
701 #[test]
702 fn stats_vec_normalize_roundtrip_z_mean_zero(
703 dim in 1usize..6,
704 n_samples in 5usize..50,
705 ) {
706 let mut data = Vec::with_capacity(n_samples * dim);
707 for s in 0..n_samples {
708 for d in 0..dim {
709 data.push((s as f64) * 1.3 + (d as f64) * 7.0);
710 }
711 }
712
713 let mut stats = RunningStatsVec::new(dim);
714 stats.batch_update(&data, n_samples);
715
716 let z = stats.normalize(&stats.mean());
718 for (d, &val) in z.iter().enumerate().take(dim) {
719 prop_assert!(
720 val.abs() < 1e-8,
721 "normalize(mean)[{d}] should be ~0, got {}",
722 val
723 );
724 }
725 }
726 }
727 }
728}