rlox_core/buffer/
mod.rs

1//! Experience replay buffers for reinforcement learning.
2//!
3//! This module provides several buffer implementations, each suited to
4//! different training scenarios:
5//!
6//! | Buffer | Use case |
7//! |--------|----------|
8//! | [`ringbuf::ReplayBuffer`] | Standard uniform off-policy replay (DQN, SAC). |
9//! | [`priority::PrioritizedReplayBuffer`] | Proportional PER with importance-sampling weights. |
10//! | [`sequence::SequenceReplayBuffer`] | Contiguous subsequence sampling for recurrent policies. |
11//! | [`her::HERBuffer`] | Hindsight Experience Replay for goal-conditioned RL. |
12//! | [`mmap::MmapReplayBuffer`] | Disk-backed spill for large replay capacities. |
13//! | [`columnar::ExperienceTable`] | Append-only on-policy table (PPO, A2C). |
14//! | [`offline::OfflineDatasetBuffer`] | Read-only offline RL datasets. |
15//! | [`concurrent::ConcurrentReplayBuffer`] | Lock-free multi-producer buffer. |
16//!
17//! All buffers store observations and actions as `f32`, rewards as `f32`,
18//! and boolean done flags. The [`ExperienceRecord`] struct is a convenience
19//! for the `push()` API; prefer `push_slices()` to avoid intermediate
20//! `Vec` allocations in hot paths.
21
22pub mod columnar;
23pub mod concurrent;
24pub mod episode;
25pub mod extra_columns;
26#[cfg(feature = "gpu")]
27pub mod flat;
28pub mod her;
29pub mod mixed;
30pub mod mmap;
31pub mod offline;
32pub mod priority;
33pub mod provenance;
34pub mod ringbuf;
35pub mod sequence;
36pub mod varlen;
37
38/// A single experience record to push into a buffer.
39/// Uses f32 throughout for numpy compatibility.
40#[derive(Debug, Clone)]
41pub struct ExperienceRecord {
42    pub obs: Vec<f32>,
43    pub next_obs: Vec<f32>,
44    pub action: Vec<f32>,
45    pub reward: f32,
46    pub terminated: bool,
47    pub truncated: bool,
48}
49
50#[cfg(test)]
51pub(crate) fn sample_record(obs_dim: usize) -> ExperienceRecord {
52    ExperienceRecord {
53        obs: vec![1.0; obs_dim],
54        next_obs: vec![2.0; obs_dim],
55        action: vec![0.0],
56        reward: 1.0,
57        terminated: false,
58        truncated: false,
59    }
60}
61
62#[cfg(test)]
63pub(crate) fn sample_record_multidim(obs_dim: usize, act_dim: usize) -> ExperienceRecord {
64    ExperienceRecord {
65        obs: vec![1.0; obs_dim],
66        next_obs: vec![2.0; obs_dim],
67        action: vec![0.0; act_dim],
68        reward: 1.0,
69        terminated: false,
70        truncated: false,
71    }
72}
73
74#[cfg(test)]
75mod fix_verification_tests {
76    use super::*;
77    use crate::buffer::columnar::ExperienceTable;
78    use crate::buffer::ringbuf::ReplayBuffer;
79
80    #[test]
81    fn experience_record_action_is_vec() {
82        let record = ExperienceRecord {
83            obs: vec![0.0f32; 17],
84            next_obs: vec![0.0f32; 17],
85            action: vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6],
86            reward: 1.0,
87            terminated: false,
88            truncated: false,
89        };
90        assert_eq!(record.action.len(), 6);
91        assert_eq!(record.obs.len(), 17);
92    }
93
94    #[test]
95    fn experience_table_stores_multi_dim_action() {
96        let obs_dim = 17;
97        let act_dim = 6;
98        let mut table = ExperienceTable::new(obs_dim, act_dim);
99        let action = vec![0.1f32, -0.2, 0.3, -0.4, 0.5, -0.6];
100        let record = ExperienceRecord {
101            obs: vec![1.0; obs_dim],
102            next_obs: vec![2.0; obs_dim],
103            action: action.clone(),
104            reward: 5.0,
105            terminated: false,
106            truncated: false,
107        };
108        table.push(record).unwrap();
109        assert_eq!(table.actions_raw().len(), act_dim);
110        assert_eq!(&table.actions_raw()[..act_dim], action.as_slice());
111    }
112
113    #[test]
114    fn replay_buffer_multi_dim_action_roundtrip() {
115        let obs_dim = 4;
116        let act_dim = 3;
117        let mut buf = ReplayBuffer::new(100, obs_dim, act_dim);
118        let action = vec![0.5f32, -0.5, 1.0];
119        let record = ExperienceRecord {
120            obs: vec![0.1; obs_dim],
121            next_obs: vec![0.2; obs_dim],
122            action: action.clone(),
123            reward: 1.0,
124            terminated: false,
125            truncated: false,
126        };
127        buf.push(record).unwrap();
128        let batch = buf.sample(1, 42).unwrap();
129        assert_eq!(batch.act_dim, act_dim);
130        assert_eq!(batch.actions.len(), act_dim);
131        assert_eq!(&batch.actions[..act_dim], action.as_slice());
132    }
133
134    #[test]
135    fn experience_table_action_dim_mismatch_returns_error() {
136        let mut table = ExperienceTable::new(4, 2);
137        let record = ExperienceRecord {
138            obs: vec![1.0; 4],
139            next_obs: vec![2.0; 4],
140            action: vec![0.1, 0.2, 0.3], // 3 dims, table expects 2
141            reward: 1.0,
142            terminated: false,
143            truncated: false,
144        };
145        let result = table.push(record);
146        assert!(result.is_err(), "action dim mismatch must return Err");
147        let err_str = result.unwrap_err().to_string();
148        assert!(
149            err_str.contains("act_dim"),
150            "error must mention act_dim, got: {err_str}"
151        );
152    }
153
154    #[test]
155    fn experience_table_scalar_action_dim_one() {
156        let mut table = ExperienceTable::new(4, 1);
157        let record = ExperienceRecord {
158            obs: vec![1.0; 4],
159            next_obs: vec![2.0; 4],
160            action: vec![0.0],
161            reward: 1.0,
162            terminated: false,
163            truncated: false,
164        };
165        table.push(record).unwrap();
166        assert_eq!(table.len(), 1);
167        assert_eq!(table.actions_raw().len(), 1);
168    }
169}