1#[derive(Debug, Clone)]
3pub struct GaeResult<T> {
4 pub advantages: Vec<T>,
6 pub returns: Vec<T>,
8}
9
10pub fn compute_gae(
25 rewards: &[f64],
26 values: &[f64],
27 dones: &[f64],
28 last_value: f64,
29 gamma: f64,
30 gae_lambda: f64,
31) -> (Vec<f64>, Vec<f64>) {
32 let n = rewards.len();
33 debug_assert_eq!(values.len(), n, "values.len() must equal rewards.len()");
34 debug_assert_eq!(dones.len(), n, "dones.len() must equal rewards.len()");
35 if n == 0 {
36 return (Vec::new(), Vec::new());
37 }
38
39 let mut advantages = vec![0.0; n];
40
41 let last_nt = 1.0 - dones[n - 1];
43 let last_delta = rewards[n - 1] + gamma * last_value * last_nt - values[n - 1];
44 let mut last_gae = last_delta;
45 advantages[n - 1] = last_gae;
46
47 for t in (0..n - 1).rev() {
48 let next_non_terminal = 1.0 - dones[t];
49 let delta = rewards[t] + gamma * values[t + 1] * next_non_terminal - values[t];
50 last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
51 advantages[t] = last_gae;
52 }
53
54 let returns: Vec<f64> = advantages
55 .iter()
56 .zip(values.iter())
57 .map(|(a, v)| a + v)
58 .collect();
59
60 (advantages, returns)
61}
62
63pub fn compute_gae_batched(
76 rewards: &[f64],
77 values: &[f64],
78 dones: &[f64],
79 last_values: &[f64],
80 n_steps: usize,
81 gamma: f64,
82 gae_lambda: f64,
83) -> (Vec<f64>, Vec<f64>) {
84 let n_envs = last_values.len();
85 if n_envs == 0 || n_steps == 0 {
86 return (Vec::new(), Vec::new());
87 }
88 let expected_len = n_envs * n_steps;
89 debug_assert_eq!(
90 rewards.len(),
91 expected_len,
92 "rewards.len() must equal n_envs * n_steps"
93 );
94 debug_assert_eq!(
95 values.len(),
96 expected_len,
97 "values.len() must equal n_envs * n_steps"
98 );
99 debug_assert_eq!(
100 dones.len(),
101 expected_len,
102 "dones.len() must equal n_envs * n_steps"
103 );
104
105 use rayon::prelude::*;
106
107 let mut all_advantages = vec![0.0; n_envs * n_steps];
108 let mut all_returns = vec![0.0; n_envs * n_steps];
109
110 all_advantages
111 .par_chunks_mut(n_steps)
112 .zip(all_returns.par_chunks_mut(n_steps))
113 .enumerate()
114 .for_each(|(env_idx, (adv_chunk, ret_chunk))| {
115 let offset = env_idx * n_steps;
116 let r = &rewards[offset..offset + n_steps];
117 let v = &values[offset..offset + n_steps];
118 let d = &dones[offset..offset + n_steps];
119 let lv = last_values[env_idx];
120
121 let mut last_gae = 0.0;
122 for t in (0..n_steps).rev() {
123 let next_non_terminal = 1.0 - d[t];
124 let next_value = if t == n_steps - 1 { lv } else { v[t + 1] };
125 let delta = r[t] + gamma * next_value * next_non_terminal - v[t];
126 last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
127 adv_chunk[t] = last_gae;
128 ret_chunk[t] = last_gae + v[t];
129 }
130 });
131
132 (all_advantages, all_returns)
133}
134
135pub fn compute_gae_batched_f32(
144 rewards: &[f32],
145 values: &[f32],
146 dones: &[f32],
147 last_values: &[f32],
148 n_steps: usize,
149 gamma: f32,
150 gae_lambda: f32,
151) -> (Vec<f32>, Vec<f32>) {
152 let n_envs = last_values.len();
153 if n_envs == 0 || n_steps == 0 {
154 return (Vec::new(), Vec::new());
155 }
156 let expected_len = n_envs * n_steps;
157 debug_assert_eq!(
158 rewards.len(),
159 expected_len,
160 "rewards.len() must equal n_envs * n_steps"
161 );
162 debug_assert_eq!(
163 values.len(),
164 expected_len,
165 "values.len() must equal n_envs * n_steps"
166 );
167 debug_assert_eq!(
168 dones.len(),
169 expected_len,
170 "dones.len() must equal n_envs * n_steps"
171 );
172
173 use rayon::prelude::*;
174
175 let mut all_advantages = vec![0.0f32; n_envs * n_steps];
176 let mut all_returns = vec![0.0f32; n_envs * n_steps];
177
178 all_advantages
179 .par_chunks_mut(n_steps)
180 .zip(all_returns.par_chunks_mut(n_steps))
181 .enumerate()
182 .for_each(|(env_idx, (adv_chunk, ret_chunk))| {
183 let offset = env_idx * n_steps;
184 let r = &rewards[offset..offset + n_steps];
185 let v = &values[offset..offset + n_steps];
186 let d = &dones[offset..offset + n_steps];
187 let lv = last_values[env_idx];
188
189 let last_nt = 1.0 - d[n_steps - 1];
191 let last_delta = r[n_steps - 1] + gamma * lv * last_nt - v[n_steps - 1];
192 let mut last_gae = last_delta;
193 adv_chunk[n_steps - 1] = last_gae;
194 ret_chunk[n_steps - 1] = last_gae + v[n_steps - 1];
195
196 for t in (0..n_steps - 1).rev() {
197 let next_non_terminal = 1.0 - d[t];
198 let delta = r[t] + gamma * v[t + 1] * next_non_terminal - v[t];
199 last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
200 adv_chunk[t] = last_gae;
201 ret_chunk[t] = last_gae + v[t];
202 }
203 });
204
205 (all_advantages, all_returns)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn bools_to_f64(bools: &[bool]) -> Vec<f64> {
214 bools.iter().map(|&b| if b { 1.0 } else { 0.0 }).collect()
215 }
216
217 #[test]
218 fn gae_single_step_episode() {
219 let rewards = &[1.0];
220 let values = &[0.5];
221 let dones = bools_to_f64(&[true]);
222 let last_value = 0.0;
223 let gamma = 0.99;
224 let gae_lambda = 0.95;
225 let (advantages, _returns) =
226 compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
227 assert_eq!(advantages.len(), 1);
228 assert!((advantages[0] - 0.5).abs() < 1e-6);
229 }
230
231 #[test]
232 fn gae_multi_step_no_termination() {
233 let rewards = &[1.0, 1.0, 1.0];
234 let values = &[0.0, 0.0, 0.0];
235 let dones = bools_to_f64(&[false, false, false]);
236 let last_value = 0.0;
237 let gamma = 0.99;
238 let gae_lambda = 0.95;
239 let (advantages, _returns) =
240 compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
241 assert_eq!(advantages.len(), 3);
242 assert!((advantages[2] - 1.0).abs() < 1e-6);
244 assert!((advantages[1] - 1.9405).abs() < 1e-4);
246 assert!(advantages[0] > advantages[1]);
248 }
249
250 #[test]
251 fn gae_resets_at_episode_boundary() {
252 let rewards = &[1.0, 1.0, 1.0];
253 let values = &[0.0, 0.0, 0.0];
254 let dones = bools_to_f64(&[false, true, false]);
255 let last_value = 0.0;
256 let gamma = 0.99;
257 let gae_lambda = 0.95;
258 let (advantages, _) = compute_gae(rewards, values, &dones, last_value, gamma, gae_lambda);
259 assert!((advantages[1] - 1.0).abs() < 1e-6);
261 }
262
263 #[test]
264 fn gae_returns_are_advantages_plus_values() {
265 let rewards = &[1.0, 2.0, 3.0];
266 let values = &[0.5, 1.0, 1.5];
267 let dones = bools_to_f64(&[false, false, true]);
268 let last_value = 0.0;
269 let (advantages, returns) = compute_gae(rewards, values, &dones, last_value, 0.99, 0.95);
270 for i in 0..3 {
271 assert!((returns[i] - (advantages[i] + values[i])).abs() < 1e-6);
272 }
273 }
274
275 #[test]
276 fn gae_empty_input() {
277 let (advantages, returns) = compute_gae(&[], &[], &[], 0.0, 0.99, 0.95);
278 assert!(advantages.is_empty());
279 assert!(returns.is_empty());
280 }
281
282 #[test]
283 fn gae_lambda_zero_is_one_step_td() {
284 let rewards = &[1.0, 1.0];
285 let values = &[0.5, 0.5];
286 let dones = bools_to_f64(&[false, false]);
287 let last_value = 0.5;
288 let (advantages, _) = compute_gae(rewards, values, &dones, last_value, 0.99, 0.0);
289 assert!((advantages[1] - 0.995).abs() < 1e-6);
291 }
292
293 #[test]
294 fn gae_lambda_one_is_monte_carlo() {
295 let rewards = &[1.0, 1.0, 1.0];
296 let values = &[0.0, 0.0, 0.0];
297 let dones = bools_to_f64(&[false, false, true]);
298 let (advantages, _) = compute_gae(rewards, values, &dones, 0.0, 0.99, 1.0);
299 assert!((advantages[0] - 2.9701).abs() < 1e-3);
301 }
302
303 #[test]
304 fn gae_batched_matches_unbatched() {
305 let gamma = 0.99;
306 let lam = 0.95;
307 let rewards = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
309 let values = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
310 let dones = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
311 let last_values = vec![0.0, 0.5];
312
313 let (adv_b, ret_b) =
314 compute_gae_batched(&rewards, &values, &dones, &last_values, 3, gamma, lam);
315
316 let (adv0, ret0) = compute_gae(
317 &rewards[..3],
318 &values[..3],
319 &dones[..3],
320 last_values[0],
321 gamma,
322 lam,
323 );
324 let (adv1, ret1) = compute_gae(
325 &rewards[3..],
326 &values[3..],
327 &dones[3..],
328 last_values[1],
329 gamma,
330 lam,
331 );
332
333 for i in 0..3 {
334 assert!(
335 (adv_b[i] - adv0[i]).abs() < 1e-12,
336 "env0 adv mismatch at {i}"
337 );
338 assert!(
339 (ret_b[i] - ret0[i]).abs() < 1e-12,
340 "env0 ret mismatch at {i}"
341 );
342 assert!(
343 (adv_b[3 + i] - adv1[i]).abs() < 1e-12,
344 "env1 adv mismatch at {i}"
345 );
346 assert!(
347 (ret_b[3 + i] - ret1[i]).abs() < 1e-12,
348 "env1 ret mismatch at {i}"
349 );
350 }
351 }
352
353 #[test]
354 fn gae_batched_empty() {
355 let (adv, ret) = compute_gae_batched(&[], &[], &[], &[], 0, 0.99, 0.95);
356 assert!(adv.is_empty());
357 assert!(ret.is_empty());
358 }
359
360 #[test]
361 fn gae_batched_f32_matches_f64() {
362 let gamma = 0.99f32;
363 let lam = 0.95f32;
364 let rewards_f32: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
365 let values_f32: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
366 let dones_f32: Vec<f32> = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
367 let last_values_f32: Vec<f32> = vec![0.0, 0.5];
368
369 let (adv_f32, ret_f32) = compute_gae_batched_f32(
370 &rewards_f32,
371 &values_f32,
372 &dones_f32,
373 &last_values_f32,
374 3,
375 gamma,
376 lam,
377 );
378
379 let rewards_f64: Vec<f64> = rewards_f32.iter().map(|&x| x as f64).collect();
380 let values_f64: Vec<f64> = values_f32.iter().map(|&x| x as f64).collect();
381 let dones_f64: Vec<f64> = dones_f32.iter().map(|&x| x as f64).collect();
382 let last_values_f64: Vec<f64> = last_values_f32.iter().map(|&x| x as f64).collect();
383
384 let (adv_f64, ret_f64) = compute_gae_batched(
385 &rewards_f64,
386 &values_f64,
387 &dones_f64,
388 &last_values_f64,
389 3,
390 0.99,
391 0.95,
392 );
393
394 for i in 0..6 {
395 assert!(
396 (adv_f32[i] as f64 - adv_f64[i]).abs() < 1e-5,
397 "adv mismatch at {i}"
398 );
399 assert!(
400 (ret_f32[i] as f64 - ret_f64[i]).abs() < 1e-5,
401 "ret mismatch at {i}"
402 );
403 }
404 }
405
406 mod proptests {
407 use super::*;
408 use proptest::prelude::*;
409
410 proptest! {
411 #[test]
412 fn gae_returns_equal_advantages_plus_values(n in 1..500usize) {
413 let rewards: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
414 let values: Vec<f64> = (0..n).map(|i| (i as f64) * 0.05).collect();
415 let dones: Vec<f64> = (0..n).map(|i| if i % 10 == 9 { 1.0 } else { 0.0 }).collect();
416 let (advantages, returns) = compute_gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);
417 for i in 0..n {
418 let diff = (returns[i] - (advantages[i] + values[i])).abs();
419 prop_assert!(diff < 1e-10, "mismatch at index {}: returns={}, adv+val={}", i, returns[i], advantages[i] + values[i]);
420 }
421 }
422
423 #[test]
424 fn gae_length_matches_input(n in 0..500usize) {
425 let rewards = vec![1.0; n];
426 let values = vec![0.5; n];
427 let dones = vec![0.0; n];
428 let (advantages, returns) = compute_gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);
429 prop_assert_eq!(advantages.len(), n);
430 prop_assert_eq!(returns.len(), n);
431 }
432 }
433 }
434}