rlox_core/training/
gae.rs

1/// Result of a GAE computation, containing advantages and returns.
2#[derive(Debug, Clone)]
3pub struct GaeResult<T> {
4    /// Per-step advantage estimates.
5    pub advantages: Vec<T>,
6    /// Per-step return estimates (advantages + values).
7    pub returns: Vec<T>,
8}
9
10/// Compute Generalized Advantage Estimation.
11///
12/// Iterates backwards over the rollout, computing:
13///   delta_t = reward_t + gamma * V(t+1) * (1 - done_t) - V(t)
14///   A_t = delta_t + gamma * lambda * (1 - done_t) * A(t+1)
15///   return_t = A_t + V(t)
16///
17/// The `dones` slice uses `f64` where 0.0 = not done, 1.0 = done,
18/// matching the common Python/numpy convention.
19///
20/// # Panics
21///
22/// Panics in debug builds if `rewards`, `values`, and `dones` have
23/// different lengths.
24pub fn compute_gae(
25    rewards: &[f64],
26    values: &[f64],
27    dones: &[f64],
28    last_value: f64,
29    gamma: f64,
30    gae_lambda: f64,
31) -> (Vec<f64>, Vec<f64>) {
32    let n = rewards.len();
33    debug_assert_eq!(values.len(), n, "values.len() must equal rewards.len()");
34    debug_assert_eq!(dones.len(), n, "dones.len() must equal rewards.len()");
35    if n == 0 {
36        return (Vec::new(), Vec::new());
37    }
38
39    let mut advantages = vec![0.0; n];
40
41    // Peel last step to remove branch from inner loop
42    let last_nt = 1.0 - dones[n - 1];
43    let last_delta = rewards[n - 1] + gamma * last_value * last_nt - values[n - 1];
44    let mut last_gae = last_delta;
45    advantages[n - 1] = last_gae;
46
47    for t in (0..n - 1).rev() {
48        let next_non_terminal = 1.0 - dones[t];
49        let delta = rewards[t] + gamma * values[t + 1] * next_non_terminal - values[t];
50        last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
51        advantages[t] = last_gae;
52    }
53
54    let returns: Vec<f64> = advantages
55        .iter()
56        .zip(values.iter())
57        .map(|(a, v)| a + v)
58        .collect();
59
60    (advantages, returns)
61}
62
63/// Batched GAE: compute GAE for multiple environments in a single call.
64///
65/// All inputs are flat slices of length `n_envs * n_steps`, laid out as
66/// `[env0_step0, env0_step1, ..., env1_step0, env1_step1, ...]`.
67/// `last_values` has length `n_envs`.
68///
69/// Returns `(advantages, returns)` each of length `n_envs * n_steps`.
70///
71/// # Panics
72///
73/// Panics in debug builds if input slice lengths do not match
74/// `n_envs * n_steps`.
75pub fn compute_gae_batched(
76    rewards: &[f64],
77    values: &[f64],
78    dones: &[f64],
79    last_values: &[f64],
80    n_steps: usize,
81    gamma: f64,
82    gae_lambda: f64,
83) -> (Vec<f64>, Vec<f64>) {
84    let n_envs = last_values.len();
85    if n_envs == 0 || n_steps == 0 {
86        return (Vec::new(), Vec::new());
87    }
88    let expected_len = n_envs * n_steps;
89    debug_assert_eq!(
90        rewards.len(),
91        expected_len,
92        "rewards.len() must equal n_envs * n_steps"
93    );
94    debug_assert_eq!(
95        values.len(),
96        expected_len,
97        "values.len() must equal n_envs * n_steps"
98    );
99    debug_assert_eq!(
100        dones.len(),
101        expected_len,
102        "dones.len() must equal n_envs * n_steps"
103    );
104
105    use rayon::prelude::*;
106
107    let mut all_advantages = vec![0.0; n_envs * n_steps];
108    let mut all_returns = vec![0.0; n_envs * n_steps];
109
110    all_advantages
111        .par_chunks_mut(n_steps)
112        .zip(all_returns.par_chunks_mut(n_steps))
113        .enumerate()
114        .for_each(|(env_idx, (adv_chunk, ret_chunk))| {
115            let offset = env_idx * n_steps;
116            let r = &rewards[offset..offset + n_steps];
117            let v = &values[offset..offset + n_steps];
118            let d = &dones[offset..offset + n_steps];
119            let lv = last_values[env_idx];
120
121            let mut last_gae = 0.0;
122            for t in (0..n_steps).rev() {
123                let next_non_terminal = 1.0 - d[t];
124                let next_value = if t == n_steps - 1 { lv } else { v[t + 1] };
125                let delta = r[t] + gamma * next_value * next_non_terminal - v[t];
126                last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
127                adv_chunk[t] = last_gae;
128                ret_chunk[t] = last_gae + v[t];
129            }
130        });
131
132    (all_advantages, all_returns)
133}
134
135/// Batched GAE in f32 — avoids f64 conversion overhead from Python.
136///
137/// Same layout as `compute_gae_batched` but operates on f32.
138///
139/// # Panics
140///
141/// Panics in debug builds if input slice lengths do not match
142/// `n_envs * n_steps`.
143pub fn compute_gae_batched_f32(
144    rewards: &[f32],
145    values: &[f32],
146    dones: &[f32],
147    last_values: &[f32],
148    n_steps: usize,
149    gamma: f32,
150    gae_lambda: f32,
151) -> (Vec<f32>, Vec<f32>) {
152    let n_envs = last_values.len();
153    if n_envs == 0 || n_steps == 0 {
154        return (Vec::new(), Vec::new());
155    }
156    let expected_len = n_envs * n_steps;
157    debug_assert_eq!(
158        rewards.len(),
159        expected_len,
160        "rewards.len() must equal n_envs * n_steps"
161    );
162    debug_assert_eq!(
163        values.len(),
164        expected_len,
165        "values.len() must equal n_envs * n_steps"
166    );
167    debug_assert_eq!(
168        dones.len(),
169        expected_len,
170        "dones.len() must equal n_envs * n_steps"
171    );
172
173    use rayon::prelude::*;
174
175    let mut all_advantages = vec![0.0f32; n_envs * n_steps];
176    let mut all_returns = vec![0.0f32; n_envs * n_steps];
177
178    all_advantages
179        .par_chunks_mut(n_steps)
180        .zip(all_returns.par_chunks_mut(n_steps))
181        .enumerate()
182        .for_each(|(env_idx, (adv_chunk, ret_chunk))| {
183            let offset = env_idx * n_steps;
184            let r = &rewards[offset..offset + n_steps];
185            let v = &values[offset..offset + n_steps];
186            let d = &dones[offset..offset + n_steps];
187            let lv = last_values[env_idx];
188
189            // Peel last step
190            let last_nt = 1.0 - d[n_steps - 1];
191            let last_delta = r[n_steps - 1] + gamma * lv * last_nt - v[n_steps - 1];
192            let mut last_gae = last_delta;
193            adv_chunk[n_steps - 1] = last_gae;
194            ret_chunk[n_steps - 1] = last_gae + v[n_steps - 1];
195
196            for t in (0..n_steps - 1).rev() {
197                let next_non_terminal = 1.0 - d[t];
198                let delta = r[t] + gamma * v[t + 1] * next_non_terminal - v[t];
199                last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
200                adv_chunk[t] = last_gae;
201                ret_chunk[t] = last_gae + v[t];
202            }
203        });
204
205    (all_advantages, all_returns)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    // Helper to convert bool slices to f64 for the test interface
213    fn bools_to_f64(bools: &[bool]) -> Vec<f64> {
214        bools.iter().map(|&b| if b { 1.0 } else { 0.0 }).collect()
215    }
216
217    #[test]
218    fn gae_single_step_episode() {
219        let rewards = &[1.0];
220        let values = &[0.5];
221        let dones = bools_to_f64(&[true]);
222        let last_value = 0.0;
223        let gamma = 0.99;
224        let gae_lambda = 0.95;
225        let (advantages, _returns) =
226            compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
227        assert_eq!(advantages.len(), 1);
228        assert!((advantages[0] - 0.5).abs() < 1e-6);
229    }
230
231    #[test]
232    fn gae_multi_step_no_termination() {
233        let rewards = &[1.0, 1.0, 1.0];
234        let values = &[0.0, 0.0, 0.0];
235        let dones = bools_to_f64(&[false, false, false]);
236        let last_value = 0.0;
237        let gamma = 0.99;
238        let gae_lambda = 0.95;
239        let (advantages, _returns) =
240            compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
241        assert_eq!(advantages.len(), 3);
242        // Last step: delta = 1.0 + 0.99*0 - 0 = 1.0, A = 1.0
243        assert!((advantages[2] - 1.0).abs() < 1e-6);
244        // Second step: delta = 1.0, A = 1.0 + 0.99*0.95*1.0 = 1.9405
245        assert!((advantages[1] - 1.9405).abs() < 1e-4);
246        // First step: A = 1.0 + 0.99*0.95*1.9405 = 2.82...
247        assert!(advantages[0] > advantages[1]);
248    }
249
250    #[test]
251    fn gae_resets_at_episode_boundary() {
252        let rewards = &[1.0, 1.0, 1.0];
253        let values = &[0.0, 0.0, 0.0];
254        let dones = bools_to_f64(&[false, true, false]);
255        let last_value = 0.0;
256        let gamma = 0.99;
257        let gae_lambda = 0.95;
258        let (advantages, _) = compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
259        // Step 1 (terminal): delta = 1.0 + 0 - 0 = 1.0
260        assert!((advantages[1] - 1.0).abs() < 1e-6);
261    }
262
263    #[test]
264    fn gae_returns_are_advantages_plus_values() {
265        let rewards = &[1.0, 2.0, 3.0];
266        let values = &[0.5, 1.0, 1.5];
267        let dones = bools_to_f64(&[false, false, true]);
268        let last_value = 0.0;
269        let (advantages, returns) = compute_gae(rewards, values, &dones, last_value, 0.99, 0.95);
270        for i in 0..3 {
271            assert!((returns[i] - (advantages[i] + values[i])).abs() < 1e-6);
272        }
273    }
274
275    #[test]
276    fn gae_empty_input() {
277        let (advantages, returns) = compute_gae(&[], &[], &[], 0.0, 0.99, 0.95);
278        assert!(advantages.is_empty());
279        assert!(returns.is_empty());
280    }
281
282    #[test]
283    fn gae_lambda_zero_is_one_step_td() {
284        let rewards = &[1.0, 1.0];
285        let values = &[0.5, 0.5];
286        let dones = bools_to_f64(&[false, false]);
287        let last_value = 0.5;
288        let (advantages, _) = compute_gae(rewards, values, &dones, last_value, 0.99, 0.0);
289        // delta_1 = 1.0 + 0.99*0.5 - 0.5 = 0.995, advantage = delta (lambda=0)
290        assert!((advantages[1] - 0.995).abs() < 1e-6);
291    }
292
293    #[test]
294    fn gae_lambda_one_is_monte_carlo() {
295        let rewards = &[1.0, 1.0, 1.0];
296        let values = &[0.0, 0.0, 0.0];
297        let dones = bools_to_f64(&[false, false, true]);
298        let (advantages, _) = compute_gae(rewards, values, &dones, 0.0, 0.99, 1.0);
299        // Monte Carlo return from step 0: 1 + 0.99 + 0.99^2 = 2.9701
300        assert!((advantages[0] - 2.9701).abs() < 1e-3);
301    }
302
303    #[test]
304    fn gae_batched_matches_unbatched() {
305        let gamma = 0.99;
306        let lam = 0.95;
307        // Two envs, 3 steps each
308        let rewards = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
309        let values = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
310        let dones = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
311        let last_values = vec![0.0, 0.5];
312
313        let (adv_b, ret_b) =
314            compute_gae_batched(&rewards, &values, &dones, &last_values, 3, gamma, lam);
315
316        let (adv0, ret0) = compute_gae(
317            &rewards[..3],
318            &values[..3],
319            &dones[..3],
320            last_values[0],
321            gamma,
322            lam,
323        );
324        let (adv1, ret1) = compute_gae(
325            &rewards[3..],
326            &values[3..],
327            &dones[3..],
328            last_values[1],
329            gamma,
330            lam,
331        );
332
333        for i in 0..3 {
334            assert!(
335                (adv_b[i] - adv0[i]).abs() < 1e-12,
336                "env0 adv mismatch at {i}"
337            );
338            assert!(
339                (ret_b[i] - ret0[i]).abs() < 1e-12,
340                "env0 ret mismatch at {i}"
341            );
342            assert!(
343                (adv_b[3 + i] - adv1[i]).abs() < 1e-12,
344                "env1 adv mismatch at {i}"
345            );
346            assert!(
347                (ret_b[3 + i] - ret1[i]).abs() < 1e-12,
348                "env1 ret mismatch at {i}"
349            );
350        }
351    }
352
353    #[test]
354    fn gae_batched_empty() {
355        let (adv, ret) = compute_gae_batched(&[], &[], &[], &[], 0, 0.99, 0.95);
356        assert!(adv.is_empty());
357        assert!(ret.is_empty());
358    }
359
360    #[test]
361    fn gae_batched_f32_matches_f64() {
362        let gamma = 0.99f32;
363        let lam = 0.95f32;
364        let rewards_f32: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
365        let values_f32: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
366        let dones_f32: Vec<f32> = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
367        let last_values_f32: Vec<f32> = vec![0.0, 0.5];
368
369        let (adv_f32, ret_f32) = compute_gae_batched_f32(
370            &rewards_f32,
371            &values_f32,
372            &dones_f32,
373            &last_values_f32,
374            3,
375            gamma,
376            lam,
377        );
378
379        let rewards_f64: Vec<f64> = rewards_f32.iter().map(|&x| x as f64).collect();
380        let values_f64: Vec<f64> = values_f32.iter().map(|&x| x as f64).collect();
381        let dones_f64: Vec<f64> = dones_f32.iter().map(|&x| x as f64).collect();
382        let last_values_f64: Vec<f64> = last_values_f32.iter().map(|&x| x as f64).collect();
383
384        let (adv_f64, ret_f64) = compute_gae_batched(
385            &rewards_f64,
386            &values_f64,
387            &dones_f64,
388            &last_values_f64,
389            3,
390            0.99,
391            0.95,
392        );
393
394        for i in 0..6 {
395            assert!(
396                (adv_f32[i] as f64 - adv_f64[i]).abs() < 1e-5,
397                "adv mismatch at {i}"
398            );
399            assert!(
400                (ret_f32[i] as f64 - ret_f64[i]).abs() < 1e-5,
401                "ret mismatch at {i}"
402            );
403        }
404    }
405
406    mod proptests {
407        use super::*;
408        use proptest::prelude::*;
409
410        proptest! {
411            #[test]
412            fn gae_returns_equal_advantages_plus_values(n in 1..500usize) {
413                let rewards: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
414                let values: Vec<f64> = (0..n).map(|i| (i as f64) * 0.05).collect();
415                let dones: Vec<f64> = (0..n).map(|i| if i % 10 == 9 { 1.0 } else { 0.0 }).collect();
416                let (advantages, returns) = compute_gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);
417                for i in 0..n {
418                    let diff = (returns[i] - (advantages[i] + values[i])).abs();
419                    prop_assert!(diff < 1e-10, "mismatch at index {}: returns={}, adv+val={}", i, returns[i], advantages[i] + values[i]);
420                }
421            }
422
423            #[test]
424            fn gae_length_matches_input(n in 0..500usize) {
425                let rewards = vec![1.0; n];
426                let values = vec![0.5; n];
427                let dones = vec![0.0; n];
428                let (advantages, returns) = compute_gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);
429                prop_assert_eq!(advantages.len(), n);
430                prop_assert_eq!(returns.len(), n);
431            }
432        }
433    }
434}