1use crate::error::RloxError;
2
3pub fn compute_vtrace(
18 log_rhos: &[f32],
19 rewards: &[f32],
20 values: &[f32],
21 dones: &[f32],
22 bootstrap_value: f32,
23 gamma: f32,
24 rho_bar: f32,
25 c_bar: f32,
26) -> Result<(Vec<f32>, Vec<f32>), RloxError> {
27 let n = log_rhos.len();
28
29 if rewards.len() != n || values.len() != n || dones.len() != n {
30 return Err(RloxError::ShapeMismatch {
31 expected: format!("all slices length {n}"),
32 got: format!(
33 "log_rhos={}, rewards={}, values={}, dones={}",
34 n,
35 rewards.len(),
36 values.len(),
37 dones.len()
38 ),
39 });
40 }
41
42 if n == 0 {
43 return Ok((Vec::new(), Vec::new()));
44 }
45
46 let mut vs = vec![0.0f32; n];
47 let mut pg_advantages = vec![0.0f32; n];
48
49 let last = n - 1;
51 {
52 let ratio = log_rhos[last].exp();
53 let rho_t = rho_bar.min(ratio);
54 let non_terminal = 1.0 - dones[last];
55 let next_value = bootstrap_value * non_terminal;
56
57 let delta_t = rho_t * (rewards[last] + gamma * next_value - values[last]);
58 let vs_next_val = bootstrap_value * non_terminal;
60 vs[last] = values[last]
61 + delta_t
62 + gamma * non_terminal * rho_bar.min(ratio).min(c_bar) * (vs_next_val - next_value);
63 pg_advantages[last] = rho_t * (rewards[last] + gamma * vs_next_val - values[last]);
64 }
65
66 let mut vs_next = vs[last];
68
69 for t in (0..last).rev() {
70 let ratio = log_rhos[t].exp();
71 let rho_t = rho_bar.min(ratio);
72 let c_t = c_bar.min(ratio);
73 let non_terminal = 1.0 - dones[t];
74
75 let next_value = values[t + 1];
76
77 let delta_t = rho_t * (rewards[t] + gamma * non_terminal * next_value - values[t]);
78 vs[t] = values[t] + delta_t + gamma * non_terminal * c_t * (vs_next - next_value);
79 pg_advantages[t] = rho_t * (rewards[t] + gamma * non_terminal * vs_next - values[t]);
80
81 vs_next = vs[t];
82 }
83
84 Ok((vs, pg_advantages))
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn vtrace_empty_input() {
93 let (vs, adv) = compute_vtrace(&[], &[], &[], &[], 0.0, 0.99, 1.0, 1.0).unwrap();
94 assert!(vs.is_empty());
95 assert!(adv.is_empty());
96 }
97
98 #[test]
99 fn vtrace_mismatched_lengths() {
100 let result = compute_vtrace(&[0.0], &[1.0, 2.0], &[0.5], &[0.0], 0.0, 0.99, 1.0, 1.0);
101 assert!(result.is_err());
102 }
103
104 #[test]
105 fn vtrace_on_policy_matches_gae_like() {
106 let log_rhos = vec![0.0; 3];
109 let rewards = vec![1.0, 1.0, 1.0];
110 let values = vec![0.0, 0.0, 0.0];
111 let bootstrap = 0.0;
112 let gamma = 0.99;
113
114 let dones = vec![0.0; 3];
115 let (vs, _adv) = compute_vtrace(
116 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
117 )
118 .unwrap();
119
120 assert!((vs[2] - 1.0).abs() < 1e-5);
125 assert!((vs[1] - 1.99).abs() < 1e-5);
126 assert!((vs[0] - 2.9701).abs() < 1e-4);
127 }
128
129 #[test]
130 fn vtrace_single_step() {
131 let log_rho = 0.5_f32;
133 let log_rhos = vec![log_rho];
134 let rewards = vec![1.0];
135 let values = vec![0.5];
136 let bootstrap = 0.0;
137 let gamma = 0.99;
138 let rho_bar = 10.0; let c_bar = 10.0;
140
141 let dones = vec![0.0];
142 let (vs, adv) = compute_vtrace(
143 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
144 )
145 .unwrap();
146
147 let rho = log_rho.exp(); let _c = c_bar.min(rho);
149 let expected_vs = 0.5 + rho * 0.5;
154 let expected_adv = rho * 0.5;
155
156 assert!(
157 (vs[0] - expected_vs).abs() < 1e-5,
158 "vs[0]={}, expected={}",
159 vs[0],
160 expected_vs
161 );
162 assert!(
163 (adv[0] - expected_adv).abs() < 1e-5,
164 "adv[0]={}, expected={}",
165 adv[0],
166 expected_adv
167 );
168 }
169
170 #[test]
171 fn vtrace_clipping_reduces_correction() {
172 let log_rhos = vec![5.0]; let rewards = vec![1.0];
175 let values = vec![0.0];
176 let bootstrap = 0.0;
177 let gamma = 0.99;
178
179 let dones = vec![0.0];
180 let (vs_clipped, _) = compute_vtrace(
181 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
182 )
183 .unwrap();
184 let (vs_unclipped, _) = compute_vtrace(
185 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 200.0, 200.0,
186 )
187 .unwrap();
188
189 assert!((vs_clipped[0] - 1.0).abs() < 1e-5);
191 assert!(vs_unclipped[0] > 100.0);
193 }
194
195 #[test]
196 fn vtrace_output_lengths_match_input() {
197 let n = 10;
198 let log_rhos = vec![0.0; n];
199 let rewards = vec![1.0; n];
200 let values = vec![0.5; n];
201 let dones = vec![0.0; n];
202 let (vs, adv) =
203 compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
204 assert_eq!(vs.len(), n);
205 assert_eq!(adv.len(), n);
206 }
207
208 #[test]
209 fn vtrace_reference_implementation() {
210 let gamma = 0.9_f32;
212 let rho_bar = 1.5_f32;
213 let c_bar = 1.2_f32;
214
215 let log_rhos = vec![0.2, -0.3, 0.8];
216 let rewards = vec![1.0, 2.0, 3.0];
217 let values = vec![0.5, 1.0, 1.5];
218 let bootstrap = 2.0;
219
220 let rho_2 = 1.5_f32;
229 let c_2 = 1.2_f32;
230 let delta_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
231 let vs_2 = 1.5 + delta_2 + 0.9 * c_2 * (2.0 - 2.0);
232 let pg_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
233
234 let rho_1 = (-0.3_f32).exp();
241 let c_1 = c_bar.min(rho_1);
242 let delta_1 = rho_1 * (2.0 + 0.9 * 1.5 - 1.0);
243 let vs_1 = 1.0 + delta_1 + 0.9 * c_1 * (vs_2 - 1.5);
244 let pg_1 = rho_1 * (2.0 + 0.9 * vs_2 - 1.0);
245
246 let rho_0 = (0.2_f32).exp();
253 let c_0 = c_bar.min(rho_0);
254 let delta_0 = rho_0 * (1.0 + 0.9 * 1.0 - 0.5);
255 let vs_0 = 0.5 + delta_0 + 0.9 * c_0 * (vs_1 - 1.0);
256 let pg_0 = rho_0 * (1.0 + 0.9 * vs_1 - 0.5);
257
258 let dones = vec![0.0; 3];
259 let (vs, adv) = compute_vtrace(
260 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
261 )
262 .unwrap();
263
264 assert!(
265 (vs[0] - vs_0).abs() < 1e-4,
266 "vs[0]: got {}, expected {}",
267 vs[0],
268 vs_0
269 );
270 assert!(
271 (vs[1] - vs_1).abs() < 1e-4,
272 "vs[1]: got {}, expected {}",
273 vs[1],
274 vs_1
275 );
276 assert!(
277 (vs[2] - vs_2).abs() < 1e-4,
278 "vs[2]: got {}, expected {}",
279 vs[2],
280 vs_2
281 );
282 assert!(
283 (adv[0] - pg_0).abs() < 1e-4,
284 "adv[0]: got {}, expected {}",
285 adv[0],
286 pg_0
287 );
288 assert!(
289 (adv[1] - pg_1).abs() < 1e-4,
290 "adv[1]: got {}, expected {}",
291 adv[1],
292 pg_1
293 );
294 assert!(
295 (adv[2] - pg_2).abs() < 1e-4,
296 "adv[2]: got {}, expected {}",
297 adv[2],
298 pg_2
299 );
300 }
301
302 #[test]
303 fn vtrace_with_dones_resets_at_boundary() {
304 let gamma = 0.99_f32;
307 let log_rhos = vec![0.0; 4]; let rewards = vec![1.0, 1.0, 1.0, 1.0];
309 let values = vec![0.0; 4];
310 let dones = vec![0.0, 1.0, 0.0, 0.0]; let bootstrap = 0.0;
312
313 let (vs_with_dones, _) = compute_vtrace(
314 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, 1.0, 1.0,
315 )
316 .unwrap();
317
318 let no_dones = vec![0.0; 4];
320 let (vs_no_dones, _) = compute_vtrace(
321 &log_rhos, &rewards, &values, &no_dones, bootstrap, gamma, 1.0, 1.0,
322 )
323 .unwrap();
324
325 assert!(
328 vs_with_dones[0] < vs_no_dones[0],
329 "vs_with_dones[0]={} should be < vs_no_dones[0]={}",
330 vs_with_dones[0],
331 vs_no_dones[0]
332 );
333
334 assert!(
336 (vs_with_dones[3] - vs_no_dones[3]).abs() < 1e-5,
337 "t=3 should be identical"
338 );
339 }
340
341 #[test]
342 fn vtrace_without_dones_matches_old_behavior() {
343 let gamma = 0.9_f32;
345 let rho_bar = 1.5_f32;
346 let c_bar = 1.2_f32;
347 let log_rhos = vec![0.2, -0.3, 0.8];
348 let rewards = vec![1.0, 2.0, 3.0];
349 let values = vec![0.5, 1.0, 1.5];
350 let bootstrap = 2.0;
351 let dones = vec![0.0; 3];
352
353 let (vs, adv) = compute_vtrace(
354 &log_rhos, &rewards, &values, &dones, bootstrap, gamma, rho_bar, c_bar,
355 )
356 .unwrap();
357
358 let rho_2 = 1.5_f32;
360 let c_2 = 1.2_f32;
361 let delta_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
362 let vs_2 = 1.5 + delta_2 + 0.9 * c_2 * (2.0 - 2.0);
363 let pg_2 = rho_2 * (3.0 + 0.9 * 2.0 - 1.5);
364
365 let rho_1 = (-0.3_f32).exp();
366 let c_1 = c_bar.min(rho_1);
367 let delta_1 = rho_1 * (2.0 + 0.9 * 1.5 - 1.0);
368 let vs_1 = 1.0 + delta_1 + 0.9 * c_1 * (vs_2 - 1.5);
369 let pg_1 = rho_1 * (2.0 + 0.9 * vs_2 - 1.0);
370
371 let rho_0 = (0.2_f32).exp();
372 let c_0 = c_bar.min(rho_0);
373 let delta_0 = rho_0 * (1.0 + 0.9 * 1.0 - 0.5);
374 let vs_0 = 0.5 + delta_0 + 0.9 * c_0 * (vs_1 - 1.0);
375 let pg_0 = rho_0 * (1.0 + 0.9 * vs_1 - 0.5);
376
377 assert!(
378 (vs[0] - vs_0).abs() < 1e-4,
379 "vs[0]: got {}, expected {}",
380 vs[0],
381 vs_0
382 );
383 assert!(
384 (vs[1] - vs_1).abs() < 1e-4,
385 "vs[1]: got {}, expected {}",
386 vs[1],
387 vs_1
388 );
389 assert!(
390 (vs[2] - vs_2).abs() < 1e-4,
391 "vs[2]: got {}, expected {}",
392 vs[2],
393 vs_2
394 );
395 assert!(
396 (adv[0] - pg_0).abs() < 1e-4,
397 "adv[0]: got {}, expected {}",
398 adv[0],
399 pg_0
400 );
401 assert!(
402 (adv[1] - pg_1).abs() < 1e-4,
403 "adv[1]: got {}, expected {}",
404 adv[1],
405 pg_1
406 );
407 assert!(
408 (adv[2] - pg_2).abs() < 1e-4,
409 "adv[2]: got {}, expected {}",
410 adv[2],
411 pg_2
412 );
413
414 let _ = (c_0, c_1, c_2, pg_0, pg_1, pg_2, delta_0, delta_1, delta_2);
416 }
417
418 #[test]
419 fn vtrace_dones_at_last_step_zeros_bootstrap() {
420 let gamma = 0.99_f32;
422 let log_rhos = vec![0.0]; let rewards = vec![1.0];
424 let values = vec![0.5];
425 let bootstrap = 10.0; let dones_terminal = vec![1.0];
429 let (vs_term, adv_term) = compute_vtrace(
430 &log_rhos,
431 &rewards,
432 &values,
433 &dones_terminal,
434 bootstrap,
435 gamma,
436 1.0,
437 1.0,
438 )
439 .unwrap();
440
441 let dones_none = vec![0.0];
443 let (vs_cont, adv_cont) = compute_vtrace(
444 &log_rhos,
445 &rewards,
446 &values,
447 &dones_none,
448 bootstrap,
449 gamma,
450 1.0,
451 1.0,
452 )
453 .unwrap();
454
455 assert!(
458 (vs_term[0] - 1.0).abs() < 1e-5,
459 "terminal vs[0]={}, expected 1.0",
460 vs_term[0]
461 );
462 assert!(
463 vs_cont[0] > vs_term[0],
464 "non-terminal vs should be larger due to bootstrap"
465 );
466
467 assert!(
469 (adv_term[0] - 0.5).abs() < 1e-5,
470 "terminal adv[0]={}, expected 0.5",
471 adv_term[0]
472 );
473 assert!(
474 adv_cont[0] > adv_term[0],
475 "non-terminal adv should be larger"
476 );
477 }
478
479 mod proptests {
480 use super::*;
481 use proptest::prelude::*;
482
483 proptest! {
484 #[test]
485 fn vtrace_output_length_matches_input(n in 0..200usize) {
486 let log_rhos = vec![0.0; n];
487 let rewards = vec![1.0; n];
488 let values = vec![0.5; n];
489 let dones = vec![0.0; n];
490 let (vs, adv) = compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
491 prop_assert_eq!(vs.len(), n);
492 prop_assert_eq!(adv.len(), n);
493 }
494
495 #[test]
496 fn vtrace_on_policy_vs_are_finite(n in 1..100usize) {
497 let log_rhos = vec![0.0; n];
498 let rewards: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
499 let values: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05).collect();
500 let dones = vec![0.0; n];
501 let (vs, adv) = compute_vtrace(&log_rhos, &rewards, &values, &dones, 0.0, 0.99, 1.0, 1.0).unwrap();
502 for i in 0..n {
503 prop_assert!(vs[i].is_finite(), "vs[{}] is not finite: {}", i, vs[i]);
504 prop_assert!(adv[i].is_finite(), "adv[{}] is not finite: {}", i, adv[i]);
505 }
506 }
507 }
508 }
509}