1pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
8 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
9 let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
10 let log_sum_exp = max + sum_exp.ln();
11 logits.iter().map(|&x| x - log_sum_exp).collect()
12}
13
14pub fn categorical_sample(logits: &[f32], uniform_rand: f32) -> usize {
17 let log_probs = log_softmax(logits);
18 let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
19
20 let mut cumsum = 0.0;
21 for (i, &p) in probs.iter().enumerate() {
22 cumsum += p;
23 if uniform_rand < cumsum {
24 return i;
25 }
26 }
27 probs.len() - 1
28}
29
30pub fn categorical_log_prob(logits: &[f32], action: usize) -> f32 {
32 let log_probs = log_softmax(logits);
33 log_probs[action]
34}
35
36pub fn categorical_entropy(logits: &[f32]) -> f32 {
38 let log_probs = log_softmax(logits);
39 let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
40 -probs
41 .iter()
42 .zip(log_probs.iter())
43 .map(|(&p, &lp)| if p > 0.0 { p * lp } else { 0.0 })
44 .sum::<f32>()
45}
46
47pub fn normal_log_prob(x: f32, mean: f32, std: f32) -> f32 {
49 let var = std * std;
50 let log_std = std.ln();
51 -0.5 * ((x - mean) * (x - mean) / var + 2.0 * log_std + (2.0 * std::f32::consts::PI).ln())
52}
53
54pub fn normal_entropy(std: f32) -> f32 {
56 0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E * std * std).ln()
57}
58
59pub fn tanh_log_prob_correction(pre_tanh: f32) -> f32 {
61 let tanh_x = pre_tanh.tanh();
62 -(1.0 - tanh_x * tanh_x + 1e-6).ln()
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn test_log_softmax_sums_to_one() {
71 let logits = vec![1.0, 2.0, 3.0];
72 let log_probs = log_softmax(&logits);
73 let sum: f32 = log_probs.iter().map(|&lp| lp.exp()).sum();
74 assert!(
75 (sum - 1.0).abs() < 1e-5,
76 "softmax should sum to 1, got {sum}"
77 );
78 }
79
80 #[test]
81 fn test_log_softmax_with_large_logits() {
82 let logits = vec![1000.0, 1001.0, 1002.0];
83 let log_probs = log_softmax(&logits);
84 let sum: f32 = log_probs.iter().map(|&lp| lp.exp()).sum();
85 assert!((sum - 1.0).abs() < 1e-3, "numerical stability: {sum}");
86 }
87
88 #[test]
89 fn test_log_softmax_uniform() {
90 let logits = vec![0.0, 0.0, 0.0, 0.0];
91 let log_probs = log_softmax(&logits);
92 let expected = -(4.0_f32).ln();
93 for &lp in &log_probs {
94 assert!((lp - expected).abs() < 1e-6);
95 }
96 }
97
98 #[test]
99 fn test_categorical_sample_boundaries() {
100 let logits = vec![0.0, 0.0]; assert_eq!(categorical_sample(&logits, 0.0), 0);
102 assert_eq!(categorical_sample(&logits, 0.49), 0);
103 assert_eq!(categorical_sample(&logits, 0.51), 1);
104 assert_eq!(categorical_sample(&logits, 0.99), 1);
105 }
106
107 #[test]
108 fn test_categorical_sample_skewed() {
109 let logits = vec![10.0, 0.0];
111 assert_eq!(categorical_sample(&logits, 0.5), 0);
112 }
113
114 #[test]
115 fn test_categorical_log_prob() {
116 let logits = vec![0.0, 0.0, 0.0];
117 let lp = categorical_log_prob(&logits, 1);
118 let expected = -(3.0_f32).ln();
119 assert!((lp - expected).abs() < 1e-5);
120 }
121
122 #[test]
123 fn test_categorical_entropy_uniform() {
124 let logits = vec![0.0, 0.0, 0.0, 0.0];
125 let ent = categorical_entropy(&logits);
126 let expected = (4.0_f32).ln(); assert!(
128 (ent - expected).abs() < 1e-5,
129 "expected {expected}, got {ent}"
130 );
131 }
132
133 #[test]
134 fn test_categorical_entropy_deterministic() {
135 let logits = vec![100.0, -100.0, -100.0];
136 let ent = categorical_entropy(&logits);
137 assert!(
138 ent < 0.01,
139 "near-deterministic should have low entropy: {ent}"
140 );
141 }
142
143 #[test]
144 fn test_normal_log_prob() {
145 let lp = normal_log_prob(0.0, 0.0, 1.0);
147 let expected = -0.5 * (2.0 * std::f32::consts::PI).ln();
148 assert!((lp - expected).abs() < 1e-5);
149 }
150
151 #[test]
152 fn test_normal_log_prob_shifted() {
153 let lp = normal_log_prob(2.0, 2.0, 1.0);
154 let lp_center = normal_log_prob(0.0, 0.0, 1.0);
155 assert!((lp - lp_center).abs() < 1e-5, "shifted mean at center");
156 }
157
158 #[test]
159 fn test_normal_entropy() {
160 let ent = normal_entropy(1.0);
161 let expected = 0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln();
162 assert!((ent - expected).abs() < 1e-5);
163 }
164
165 #[test]
166 fn test_normal_entropy_wider_is_larger() {
167 let ent1 = normal_entropy(1.0);
168 let ent2 = normal_entropy(2.0);
169 assert!(ent2 > ent1, "wider distribution should have higher entropy");
170 }
171
172 #[test]
173 fn test_tanh_correction_at_zero() {
174 let correction = tanh_log_prob_correction(0.0);
175 assert!(correction.abs() < 0.01);
177 }
178
179 #[test]
180 fn test_tanh_correction_increases_at_extremes() {
181 let c_small = tanh_log_prob_correction(0.5);
182 let c_large = tanh_log_prob_correction(3.0);
183 assert!(
184 c_large > c_small,
185 "correction should increase at extremes: {c_small} vs {c_large}"
186 );
187 }
188}