1use rand::Rng;
13use rand::SeedableRng;
14use rand_chacha::ChaCha8Rng;
15
16use crate::error::RloxError;
17
18#[derive(Debug, Clone)]
20pub struct DatasetStats {
21 pub n_transitions: usize,
22 pub n_episodes: usize,
23 pub obs_dim: usize,
24 pub act_dim: usize,
25 pub mean_return: f32,
26 pub std_return: f32,
27 pub min_return: f32,
28 pub max_return: f32,
29 pub mean_episode_length: f32,
30}
31
32#[derive(Debug, Clone)]
34pub struct OfflineBatch {
35 pub obs: Vec<f32>, pub next_obs: Vec<f32>, pub actions: Vec<f32>, pub rewards: Vec<f32>, pub terminated: Vec<u8>, pub obs_dim: usize,
41 pub act_dim: usize,
42}
43
44#[derive(Debug, Clone)]
46pub struct TrajectoryBatch {
47 pub obs: Vec<f32>, pub actions: Vec<f32>, pub rewards: Vec<f32>, pub returns_to_go: Vec<f32>, pub timesteps: Vec<u32>, pub mask: Vec<u8>, pub seq_len: usize,
54 pub obs_dim: usize,
55 pub act_dim: usize,
56}
57
58pub struct OfflineDatasetBuffer {
60 obs: Vec<f32>,
61 next_obs: Vec<f32>,
62 actions: Vec<f32>,
63 rewards: Vec<f32>,
64 terminated: Vec<u8>,
65 #[allow(dead_code)]
66 truncated: Vec<u8>,
67
68 episode_starts: Vec<usize>,
70 episode_lengths: Vec<usize>,
71 episode_returns: Vec<f32>,
72
73 obs_dim: usize,
74 act_dim: usize,
75 len: usize,
76
77 obs_mean: Option<Vec<f32>>,
79 obs_std: Option<Vec<f32>>,
80 reward_mean: Option<f32>,
81 reward_std: Option<f32>,
82}
83
84impl OfflineDatasetBuffer {
85 pub fn from_arrays(
89 obs: Vec<f32>,
90 next_obs: Vec<f32>,
91 actions: Vec<f32>,
92 rewards: Vec<f32>,
93 terminated: Vec<u8>,
94 truncated: Vec<u8>,
95 obs_dim: usize,
96 act_dim: usize,
97 ) -> Result<Self, RloxError> {
98 let n = rewards.len();
99
100 if obs.len() != n * obs_dim {
101 return Err(RloxError::ShapeMismatch {
102 expected: format!("obs length = {} * {} = {}", n, obs_dim, n * obs_dim),
103 got: format!("{}", obs.len()),
104 });
105 }
106 if next_obs.len() != n * obs_dim {
107 return Err(RloxError::ShapeMismatch {
108 expected: format!("next_obs length = {}", n * obs_dim),
109 got: format!("{}", next_obs.len()),
110 });
111 }
112 if actions.len() != n * act_dim {
113 return Err(RloxError::ShapeMismatch {
114 expected: format!("actions length = {} * {} = {}", n, act_dim, n * act_dim),
115 got: format!("{}", actions.len()),
116 });
117 }
118 if terminated.len() != n || truncated.len() != n {
119 return Err(RloxError::ShapeMismatch {
120 expected: format!("terminated/truncated length = {}", n),
121 got: format!(
122 "terminated={}, truncated={}",
123 terminated.len(),
124 truncated.len()
125 ),
126 });
127 }
128
129 let mut episode_starts = vec![0usize];
131 let mut episode_returns = Vec::new();
132 let mut ep_return = 0.0f32;
133
134 for i in 0..n {
135 ep_return += rewards[i];
136 let done = terminated[i] != 0 || truncated[i] != 0;
137 if done || i == n - 1 {
138 episode_returns.push(ep_return);
139 if i + 1 < n {
140 episode_starts.push(i + 1);
141 }
142 ep_return = 0.0;
143 }
144 }
145
146 let episode_lengths: Vec<usize> = episode_starts
147 .windows(2)
148 .map(|w| w[1] - w[0])
149 .chain(std::iter::once(n - episode_starts.last().unwrap_or(&0)))
150 .collect();
151
152 Ok(Self {
153 obs,
154 next_obs,
155 actions,
156 rewards,
157 terminated,
158 truncated,
159 episode_starts,
160 episode_lengths,
161 episode_returns,
162 obs_dim,
163 act_dim,
164 len: n,
165 obs_mean: None,
166 obs_std: None,
167 reward_mean: None,
168 reward_std: None,
169 })
170 }
171
172 pub fn len(&self) -> usize {
174 self.len
175 }
176
177 pub fn is_empty(&self) -> bool {
178 self.len == 0
179 }
180
181 pub fn n_episodes(&self) -> usize {
183 self.episode_starts.len()
184 }
185
186 pub fn obs_dim(&self) -> usize {
187 self.obs_dim
188 }
189
190 pub fn act_dim(&self) -> usize {
191 self.act_dim
192 }
193
194 #[allow(clippy::needless_range_loop)]
196 pub fn compute_normalization(&mut self) {
197 let n = self.len;
198 let d = self.obs_dim;
199
200 let mut mean = vec![0.0f64; d];
202 for i in 0..n {
203 for j in 0..d {
204 mean[j] += self.obs[i * d + j] as f64;
205 }
206 }
207 for m in &mut mean {
208 *m /= n as f64;
209 }
210
211 let mut var = vec![0.0f64; d];
212 for i in 0..n {
213 for j in 0..d {
214 let diff = self.obs[i * d + j] as f64 - mean[j];
215 var[j] += diff * diff;
216 }
217 }
218 for v in &mut var {
219 *v = (*v / n as f64).sqrt().max(1e-8);
220 }
221
222 self.obs_mean = Some(mean.iter().map(|&x| x as f32).collect());
223 self.obs_std = Some(var.iter().map(|&x| x as f32).collect());
224
225 let r_mean = self.rewards.iter().map(|&r| r as f64).sum::<f64>() / n as f64;
227 let r_var = self
228 .rewards
229 .iter()
230 .map(|&r| {
231 let d = r as f64 - r_mean;
232 d * d
233 })
234 .sum::<f64>()
235 / n as f64;
236 self.reward_mean = Some(r_mean as f32);
237 self.reward_std = Some((r_var.sqrt().max(1e-8)) as f32);
238 }
239
240 pub fn sample(&self, batch_size: usize, seed: u64) -> OfflineBatch {
242 let mut rng = ChaCha8Rng::seed_from_u64(seed);
243 let d = self.obs_dim;
244 let a = self.act_dim;
245
246 let mut obs = Vec::with_capacity(batch_size * d);
247 let mut next_obs = Vec::with_capacity(batch_size * d);
248 let mut actions = Vec::with_capacity(batch_size * a);
249 let mut rewards = Vec::with_capacity(batch_size);
250 let mut terminated = Vec::with_capacity(batch_size);
251
252 for _ in 0..batch_size {
253 let idx = rng.random_range(0..self.len);
254
255 obs.extend_from_slice(&self.obs[idx * d..(idx + 1) * d]);
256 next_obs.extend_from_slice(&self.next_obs[idx * d..(idx + 1) * d]);
257 actions.extend_from_slice(&self.actions[idx * a..(idx + 1) * a]);
258 rewards.push(self.rewards[idx]);
259 terminated.push(self.terminated[idx]);
260 }
261
262 if let (Some(mean), Some(std)) = (&self.obs_mean, &self.obs_std) {
264 for i in 0..batch_size {
265 for j in 0..d {
266 obs[i * d + j] = (obs[i * d + j] - mean[j]) / std[j];
267 next_obs[i * d + j] = (next_obs[i * d + j] - mean[j]) / std[j];
268 }
269 }
270 }
271
272 OfflineBatch {
273 obs,
274 next_obs,
275 actions,
276 rewards,
277 terminated,
278 obs_dim: d,
279 act_dim: a,
280 }
281 }
282
283 pub fn sample_trajectories(
289 &self,
290 batch_size: usize,
291 seq_len: usize,
292 seed: u64,
293 ) -> TrajectoryBatch {
294 let mut rng = ChaCha8Rng::seed_from_u64(seed);
295 let d = self.obs_dim;
296 let a = self.act_dim;
297 let n_eps = self.n_episodes();
298
299 let total = batch_size * seq_len;
300 let mut obs = vec![0.0f32; total * d];
301 let mut actions = vec![0.0f32; total * a];
302 let mut rewards = vec![0.0f32; total];
303 let mut returns_to_go = vec![0.0f32; total];
304 let mut timesteps = vec![0u32; total];
305 let mut mask = vec![0u8; total];
306
307 for b in 0..batch_size {
308 let ep_idx = rng.random_range(0..n_eps);
309 let ep_start = self.episode_starts[ep_idx];
310 let ep_len = self.episode_lengths[ep_idx];
311
312 let max_start = ep_len.saturating_sub(seq_len);
314 let start_offset = rng.random_range(0..=max_start);
315 let actual_len = seq_len.min(ep_len - start_offset);
316
317 let mut rtg = vec![0.0f32; actual_len];
319 if actual_len > 0 {
320 rtg[actual_len - 1] = self.rewards[ep_start + start_offset + actual_len - 1];
321 for t in (0..actual_len - 1).rev() {
322 rtg[t] = self.rewards[ep_start + start_offset + t] + rtg[t + 1];
323 }
324 }
325
326 for (t, rtg_val) in rtg.iter().enumerate() {
327 let src_idx = ep_start + start_offset + t;
328 let dst_idx = b * seq_len + t;
329
330 obs[dst_idx * d..(dst_idx + 1) * d]
331 .copy_from_slice(&self.obs[src_idx * d..(src_idx + 1) * d]);
332 actions[dst_idx * a..(dst_idx + 1) * a]
333 .copy_from_slice(&self.actions[src_idx * a..(src_idx + 1) * a]);
334 rewards[dst_idx] = self.rewards[src_idx];
335 returns_to_go[dst_idx] = *rtg_val;
336 timesteps[dst_idx] = (start_offset + t) as u32;
337 mask[dst_idx] = 1;
338 }
339 }
340
341 TrajectoryBatch {
342 obs,
343 actions,
344 rewards,
345 returns_to_go,
346 timesteps,
347 mask,
348 seq_len,
349 obs_dim: d,
350 act_dim: a,
351 }
352 }
353
354 pub fn stats(&self) -> DatasetStats {
356 let returns = &self.episode_returns;
357 let n_eps = returns.len();
358
359 let mean_return = if n_eps > 0 {
360 returns.iter().sum::<f32>() / n_eps as f32
361 } else {
362 0.0
363 };
364
365 let std_return = if n_eps > 1 {
366 let var: f32 = returns
367 .iter()
368 .map(|&r| (r - mean_return).powi(2))
369 .sum::<f32>()
370 / (n_eps - 1) as f32;
371 var.sqrt()
372 } else {
373 0.0
374 };
375
376 let min_return = returns.iter().cloned().reduce(f32::min).unwrap_or(0.0);
377 let max_return = returns.iter().cloned().reduce(f32::max).unwrap_or(0.0);
378
379 let mean_ep_len = if n_eps > 0 {
380 self.episode_lengths.iter().sum::<usize>() as f32 / n_eps as f32
381 } else {
382 0.0
383 };
384
385 DatasetStats {
386 n_transitions: self.len,
387 n_episodes: n_eps,
388 obs_dim: self.obs_dim,
389 act_dim: self.act_dim,
390 mean_return,
391 std_return,
392 min_return,
393 max_return,
394 mean_episode_length: mean_ep_len,
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn make_test_dataset(
404 n: usize,
405 obs_dim: usize,
406 act_dim: usize,
407 ep_len: usize,
408 ) -> OfflineDatasetBuffer {
409 let rewards = vec![1.0f32; n];
410 let mut terminated = vec![0u8; n];
411 let truncated = vec![0u8; n];
412
413 for (i, t) in terminated.iter_mut().enumerate().take(n) {
415 if (i + 1).is_multiple_of(ep_len) {
416 *t = 1;
417 }
418 }
419
420 OfflineDatasetBuffer::from_arrays(
421 vec![0.1f32; n * obs_dim],
422 vec![0.2f32; n * obs_dim],
423 vec![0.0f32; n * act_dim],
424 rewards,
425 terminated,
426 truncated,
427 obs_dim,
428 act_dim,
429 )
430 .unwrap()
431 }
432
433 #[test]
434 fn test_load_from_arrays() {
435 let buf = make_test_dataset(100, 4, 1, 10);
436 assert_eq!(buf.len(), 100);
437 assert_eq!(buf.obs_dim(), 4);
438 assert_eq!(buf.act_dim(), 1);
439 }
440
441 #[test]
442 fn test_episode_boundary_detection() {
443 let buf = make_test_dataset(100, 4, 1, 10);
444 assert_eq!(buf.n_episodes(), 10);
445 assert_eq!(buf.episode_lengths, vec![10; 10]);
446 }
447
448 #[test]
449 fn test_episode_returns() {
450 let buf = make_test_dataset(100, 4, 1, 10);
451 for &ret in &buf.episode_returns {
453 assert!((ret - 10.0).abs() < 1e-5);
454 }
455 }
456
457 #[test]
458 fn test_sample_uniform_shapes() {
459 let buf = make_test_dataset(1000, 4, 2, 100);
460 let batch = buf.sample(32, 42);
461 assert_eq!(batch.obs.len(), 32 * 4);
462 assert_eq!(batch.next_obs.len(), 32 * 4);
463 assert_eq!(batch.actions.len(), 32 * 2);
464 assert_eq!(batch.rewards.len(), 32);
465 assert_eq!(batch.terminated.len(), 32);
466 }
467
468 #[test]
469 fn test_sample_deterministic() {
470 let buf = make_test_dataset(1000, 4, 1, 100);
471 let b1 = buf.sample(32, 42);
472 let b2 = buf.sample(32, 42);
473 assert_eq!(b1.obs, b2.obs);
474 assert_eq!(b1.rewards, b2.rewards);
475 }
476
477 #[test]
478 fn test_sample_different_seeds() {
479 let n = 1000;
481 let obs_dim = 4;
482 let obs: Vec<f32> = (0..n * obs_dim).map(|i| i as f32 * 0.001).collect();
483 let mut terminated = vec![0u8; n];
484 for i in (99..n).step_by(100) {
485 terminated[i] = 1;
486 }
487 let buf = OfflineDatasetBuffer::from_arrays(
488 obs.clone(),
489 obs,
490 vec![0.0; n],
491 vec![1.0; n],
492 terminated,
493 vec![0; n],
494 obs_dim,
495 1,
496 )
497 .unwrap();
498
499 let b1 = buf.sample(32, 42);
500 let b2 = buf.sample(32, 99);
501 assert_ne!(
502 b1.obs, b2.obs,
503 "Different seeds should produce different samples"
504 );
505 }
506
507 #[test]
508 fn test_normalization() {
509 let mut buf = make_test_dataset(1000, 4, 1, 100);
510 buf.compute_normalization();
511 assert!(buf.obs_mean.is_some());
512 assert!(buf.obs_std.is_some());
513
514 let batch = buf.sample(32, 42);
515 let mean: f32 = batch.obs.iter().sum::<f32>() / batch.obs.len() as f32;
517 assert!(
518 mean.abs() < 1.0,
519 "Normalized mean should be near 0, got {mean}"
520 );
521 }
522
523 #[test]
524 fn test_sample_trajectories_shapes() {
525 let buf = make_test_dataset(1000, 4, 2, 100);
526 let batch = buf.sample_trajectories(8, 20, 42);
527 assert_eq!(batch.obs.len(), 8 * 20 * 4);
528 assert_eq!(batch.actions.len(), 8 * 20 * 2);
529 assert_eq!(batch.rewards.len(), 8 * 20);
530 assert_eq!(batch.returns_to_go.len(), 8 * 20);
531 assert_eq!(batch.timesteps.len(), 8 * 20);
532 assert_eq!(batch.mask.len(), 8 * 20);
533 }
534
535 #[test]
536 fn test_sample_trajectories_mask() {
537 let buf = make_test_dataset(50, 4, 1, 5); let batch = buf.sample_trajectories(4, 10, 42); for b in 0..4 {
543 let valid: usize = (0..10).map(|t| batch.mask[b * 10 + t] as usize).sum();
544 assert!(
545 valid <= 5,
546 "Valid mask count should be <= ep_len=5, got {valid}"
547 );
548 assert!(valid > 0, "Should have at least 1 valid step");
549 }
550 }
551
552 #[test]
553 fn test_sample_trajectories_returns_to_go() {
554 let buf = make_test_dataset(100, 4, 1, 10);
555 let batch = buf.sample_trajectories(1, 10, 42);
556
557 let mut prev_rtg = f32::MAX;
559 for t in 0..10 {
560 if batch.mask[t] == 1 {
561 assert!(
562 batch.returns_to_go[t] <= prev_rtg + 1e-5,
563 "RTG should be non-increasing, got {} after {}",
564 batch.returns_to_go[t],
565 prev_rtg
566 );
567 prev_rtg = batch.returns_to_go[t];
568 }
569 }
570 }
571
572 #[test]
573 fn test_stats() {
574 let buf = make_test_dataset(100, 4, 1, 10);
575 let stats = buf.stats();
576 assert_eq!(stats.n_transitions, 100);
577 assert_eq!(stats.n_episodes, 10);
578 assert_eq!(stats.obs_dim, 4);
579 assert_eq!(stats.act_dim, 1);
580 assert!((stats.mean_return - 10.0).abs() < 1e-5);
581 assert!((stats.mean_episode_length - 10.0).abs() < 1e-5);
582 }
583
584 #[test]
585 fn test_empty_dataset_error() {
586 let result =
587 OfflineDatasetBuffer::from_arrays(vec![], vec![], vec![], vec![], vec![], vec![], 4, 1);
588 assert!(result.is_ok());
590 assert_eq!(result.unwrap().len(), 0);
591 }
592
593 #[test]
594 fn test_mismatched_lengths_error() {
595 let result = OfflineDatasetBuffer::from_arrays(
596 vec![0.0; 40], vec![0.0; 40],
598 vec![0.0; 10], vec![0.0; 10],
600 vec![0; 5], vec![0; 10],
602 4,
603 1,
604 );
605 assert!(result.is_err());
606 }
607
608 #[test]
609 fn test_variable_episode_lengths() {
610 let n = 25; let obs_dim = 2;
613 let act_dim = 1;
614 let mut terminated = vec![0u8; n];
615 terminated[4] = 1; terminated[12] = 1; terminated[24] = 1; let buf = OfflineDatasetBuffer::from_arrays(
620 vec![0.0; n * obs_dim],
621 vec![0.0; n * obs_dim],
622 vec![0.0; n * act_dim],
623 vec![1.0; n],
624 terminated,
625 vec![0; n],
626 obs_dim,
627 act_dim,
628 )
629 .unwrap();
630
631 assert_eq!(buf.n_episodes(), 3);
632 assert_eq!(buf.episode_lengths, vec![5, 8, 12]);
633 }
634}