rlox_core/training/
vtrace.rs

1use crate::error::RloxError;
2
3/// Compute V-trace targets and policy gradient advantages (Espeholt et al. 2018).
4///
5/// Processes backwards from t=n-1 to t=0:
6///   rho_t = min(rho_bar, exp(log_rhos[t]))
7///   c_t   = min(c_bar,   exp(log_rhos[t]))
8///   non_terminal = 1.0 - dones[t]
9///   delta_t = rho_t * (rewards[t] + gamma * non_terminal * values[t+1] - values[t])
10///   vs[t]   = values[t] + delta_t + gamma * non_terminal * c_t * (vs[t+1] - values[t+1])
11///   pg_advantages[t] = rho_t * (rewards[t] + gamma * non_terminal * vs[t+1] - values[t])
12///
13/// Uses `bootstrap_value` for values[n] and vs[n], zeroed when the last step
14/// is terminal (`dones[n-1] == 1.0`).
15///
16/// Returns `(vs, pg_advantages)`.
17pub fn compute_vtrace(
18    log_rhos: &[f32],
19    rewards: &[f32],
20    values: &[f32],
21    dones: &[f32],
22    bootstrap_value: f32,
23    gamma: f32,
24    rho_bar: f32,
25    c_bar: f32,
26) -> Result<(Vec<f32>, Vec<f32>), RloxError> {
27    let n = log_rhos.len();
28
29    if rewards.len() != n || values.len() != n || dones.len() != n {
30        return Err(RloxError::ShapeMismatch {
31            expected: format!("all slices length {n}"),
32            got: format!(
33                "log_rhos={}, rewards={}, values={}, dones={}",
34                n,
35                rewards.len(),
36                values.len(),
37                dones.len()
38            ),
39        });
40    }
41
42    if n == 0 {
43        return Ok((Vec::new(), Vec::new()));
44    }
45
46    let mut vs = vec![0.0f32; n];
47    let mut pg_advantages = vec![0.0f32; n];
48
49    // Handle last step (t = n-1) outside the loop
50    let last = n - 1;
51    {
52        let ratio = log_rhos[last].exp();
53        let rho_t = rho_bar.min(ratio);
54        let non_terminal = 1.0 - dones[last];
55        let next_value = bootstrap_value * non_terminal;
56
57        let delta_t = rho_t * (rewards[last] + gamma * next_value - values[last]);
58        // vs_next for the last step is bootstrap_value (zeroed if terminal)
59        let vs_next_val = bootstrap_value * non_terminal;
60        vs[last] = values[last]
61            + delta_t
62            + gamma * non_terminal * rho_bar.min(ratio).min(c_bar) * (vs_next_val - next_value);
63        pg_advantages[last] = rho_t * (rewards[last] + gamma * vs_next_val - values[last]);
64    }
65
66    // Iterate backwards for remaining steps
67    let mut vs_next = vs[last];
68
69    for t in (0..last).rev() {
70        let ratio = log_rhos[t].exp();
71        let rho_t = rho_bar.min(ratio);
72        let c_t = c_bar.min(ratio);
73        let non_terminal = 1.0 - dones[t];
74
75        let next_value = values[t + 1];
76
77        let delta_t = rho_t * (rewards[t] + gamma * non_terminal * next_value - values[t]);
78        vs[t] = values[t] + delta_t + gamma * non_terminal * c_t * (vs_next - next_value);
79        pg_advantages[t] = rho_t * (rewards[t] + gamma * non_terminal * vs_next - values[t]);
80
81        vs_next = vs[t];
82    }
83
84    Ok((vs, pg_advantages))
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn vtrace_empty_input() {
93        let (vs, adv) = compute_vtrace(&[], &[], &[], &[], 0.0, 0.99, 1.0, 1.0).unwrap();
94        assert!(vs.is_empty());
95        assert!(adv.is_empty());
96    }
97
98    #[test]
99    fn vtrace_mismatched_lengths() {
100        let result = compute_vtrace(&[0.0], &[1.0, 2.0], &[0.5], &[0.0], 0.0, 0.99, 1.0, 1.0);
101        assert!(result.is_err());
102    }
103
104    #[test]
105    fn vtrace_on_policy_matches_gae_like() {
106        // When log_rhos = 0 (on-policy), rho=1, c=1 => V-trace reduces to
107        // something close to GAE(lambda=1)
108        let log_rhos = vec![0.0; 3];
109        let rewards = vec![1.0, 1.0, 1.0];
110        let values = vec![0.0, 0.0, 0.0];
111        let bootstrap = 0.0;
112        let gamma = 0.99;
113
114        let dones = vec![0.0; 3];
115        let (vs, _adv) = compute_vtrace(
116            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
117        )
118        .unwrap();
119
120        // On-policy with rho=c=1:
121        // t=2: delta = 1*(1 + 0.99*0 - 0) = 1, vs[2] = 0 + 1 + 0.99*1*(0-0) = 1
122        // t=1: delta = 1*(1 + 0.99*0 - 0) = 1, vs[1] = 0 + 1 + 0.99*1*(1-0) = 1.99
123        // t=0: delta = 1*(1 + 0.99*0 - 0) = 1, vs[0] = 0 + 1 + 0.99*1*(1.99-0) = 2.9701
124        assert!((vs[2] - 1.0).abs() < 1e-5);
125        assert!((vs[1] - 1.99).abs() < 1e-5);
126        assert!((vs[0] - 2.9701).abs() < 1e-4);
127    }
128
129    #[test]
130    fn vtrace_single_step() {
131        // Use rho_bar large enough so clipping doesn't engage
132        let log_rho = 0.5_f32;
133        let log_rhos = vec![log_rho];
134        let rewards = vec![1.0];
135        let values = vec![0.5];
136        let bootstrap = 0.0;
137        let gamma = 0.99;
138        let rho_bar = 10.0; // no clipping
139        let c_bar = 10.0;
140
141        let dones = vec![0.0];
142        let (vs, adv) = compute_vtrace(
143            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
144        )
145        .unwrap();
146
147        let rho = log_rho.exp(); // ~1.6487
148        let _c = c_bar.min(rho);
149        // t=0 (only step, n-1): next_value = bootstrap = 0, vs_next = bootstrap = 0
150        // delta = rho * (1.0 + 0.99*0 - 0.5) = rho * 0.5
151        // vs[0] = 0.5 + rho*0.5 + 0.99*c*(bootstrap - bootstrap) = 0.5 + rho*0.5
152        // pg_adv[0] = rho * (1.0 + 0.99*bootstrap - 0.5) = rho * 0.5
153        let expected_vs = 0.5 + rho * 0.5;
154        let expected_adv = rho * 0.5;
155
156        assert!(
157            (vs[0] - expected_vs).abs() < 1e-5,
158            "vs[0]={}, expected={}",
159            vs[0],
160            expected_vs
161        );
162        assert!(
163            (adv[0] - expected_adv).abs() < 1e-5,
164            "adv[0]={}, expected={}",
165            adv[0],
166            expected_adv
167        );
168    }
169
170    #[test]
171    fn vtrace_clipping_reduces_correction() {
172        // With very large importance ratio, clipping should limit the correction
173        let log_rhos = vec![5.0]; // exp(5) ~ 148, way above rho_bar=1
174        let rewards = vec![1.0];
175        let values = vec![0.0];
176        let bootstrap = 0.0;
177        let gamma = 0.99;
178
179        let dones = vec![0.0];
180        let (vs_clipped, _) = compute_vtrace(
181            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
182        )
183        .unwrap();
184        let (vs_unclipped, _) = compute_vtrace(
185            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 200.0, 200.0,
186        )
187        .unwrap();
188
189        // With rho_bar=1.0, rho is clamped to 1.0 => delta = 1*(1-0) = 1 => vs = 1.0
190        assert!((vs_clipped[0] - 1.0).abs() < 1e-5);
191        // Without clipping, rho ~ 148 => delta = 148*(1-0) = 148 => vs = 148
192        assert!(vs_unclipped[0] > 100.0);
193    }
194
195    #[test]
196    fn vtrace_output_lengths_match_input() {
197        let n = 10;
198        let log_rhos = vec![0.0; n];
199        let rewards = vec![1.0; n];
200        let values = vec![0.5; n];
201        let dones = vec![0.0; n];
202        let (vs, adv) =
203            compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
204        assert_eq!(vs.len(), n);
205        assert_eq!(adv.len(), n);
206    }
207
208    #[test]
209    fn vtrace_reference_implementation() {
210        // Reference: manually compute for a 3-step trajectory
211        let gamma = 0.9_f32;
212        let rho_bar = 1.5_f32;
213        let c_bar = 1.2_f32;
214
215        let log_rhos = vec![0.2, -0.3, 0.8];
216        let rewards = vec![1.0, 2.0, 3.0];
217        let values = vec![0.5, 1.0, 1.5];
218        let bootstrap = 2.0;
219
220        // Manually compute backwards:
221        // t=2: rho = min(1.5, exp(0.8)) = min(1.5, 2.2255) = 1.5
222        //       c  = min(1.2, 2.2255) = 1.2
223        //       next_val = bootstrap = 2.0
224        //       delta = 1.5 * (3.0 + 0.9*2.0 - 1.5) = 1.5 * 3.3 = 4.95
225        //       vs[2] = 1.5 + 4.95 + 0.9*1.2*(2.0 - 2.0) = 6.45
226        //       pg_adv[2] = 1.5 * (3.0 + 0.9*2.0 - 1.5) = 4.95
227        //       vs_next = 6.45
228        let rho_2 = 1.5_f32;
229        let c_2 = 1.2_f32;
230        let delta_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
231        let vs_2 = 1.5 + delta_2 + 0.9 * c_2 * (2.0 - 2.0);
232        let pg_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
233
234        // t=1: rho = min(1.5, exp(-0.3)) = min(1.5, 0.7408) = 0.7408
235        //       c  = min(1.2, 0.7408) = 0.7408
236        //       next_val = values[2] = 1.5
237        //       delta = 0.7408 * (2.0 + 0.9*1.5 - 1.0) = 0.7408 * 2.35 = 1.74088
238        //       vs[1] = 1.0 + 1.74088 + 0.9*0.7408*(6.45 - 1.5) = 1.0 + 1.74088 + 0.9*0.7408*4.95
239        //       pg_adv[1] = 0.7408 * (2.0 + 0.9*6.45 - 1.0) = 0.7408 * (2.0 + 5.805 - 1.0) = 0.7408 * 6.805
240        let rho_1 = (-0.3_f32).exp();
241        let c_1 = c_bar.min(rho_1);
242        let delta_1 = rho_1 * (2.0 + 0.9 * 1.5 - 1.0);
243        let vs_1 = 1.0 + delta_1 + 0.9 * c_1 * (vs_2 - 1.5);
244        let pg_1 = rho_1 * (2.0 + 0.9 * vs_2 - 1.0);
245
246        // t=0: rho = min(1.5, exp(0.2)) = min(1.5, 1.2214) = 1.2214
247        //       c  = min(1.2, 1.2214) = 1.2
248        //       next_val = values[1] = 1.0
249        //       delta = 1.2214 * (1.0 + 0.9*1.0 - 0.5) = 1.2214 * 1.4
250        //       vs[0] = 0.5 + delta + 0.9*1.2*(vs_1 - 1.0)
251        //       pg_adv[0] = 1.2214 * (1.0 + 0.9*vs_1 - 0.5)
252        let rho_0 = (0.2_f32).exp();
253        let c_0 = c_bar.min(rho_0);
254        let delta_0 = rho_0 * (1.0 + 0.9 * 1.0 - 0.5);
255        let vs_0 = 0.5 + delta_0 + 0.9 * c_0 * (vs_1 - 1.0);
256        let pg_0 = rho_0 * (1.0 + 0.9 * vs_1 - 0.5);
257
258        let dones = vec![0.0; 3];
259        let (vs, adv) = compute_vtrace(
260            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
261        )
262        .unwrap();
263
264        assert!(
265            (vs[0] - vs_0).abs() < 1e-4,
266            "vs[0]: got {}, expected {}",
267            vs[0],
268            vs_0
269        );
270        assert!(
271            (vs[1] - vs_1).abs() < 1e-4,
272            "vs[1]: got {}, expected {}",
273            vs[1],
274            vs_1
275        );
276        assert!(
277            (vs[2] - vs_2).abs() < 1e-4,
278            "vs[2]: got {}, expected {}",
279            vs[2],
280            vs_2
281        );
282        assert!(
283            (adv[0] - pg_0).abs() < 1e-4,
284            "adv[0]: got {}, expected {}",
285            adv[0],
286            pg_0
287        );
288        assert!(
289            (adv[1] - pg_1).abs() < 1e-4,
290            "adv[1]: got {}, expected {}",
291            adv[1],
292            pg_1
293        );
294        assert!(
295            (adv[2] - pg_2).abs() < 1e-4,
296            "adv[2]: got {}, expected {}",
297            adv[2],
298            pg_2
299        );
300    }
301
302    #[test]
303    fn vtrace_with_dones_resets_at_boundary() {
304        // 4-step trajectory with a done at t=1. Episode boundary should
305        // prevent rewards from leaking across episodes.
306        let gamma = 0.99_f32;
307        let log_rhos = vec![0.0; 4]; // on-policy
308        let rewards = vec![1.0, 1.0, 1.0, 1.0];
309        let values = vec![0.0; 4];
310        let dones = vec![0.0, 1.0, 0.0, 0.0]; // done at t=1
311        let bootstrap = 0.0;
312
313        let (vs_with_dones, _) = compute_vtrace(
314            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
315        )
316        .unwrap();
317
318        // Without dones, rewards leak across episodes
319        let no_dones = vec![0.0; 4];
320        let (vs_no_dones, _) = compute_vtrace(
321            &log_rhos, &rewards, &values, &no_dones, bootstrap, gamma, 1.0, 1.0,
322        )
323        .unwrap();
324
325        // After the boundary (t=0), the done-aware version should produce
326        // a LOWER vs because future rewards beyond the boundary are zeroed.
327        assert!(
328            vs_with_dones[0] < vs_no_dones[0],
329            "vs_with_dones[0]={} should be < vs_no_dones[0]={}",
330            vs_with_dones[0],
331            vs_no_dones[0]
332        );
333
334        // Steps after the boundary (t=2, t=3) should be unaffected
335        assert!(
336            (vs_with_dones[3] - vs_no_dones[3]).abs() < 1e-5,
337            "t=3 should be identical"
338        );
339    }
340
341    #[test]
342    fn vtrace_without_dones_matches_old_behavior() {
343        // Passing all-zeros dones should reproduce the original behavior
344        let gamma = 0.9_f32;
345        let rho_bar = 1.5_f32;
346        let c_bar = 1.2_f32;
347        let log_rhos = vec![0.2, -0.3, 0.8];
348        let rewards = vec![1.0, 2.0, 3.0];
349        let values = vec![0.5, 1.0, 1.5];
350        let bootstrap = 2.0;
351        let dones = vec![0.0; 3];
352
353        let (vs, adv) = compute_vtrace(
354            &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
355        )
356        .unwrap();
357
358        // Manually computed reference (same as vtrace_reference_implementation)
359        let rho_2 = 1.5_f32;
360        let c_2 = 1.2_f32;
361        let delta_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
362        let vs_2 = 1.5 + delta_2 + 0.9 * c_2 * (2.0 - 2.0);
363        let pg_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
364
365        let rho_1 = (-0.3_f32).exp();
366        let c_1 = c_bar.min(rho_1);
367        let delta_1 = rho_1 * (2.0 + 0.9 * 1.5 - 1.0);
368        let vs_1 = 1.0 + delta_1 + 0.9 * c_1 * (vs_2 - 1.5);
369        let pg_1 = rho_1 * (2.0 + 0.9 * vs_2 - 1.0);
370
371        let rho_0 = (0.2_f32).exp();
372        let c_0 = c_bar.min(rho_0);
373        let delta_0 = rho_0 * (1.0 + 0.9 * 1.0 - 0.5);
374        let vs_0 = 0.5 + delta_0 + 0.9 * c_0 * (vs_1 - 1.0);
375        let pg_0 = rho_0 * (1.0 + 0.9 * vs_1 - 0.5);
376
377        assert!(
378            (vs[0] - vs_0).abs() < 1e-4,
379            "vs[0]: got {}, expected {}",
380            vs[0],
381            vs_0
382        );
383        assert!(
384            (vs[1] - vs_1).abs() < 1e-4,
385            "vs[1]: got {}, expected {}",
386            vs[1],
387            vs_1
388        );
389        assert!(
390            (vs[2] - vs_2).abs() < 1e-4,
391            "vs[2]: got {}, expected {}",
392            vs[2],
393            vs_2
394        );
395        assert!(
396            (adv[0] - pg_0).abs() < 1e-4,
397            "adv[0]: got {}, expected {}",
398            adv[0],
399            pg_0
400        );
401        assert!(
402            (adv[1] - pg_1).abs() < 1e-4,
403            "adv[1]: got {}, expected {}",
404            adv[1],
405            pg_1
406        );
407        assert!(
408            (adv[2] - pg_2).abs() < 1e-4,
409            "adv[2]: got {}, expected {}",
410            adv[2],
411            pg_2
412        );
413
414        // Suppress unused-variable warnings
415        let _ = (c_0, c_1, c_2, pg_0, pg_1, pg_2, delta_0, delta_1, delta_2);
416    }
417
418    #[test]
419    fn vtrace_dones_at_last_step_zeros_bootstrap() {
420        // When the last step is terminal, bootstrap should be zeroed
421        let gamma = 0.99_f32;
422        let log_rhos = vec![0.0]; // on-policy, single step
423        let rewards = vec![1.0];
424        let values = vec![0.5];
425        let bootstrap = 10.0; // large bootstrap to make the effect visible
426
427        // With done at last step
428        let dones_terminal = vec![1.0];
429        let (vs_term, adv_term) = compute_vtrace(
430            &log_rhos,
431            &rewards,
432            &values,
433            &dones_terminal,
434            bootstrap,
435            gamma,
436            1.0,
437            1.0,
438        )
439        .unwrap();
440
441        // Without done
442        let dones_none = vec![0.0];
443        let (vs_cont, adv_cont) = compute_vtrace(
444            &log_rhos,
445            &rewards,
446            &values,
447            &dones_none,
448            bootstrap,
449            gamma,
450            1.0,
451            1.0,
452        )
453        .unwrap();
454
455        // Terminal: delta = 1*(1.0 + 0.99*0*10 - 0.5) = 0.5, vs = 0.5 + 0.5 = 1.0
456        // Non-terminal: delta = 1*(1.0 + 0.99*10 - 0.5) = 10.4, vs = 0.5 + 10.4 = 10.9
457        assert!(
458            (vs_term[0] - 1.0).abs() < 1e-5,
459            "terminal vs[0]={}, expected 1.0",
460            vs_term[0]
461        );
462        assert!(
463            vs_cont[0] > vs_term[0],
464            "non-terminal vs should be larger due to bootstrap"
465        );
466
467        // Terminal advantage: rho*(r + gamma*0*vs_next - v) = 1*(1+0-0.5) = 0.5
468        assert!(
469            (adv_term[0] - 0.5).abs() < 1e-5,
470            "terminal adv[0]={}, expected 0.5",
471            adv_term[0]
472        );
473        assert!(
474            adv_cont[0] > adv_term[0],
475            "non-terminal adv should be larger"
476        );
477    }
478
479    mod proptests {
480        use super::*;
481        use proptest::prelude::*;
482
483        proptest! {
484            #[test]
485            fn vtrace_output_length_matches_input(n in 0..200usize) {
486                let log_rhos = vec![0.0; n];
487                let rewards = vec![1.0; n];
488                let values = vec![0.5; n];
489                let dones = vec![0.0; n];
490                let (vs, adv) = compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
491                prop_assert_eq!(vs.len(), n);
492                prop_assert_eq!(adv.len(), n);
493            }
494
495            #[test]
496            fn vtrace_on_policy_vs_are_finite(n in 1..100usize) {
497                let log_rhos = vec![0.0; n];
498                let rewards: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
499                let values: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05).collect();
500                let dones = vec![0.0; n];
501                let (vs, adv) = compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
502                for i in 0..n {
503                    prop_assert!(vs[i].is_finite(), "vs[{}] is not finite: {}", i, vs[i]);
504                    prop_assert!(adv[i].is_finite(), "adv[{}] is not finite: {}", i, adv[i]);
505                }
506            }
507        }
508    }
509}