rlox_nn/
distributions.rs

1/// Pure-Rust distribution utilities for CPU-based sampling and log-prob.
2/// These are backend-independent helpers that can be used by any backend
3/// or for testing without a NN framework.
4
5/// Compute log(softmax(logits)) in a numerically stable way.
6/// Returns a vector of the same length as logits.
7pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
8    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
9    let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
10    let log_sum_exp = max + sum_exp.ln();
11    logits.iter().map(|&x| x - log_sum_exp).collect()
12}
13
14/// Sample from a categorical distribution given logits.
15/// Uses the Gumbel-max trick for differentiable-friendly sampling.
16pub fn categorical_sample(logits: &[f32], uniform_rand: f32) -> usize {
17    let log_probs = log_softmax(logits);
18    let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
19
20    let mut cumsum = 0.0;
21    for (i, &p) in probs.iter().enumerate() {
22        cumsum += p;
23        if uniform_rand < cumsum {
24            return i;
25        }
26    }
27    probs.len() - 1
28}
29
30/// Compute log_prob for a categorical distribution.
31pub fn categorical_log_prob(logits: &[f32], action: usize) -> f32 {
32    let log_probs = log_softmax(logits);
33    log_probs[action]
34}
35
36/// Compute entropy of a categorical distribution from logits.
37pub fn categorical_entropy(logits: &[f32]) -> f32 {
38    let log_probs = log_softmax(logits);
39    let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
40    -probs
41        .iter()
42        .zip(log_probs.iter())
43        .map(|(&p, &lp)| if p > 0.0 { p * lp } else { 0.0 })
44        .sum::<f32>()
45}
46
47/// Compute log_prob for a normal distribution.
48pub fn normal_log_prob(x: f32, mean: f32, std: f32) -> f32 {
49    let var = std * std;
50    let log_std = std.ln();
51    -0.5 * ((x - mean) * (x - mean) / var + 2.0 * log_std + (2.0 * std::f32::consts::PI).ln())
52}
53
54/// Compute entropy of a normal distribution.
55pub fn normal_entropy(std: f32) -> f32 {
56    0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E * std * std).ln()
57}
58
59/// Tanh squashing log-prob correction: log_prob -= log(1 - tanh(x)^2 + eps)
60pub fn tanh_log_prob_correction(pre_tanh: f32) -> f32 {
61    let tanh_x = pre_tanh.tanh();
62    -(1.0 - tanh_x * tanh_x + 1e-6).ln()
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_log_softmax_sums_to_one() {
71        let logits = vec![1.0, 2.0, 3.0];
72        let log_probs = log_softmax(&logits);
73        let sum: f32 = log_probs.iter().map(|&lp| lp.exp()).sum();
74        assert!(
75            (sum - 1.0).abs() < 1e-5,
76            "softmax should sum to 1, got {sum}"
77        );
78    }
79
80    #[test]
81    fn test_log_softmax_with_large_logits() {
82        let logits = vec![1000.0, 1001.0, 1002.0];
83        let log_probs = log_softmax(&logits);
84        let sum: f32 = log_probs.iter().map(|&lp| lp.exp()).sum();
85        assert!((sum - 1.0).abs() < 1e-3, "numerical stability: {sum}");
86    }
87
88    #[test]
89    fn test_log_softmax_uniform() {
90        let logits = vec![0.0, 0.0, 0.0, 0.0];
91        let log_probs = log_softmax(&logits);
92        let expected = -(4.0_f32).ln();
93        for &lp in &log_probs {
94            assert!((lp - expected).abs() < 1e-6);
95        }
96    }
97
98    #[test]
99    fn test_categorical_sample_boundaries() {
100        let logits = vec![0.0, 0.0]; // equal probs
101        assert_eq!(categorical_sample(&logits, 0.0), 0);
102        assert_eq!(categorical_sample(&logits, 0.49), 0);
103        assert_eq!(categorical_sample(&logits, 0.51), 1);
104        assert_eq!(categorical_sample(&logits, 0.99), 1);
105    }
106
107    #[test]
108    fn test_categorical_sample_skewed() {
109        // logits [10, 0] -> prob ~[1, 0]
110        let logits = vec![10.0, 0.0];
111        assert_eq!(categorical_sample(&logits, 0.5), 0);
112    }
113
114    #[test]
115    fn test_categorical_log_prob() {
116        let logits = vec![0.0, 0.0, 0.0];
117        let lp = categorical_log_prob(&logits, 1);
118        let expected = -(3.0_f32).ln();
119        assert!((lp - expected).abs() < 1e-5);
120    }
121
122    #[test]
123    fn test_categorical_entropy_uniform() {
124        let logits = vec![0.0, 0.0, 0.0, 0.0];
125        let ent = categorical_entropy(&logits);
126        let expected = (4.0_f32).ln(); // max entropy for 4 categories
127        assert!(
128            (ent - expected).abs() < 1e-5,
129            "expected {expected}, got {ent}"
130        );
131    }
132
133    #[test]
134    fn test_categorical_entropy_deterministic() {
135        let logits = vec![100.0, -100.0, -100.0];
136        let ent = categorical_entropy(&logits);
137        assert!(
138            ent < 0.01,
139            "near-deterministic should have low entropy: {ent}"
140        );
141    }
142
143    #[test]
144    fn test_normal_log_prob() {
145        // Standard normal: log_prob(0) = -0.5 * ln(2π)
146        let lp = normal_log_prob(0.0, 0.0, 1.0);
147        let expected = -0.5 * (2.0 * std::f32::consts::PI).ln();
148        assert!((lp - expected).abs() < 1e-5);
149    }
150
151    #[test]
152    fn test_normal_log_prob_shifted() {
153        let lp = normal_log_prob(2.0, 2.0, 1.0);
154        let lp_center = normal_log_prob(0.0, 0.0, 1.0);
155        assert!((lp - lp_center).abs() < 1e-5, "shifted mean at center");
156    }
157
158    #[test]
159    fn test_normal_entropy() {
160        let ent = normal_entropy(1.0);
161        let expected = 0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln();
162        assert!((ent - expected).abs() < 1e-5);
163    }
164
165    #[test]
166    fn test_normal_entropy_wider_is_larger() {
167        let ent1 = normal_entropy(1.0);
168        let ent2 = normal_entropy(2.0);
169        assert!(ent2 > ent1, "wider distribution should have higher entropy");
170    }
171
172    #[test]
173    fn test_tanh_correction_at_zero() {
174        let correction = tanh_log_prob_correction(0.0);
175        // tanh(0) = 0, so correction = -ln(1 + eps) ≈ 0
176        assert!(correction.abs() < 0.01);
177    }
178
179    #[test]
180    fn test_tanh_correction_increases_at_extremes() {
181        let c_small = tanh_log_prob_correction(0.5);
182        let c_large = tanh_log_prob_correction(3.0);
183        assert!(
184            c_large > c_small,
185            "correction should increase at extremes: {c_small} vs {c_large}"
186        );
187    }
188}