rlox_core/buffer/
offline.rs

1//! Read-only offline dataset buffer for offline RL algorithms.
2//!
3//! Unlike [`ReplayBuffer`], this buffer is loaded once from a static dataset
4//! and never modified. It supports:
5//! - Uniform i.i.d. transition sampling (for TD3+BC, IQL, CQL, BC)
6//! - Trajectory subsequence sampling (for Decision Transformer)
7//! - Return-conditioned sampling (for return-conditioned methods)
8//! - Dataset normalization statistics
9//!
10//! Designed for D4RL/Minari-scale datasets (1M+ transitions).
11
12use rand::Rng;
13use rand::SeedableRng;
14use rand_chacha::ChaCha8Rng;
15
16use crate::error::RloxError;
17
18/// Statistics about the loaded dataset.
19#[derive(Debug, Clone)]
20pub struct DatasetStats {
21    pub n_transitions: usize,
22    pub n_episodes: usize,
23    pub obs_dim: usize,
24    pub act_dim: usize,
25    pub mean_return: f32,
26    pub std_return: f32,
27    pub min_return: f32,
28    pub max_return: f32,
29    pub mean_episode_length: f32,
30}
31
32/// A batch of i.i.d. sampled transitions.
33#[derive(Debug, Clone)]
34pub struct OfflineBatch {
35    pub obs: Vec<f32>,       // [batch_size * obs_dim]
36    pub next_obs: Vec<f32>,  // [batch_size * obs_dim]
37    pub actions: Vec<f32>,   // [batch_size * act_dim]
38    pub rewards: Vec<f32>,   // [batch_size]
39    pub terminated: Vec<u8>, // [batch_size]
40    pub obs_dim: usize,
41    pub act_dim: usize,
42}
43
44/// A batch of contiguous trajectory subsequences.
45#[derive(Debug, Clone)]
46pub struct TrajectoryBatch {
47    pub obs: Vec<f32>,           // [batch_size * seq_len * obs_dim]
48    pub actions: Vec<f32>,       // [batch_size * seq_len * act_dim]
49    pub rewards: Vec<f32>,       // [batch_size * seq_len]
50    pub returns_to_go: Vec<f32>, // [batch_size * seq_len]
51    pub timesteps: Vec<u32>,     // [batch_size * seq_len]
52    pub mask: Vec<u8>,           // [batch_size * seq_len] (1 = valid, 0 = padding)
53    pub seq_len: usize,
54    pub obs_dim: usize,
55    pub act_dim: usize,
56}
57
58/// Read-only offline dataset buffer.
59pub struct OfflineDatasetBuffer {
60    obs: Vec<f32>,
61    next_obs: Vec<f32>,
62    actions: Vec<f32>,
63    rewards: Vec<f32>,
64    terminated: Vec<u8>,
65    #[allow(dead_code)]
66    truncated: Vec<u8>,
67
68    // Episode boundary tracking
69    episode_starts: Vec<usize>,
70    episode_lengths: Vec<usize>,
71    episode_returns: Vec<f32>,
72
73    obs_dim: usize,
74    act_dim: usize,
75    len: usize,
76
77    // Normalization (computed lazily)
78    obs_mean: Option<Vec<f32>>,
79    obs_std: Option<Vec<f32>>,
80    reward_mean: Option<f32>,
81    reward_std: Option<f32>,
82}
83
84impl OfflineDatasetBuffer {
85    /// Create from flat arrays.
86    ///
87    /// Arrays must be row-major: obs has length `n * obs_dim`, etc.
88    pub fn from_arrays(
89        obs: Vec<f32>,
90        next_obs: Vec<f32>,
91        actions: Vec<f32>,
92        rewards: Vec<f32>,
93        terminated: Vec<u8>,
94        truncated: Vec<u8>,
95        obs_dim: usize,
96        act_dim: usize,
97    ) -> Result<Self, RloxError> {
98        let n = rewards.len();
99
100        if obs.len() != n * obs_dim {
101            return Err(RloxError::ShapeMismatch {
102                expected: format!("obs length = {} * {} = {}", n, obs_dim, n * obs_dim),
103                got: format!("{}", obs.len()),
104            });
105        }
106        if next_obs.len() != n * obs_dim {
107            return Err(RloxError::ShapeMismatch {
108                expected: format!("next_obs length = {}", n * obs_dim),
109                got: format!("{}", next_obs.len()),
110            });
111        }
112        if actions.len() != n * act_dim {
113            return Err(RloxError::ShapeMismatch {
114                expected: format!("actions length = {} * {} = {}", n, act_dim, n * act_dim),
115                got: format!("{}", actions.len()),
116            });
117        }
118        if terminated.len() != n || truncated.len() != n {
119            return Err(RloxError::ShapeMismatch {
120                expected: format!("terminated/truncated length = {}", n),
121                got: format!(
122                    "terminated={}, truncated={}",
123                    terminated.len(),
124                    truncated.len()
125                ),
126            });
127        }
128
129        // Detect episode boundaries
130        let mut episode_starts = vec![0usize];
131        let mut episode_returns = Vec::new();
132        let mut ep_return = 0.0f32;
133
134        for i in 0..n {
135            ep_return += rewards[i];
136            let done = terminated[i] != 0 || truncated[i] != 0;
137            if done || i == n - 1 {
138                episode_returns.push(ep_return);
139                if i + 1 < n {
140                    episode_starts.push(i + 1);
141                }
142                ep_return = 0.0;
143            }
144        }
145
146        let episode_lengths: Vec<usize> = episode_starts
147            .windows(2)
148            .map(|w| w[1] - w[0])
149            .chain(std::iter::once(n - episode_starts.last().unwrap_or(&0)))
150            .collect();
151
152        Ok(Self {
153            obs,
154            next_obs,
155            actions,
156            rewards,
157            terminated,
158            truncated,
159            episode_starts,
160            episode_lengths,
161            episode_returns,
162            obs_dim,
163            act_dim,
164            len: n,
165            obs_mean: None,
166            obs_std: None,
167            reward_mean: None,
168            reward_std: None,
169        })
170    }
171
172    /// Number of transitions in the dataset.
173    pub fn len(&self) -> usize {
174        self.len
175    }
176
177    pub fn is_empty(&self) -> bool {
178        self.len == 0
179    }
180
181    /// Number of episodes in the dataset.
182    pub fn n_episodes(&self) -> usize {
183        self.episode_starts.len()
184    }
185
186    pub fn obs_dim(&self) -> usize {
187        self.obs_dim
188    }
189
190    pub fn act_dim(&self) -> usize {
191        self.act_dim
192    }
193
194    /// Compute and cache normalization statistics.
195    #[allow(clippy::needless_range_loop)]
196    pub fn compute_normalization(&mut self) {
197        let n = self.len;
198        let d = self.obs_dim;
199
200        // Obs mean and std
201        let mut mean = vec![0.0f64; d];
202        for i in 0..n {
203            for j in 0..d {
204                mean[j] += self.obs[i * d + j] as f64;
205            }
206        }
207        for m in &mut mean {
208            *m /= n as f64;
209        }
210
211        let mut var = vec![0.0f64; d];
212        for i in 0..n {
213            for j in 0..d {
214                let diff = self.obs[i * d + j] as f64 - mean[j];
215                var[j] += diff * diff;
216            }
217        }
218        for v in &mut var {
219            *v = (*v / n as f64).sqrt().max(1e-8);
220        }
221
222        self.obs_mean = Some(mean.iter().map(|&x| x as f32).collect());
223        self.obs_std = Some(var.iter().map(|&x| x as f32).collect());
224
225        // Reward mean and std
226        let r_mean = self.rewards.iter().map(|&r| r as f64).sum::<f64>() / n as f64;
227        let r_var = self
228            .rewards
229            .iter()
230            .map(|&r| {
231                let d = r as f64 - r_mean;
232                d * d
233            })
234            .sum::<f64>()
235            / n as f64;
236        self.reward_mean = Some(r_mean as f32);
237        self.reward_std = Some((r_var.sqrt().max(1e-8)) as f32);
238    }
239
240    /// Sample i.i.d. transitions uniformly.
241    pub fn sample(&self, batch_size: usize, seed: u64) -> OfflineBatch {
242        let mut rng = ChaCha8Rng::seed_from_u64(seed);
243        let d = self.obs_dim;
244        let a = self.act_dim;
245
246        let mut obs = Vec::with_capacity(batch_size * d);
247        let mut next_obs = Vec::with_capacity(batch_size * d);
248        let mut actions = Vec::with_capacity(batch_size * a);
249        let mut rewards = Vec::with_capacity(batch_size);
250        let mut terminated = Vec::with_capacity(batch_size);
251
252        for _ in 0..batch_size {
253            let idx = rng.random_range(0..self.len);
254
255            obs.extend_from_slice(&self.obs[idx * d..(idx + 1) * d]);
256            next_obs.extend_from_slice(&self.next_obs[idx * d..(idx + 1) * d]);
257            actions.extend_from_slice(&self.actions[idx * a..(idx + 1) * a]);
258            rewards.push(self.rewards[idx]);
259            terminated.push(self.terminated[idx]);
260        }
261
262        // Apply normalization if available
263        if let (Some(mean), Some(std)) = (&self.obs_mean, &self.obs_std) {
264            for i in 0..batch_size {
265                for j in 0..d {
266                    obs[i * d + j] = (obs[i * d + j] - mean[j]) / std[j];
267                    next_obs[i * d + j] = (next_obs[i * d + j] - mean[j]) / std[j];
268                }
269            }
270        }
271
272        OfflineBatch {
273            obs,
274            next_obs,
275            actions,
276            rewards,
277            terminated,
278            obs_dim: d,
279            act_dim: a,
280        }
281    }
282
283    /// Sample contiguous trajectory subsequences.
284    ///
285    /// Each sample is a contiguous window of `seq_len` transitions from a
286    /// single episode. If the episode is shorter than `seq_len`, the sequence
287    /// is right-padded with zeros and the mask indicates valid positions.
288    pub fn sample_trajectories(
289        &self,
290        batch_size: usize,
291        seq_len: usize,
292        seed: u64,
293    ) -> TrajectoryBatch {
294        let mut rng = ChaCha8Rng::seed_from_u64(seed);
295        let d = self.obs_dim;
296        let a = self.act_dim;
297        let n_eps = self.n_episodes();
298
299        let total = batch_size * seq_len;
300        let mut obs = vec![0.0f32; total * d];
301        let mut actions = vec![0.0f32; total * a];
302        let mut rewards = vec![0.0f32; total];
303        let mut returns_to_go = vec![0.0f32; total];
304        let mut timesteps = vec![0u32; total];
305        let mut mask = vec![0u8; total];
306
307        for b in 0..batch_size {
308            let ep_idx = rng.random_range(0..n_eps);
309            let ep_start = self.episode_starts[ep_idx];
310            let ep_len = self.episode_lengths[ep_idx];
311
312            // Random start within episode
313            let max_start = ep_len.saturating_sub(seq_len);
314            let start_offset = rng.random_range(0..=max_start);
315            let actual_len = seq_len.min(ep_len - start_offset);
316
317            // Compute returns-to-go for this episode segment
318            let mut rtg = vec![0.0f32; actual_len];
319            if actual_len > 0 {
320                rtg[actual_len - 1] = self.rewards[ep_start + start_offset + actual_len - 1];
321                for t in (0..actual_len - 1).rev() {
322                    rtg[t] = self.rewards[ep_start + start_offset + t] + rtg[t + 1];
323                }
324            }
325
326            for (t, rtg_val) in rtg.iter().enumerate() {
327                let src_idx = ep_start + start_offset + t;
328                let dst_idx = b * seq_len + t;
329
330                obs[dst_idx * d..(dst_idx + 1) * d]
331                    .copy_from_slice(&self.obs[src_idx * d..(src_idx + 1) * d]);
332                actions[dst_idx * a..(dst_idx + 1) * a]
333                    .copy_from_slice(&self.actions[src_idx * a..(src_idx + 1) * a]);
334                rewards[dst_idx] = self.rewards[src_idx];
335                returns_to_go[dst_idx] = *rtg_val;
336                timesteps[dst_idx] = (start_offset + t) as u32;
337                mask[dst_idx] = 1;
338            }
339        }
340
341        TrajectoryBatch {
342            obs,
343            actions,
344            rewards,
345            returns_to_go,
346            timesteps,
347            mask,
348            seq_len,
349            obs_dim: d,
350            act_dim: a,
351        }
352    }
353
354    /// Get dataset statistics.
355    pub fn stats(&self) -> DatasetStats {
356        let returns = &self.episode_returns;
357        let n_eps = returns.len();
358
359        let mean_return = if n_eps > 0 {
360            returns.iter().sum::<f32>() / n_eps as f32
361        } else {
362            0.0
363        };
364
365        let std_return = if n_eps > 1 {
366            let var: f32 = returns
367                .iter()
368                .map(|&r| (r - mean_return).powi(2))
369                .sum::<f32>()
370                / (n_eps - 1) as f32;
371            var.sqrt()
372        } else {
373            0.0
374        };
375
376        let min_return = returns.iter().cloned().reduce(f32::min).unwrap_or(0.0);
377        let max_return = returns.iter().cloned().reduce(f32::max).unwrap_or(0.0);
378
379        let mean_ep_len = if n_eps > 0 {
380            self.episode_lengths.iter().sum::<usize>() as f32 / n_eps as f32
381        } else {
382            0.0
383        };
384
385        DatasetStats {
386            n_transitions: self.len,
387            n_episodes: n_eps,
388            obs_dim: self.obs_dim,
389            act_dim: self.act_dim,
390            mean_return,
391            std_return,
392            min_return,
393            max_return,
394            mean_episode_length: mean_ep_len,
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    fn make_test_dataset(
404        n: usize,
405        obs_dim: usize,
406        act_dim: usize,
407        ep_len: usize,
408    ) -> OfflineDatasetBuffer {
409        let rewards = vec![1.0f32; n];
410        let mut terminated = vec![0u8; n];
411        let truncated = vec![0u8; n];
412
413        // Mark episode boundaries
414        for (i, t) in terminated.iter_mut().enumerate().take(n) {
415            if (i + 1).is_multiple_of(ep_len) {
416                *t = 1;
417            }
418        }
419
420        OfflineDatasetBuffer::from_arrays(
421            vec![0.1f32; n * obs_dim],
422            vec![0.2f32; n * obs_dim],
423            vec![0.0f32; n * act_dim],
424            rewards,
425            terminated,
426            truncated,
427            obs_dim,
428            act_dim,
429        )
430        .unwrap()
431    }
432
433    #[test]
434    fn test_load_from_arrays() {
435        let buf = make_test_dataset(100, 4, 1, 10);
436        assert_eq!(buf.len(), 100);
437        assert_eq!(buf.obs_dim(), 4);
438        assert_eq!(buf.act_dim(), 1);
439    }
440
441    #[test]
442    fn test_episode_boundary_detection() {
443        let buf = make_test_dataset(100, 4, 1, 10);
444        assert_eq!(buf.n_episodes(), 10);
445        assert_eq!(buf.episode_lengths, vec![10; 10]);
446    }
447
448    #[test]
449    fn test_episode_returns() {
450        let buf = make_test_dataset(100, 4, 1, 10);
451        // Each episode has 10 steps with reward 1.0 → return = 10.0
452        for &ret in &buf.episode_returns {
453            assert!((ret - 10.0).abs() < 1e-5);
454        }
455    }
456
457    #[test]
458    fn test_sample_uniform_shapes() {
459        let buf = make_test_dataset(1000, 4, 2, 100);
460        let batch = buf.sample(32, 42);
461        assert_eq!(batch.obs.len(), 32 * 4);
462        assert_eq!(batch.next_obs.len(), 32 * 4);
463        assert_eq!(batch.actions.len(), 32 * 2);
464        assert_eq!(batch.rewards.len(), 32);
465        assert_eq!(batch.terminated.len(), 32);
466    }
467
468    #[test]
469    fn test_sample_deterministic() {
470        let buf = make_test_dataset(1000, 4, 1, 100);
471        let b1 = buf.sample(32, 42);
472        let b2 = buf.sample(32, 42);
473        assert_eq!(b1.obs, b2.obs);
474        assert_eq!(b1.rewards, b2.rewards);
475    }
476
477    #[test]
478    fn test_sample_different_seeds() {
479        // Use varying obs so different indices produce different data
480        let n = 1000;
481        let obs_dim = 4;
482        let obs: Vec<f32> = (0..n * obs_dim).map(|i| i as f32 * 0.001).collect();
483        let mut terminated = vec![0u8; n];
484        for i in (99..n).step_by(100) {
485            terminated[i] = 1;
486        }
487        let buf = OfflineDatasetBuffer::from_arrays(
488            obs.clone(),
489            obs,
490            vec![0.0; n],
491            vec![1.0; n],
492            terminated,
493            vec![0; n],
494            obs_dim,
495            1,
496        )
497        .unwrap();
498
499        let b1 = buf.sample(32, 42);
500        let b2 = buf.sample(32, 99);
501        assert_ne!(
502            b1.obs, b2.obs,
503            "Different seeds should produce different samples"
504        );
505    }
506
507    #[test]
508    fn test_normalization() {
509        let mut buf = make_test_dataset(1000, 4, 1, 100);
510        buf.compute_normalization();
511        assert!(buf.obs_mean.is_some());
512        assert!(buf.obs_std.is_some());
513
514        let batch = buf.sample(32, 42);
515        // Normalized obs should have roughly zero mean
516        let mean: f32 = batch.obs.iter().sum::<f32>() / batch.obs.len() as f32;
517        assert!(
518            mean.abs() < 1.0,
519            "Normalized mean should be near 0, got {mean}"
520        );
521    }
522
523    #[test]
524    fn test_sample_trajectories_shapes() {
525        let buf = make_test_dataset(1000, 4, 2, 100);
526        let batch = buf.sample_trajectories(8, 20, 42);
527        assert_eq!(batch.obs.len(), 8 * 20 * 4);
528        assert_eq!(batch.actions.len(), 8 * 20 * 2);
529        assert_eq!(batch.rewards.len(), 8 * 20);
530        assert_eq!(batch.returns_to_go.len(), 8 * 20);
531        assert_eq!(batch.timesteps.len(), 8 * 20);
532        assert_eq!(batch.mask.len(), 8 * 20);
533    }
534
535    #[test]
536    fn test_sample_trajectories_mask() {
537        // Short episodes → padding
538        let buf = make_test_dataset(50, 4, 1, 5); // 10 episodes of length 5
539        let batch = buf.sample_trajectories(4, 10, 42); // request seq_len=10
540
541        // Each trajectory comes from ep_len=5 episode, so at most 5 valid
542        for b in 0..4 {
543            let valid: usize = (0..10).map(|t| batch.mask[b * 10 + t] as usize).sum();
544            assert!(
545                valid <= 5,
546                "Valid mask count should be <= ep_len=5, got {valid}"
547            );
548            assert!(valid > 0, "Should have at least 1 valid step");
549        }
550    }
551
552    #[test]
553    fn test_sample_trajectories_returns_to_go() {
554        let buf = make_test_dataset(100, 4, 1, 10);
555        let batch = buf.sample_trajectories(1, 10, 42);
556
557        // Returns-to-go should be decreasing within valid region
558        let mut prev_rtg = f32::MAX;
559        for t in 0..10 {
560            if batch.mask[t] == 1 {
561                assert!(
562                    batch.returns_to_go[t] <= prev_rtg + 1e-5,
563                    "RTG should be non-increasing, got {} after {}",
564                    batch.returns_to_go[t],
565                    prev_rtg
566                );
567                prev_rtg = batch.returns_to_go[t];
568            }
569        }
570    }
571
572    #[test]
573    fn test_stats() {
574        let buf = make_test_dataset(100, 4, 1, 10);
575        let stats = buf.stats();
576        assert_eq!(stats.n_transitions, 100);
577        assert_eq!(stats.n_episodes, 10);
578        assert_eq!(stats.obs_dim, 4);
579        assert_eq!(stats.act_dim, 1);
580        assert!((stats.mean_return - 10.0).abs() < 1e-5);
581        assert!((stats.mean_episode_length - 10.0).abs() < 1e-5);
582    }
583
584    #[test]
585    fn test_empty_dataset_error() {
586        let result =
587            OfflineDatasetBuffer::from_arrays(vec![], vec![], vec![], vec![], vec![], vec![], 4, 1);
588        // Empty is technically valid (0 transitions)
589        assert!(result.is_ok());
590        assert_eq!(result.unwrap().len(), 0);
591    }
592
593    #[test]
594    fn test_mismatched_lengths_error() {
595        let result = OfflineDatasetBuffer::from_arrays(
596            vec![0.0; 40], // 10 * 4
597            vec![0.0; 40],
598            vec![0.0; 10], // 10 * 1
599            vec![0.0; 10],
600            vec![0; 5], // WRONG: should be 10
601            vec![0; 10],
602            4,
603            1,
604        );
605        assert!(result.is_err());
606    }
607
608    #[test]
609    fn test_variable_episode_lengths() {
610        // Create dataset with variable-length episodes
611        let n = 25; // episodes: 5 + 8 + 12 = 25
612        let obs_dim = 2;
613        let act_dim = 1;
614        let mut terminated = vec![0u8; n];
615        terminated[4] = 1; // episode 1: steps 0-4
616        terminated[12] = 1; // episode 2: steps 5-12
617        terminated[24] = 1; // episode 3: steps 13-24
618
619        let buf = OfflineDatasetBuffer::from_arrays(
620            vec![0.0; n * obs_dim],
621            vec![0.0; n * obs_dim],
622            vec![0.0; n * act_dim],
623            vec![1.0; n],
624            terminated,
625            vec![0; n],
626            obs_dim,
627            act_dim,
628        )
629        .unwrap();
630
631        assert_eq!(buf.n_episodes(), 3);
632        assert_eq!(buf.episode_lengths, vec![5, 8, 12]);
633    }
634}