rlox_core/buffer/
extra_columns.rs

1use crate::error::RloxError;
2
3/// Handle for accessing a named extra column. O(1) via Vec index.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub struct ColumnHandle(usize);
6
7impl ColumnHandle {
8    /// The raw index used for Vec-based O(1) lookups.
9    pub fn index(self) -> usize {
10        self.0
11    }
12}
13
14/// Storage for user-defined extra columns on a replay buffer.
15///
16/// Each column is a flat `Vec<f32>` with a fixed dimensionality, stored
17/// contiguously for cache-friendly sampling. When no columns are registered,
18/// this struct has zero overhead — no allocations occur.
19#[derive(Debug)]
20pub struct ExtraColumns {
21    names: Vec<String>,
22    dims: Vec<usize>,
23    data: Vec<Vec<f32>>,
24    capacity: usize,
25}
26
27impl ExtraColumns {
28    /// Create empty extra-column storage. No allocations until `register()`.
29    pub fn new() -> Self {
30        Self {
31            names: Vec::new(),
32            dims: Vec::new(),
33            data: Vec::new(),
34            capacity: 0,
35        }
36    }
37
38    /// Register a new column. Returns a handle for O(1) access.
39    ///
40    /// Must be called before any data is pushed. If storage has already been
41    /// allocated (via `allocate()`), the new column is pre-allocated too.
42    pub fn register(&mut self, name: &str, dim: usize) -> ColumnHandle {
43        let handle = ColumnHandle(self.names.len());
44        self.names.push(name.to_owned());
45        self.dims.push(dim);
46        let col = if self.capacity > 0 {
47            vec![0.0; self.capacity * dim]
48        } else {
49            Vec::new()
50        };
51        self.data.push(col);
52        handle
53    }
54
55    /// Number of registered columns.
56    pub fn num_columns(&self) -> usize {
57        self.names.len()
58    }
59
60    /// Get column name and dim by handle.
61    pub fn column_info(&self, handle: ColumnHandle) -> (&str, usize) {
62        (&self.names[handle.0], self.dims[handle.0])
63    }
64
65    /// Pre-allocate storage for a given capacity.
66    pub fn allocate(&mut self, capacity: usize) {
67        self.capacity = capacity;
68        for (col, &dim) in self.data.iter_mut().zip(self.dims.iter()) {
69            col.resize(capacity * dim, 0.0);
70        }
71    }
72
73    /// Write values for one column at a given buffer position.
74    ///
75    /// The `values` slice must have length equal to the column's dimensionality.
76    pub fn push(
77        &mut self,
78        handle: ColumnHandle,
79        pos: usize,
80        values: &[f32],
81    ) -> Result<(), RloxError> {
82        let dim = self.dims[handle.0];
83        if values.len() != dim {
84            return Err(RloxError::ShapeMismatch {
85                expected: format!("extra column '{}' dim={}", self.names[handle.0], dim),
86                got: format!("values.len()={}", values.len()),
87            });
88        }
89        let col = &mut self.data[handle.0];
90        let start = pos * dim;
91        if start + dim > col.len() {
92            return Err(RloxError::BufferError(format!(
93                "extra column position {} out of bounds (allocated for {})",
94                pos,
95                col.len() / dim
96            )));
97        }
98        col[start..start + dim].copy_from_slice(values);
99        Ok(())
100    }
101
102    /// Gather values for one column at the given sampled indices.
103    ///
104    /// Returns a flat `Vec<f32>` of length `indices.len() * dim`.
105    pub fn sample(&self, handle: ColumnHandle, indices: &[usize]) -> Vec<f32> {
106        let dim = self.dims[handle.0];
107        let col = &self.data[handle.0];
108        let mut out = Vec::with_capacity(indices.len() * dim);
109        for &idx in indices {
110            let start = idx * dim;
111            out.extend_from_slice(&col[start..start + dim]);
112        }
113        out
114    }
115
116    /// Gather all columns at the given sampled indices.
117    ///
118    /// Returns `(column_name, flat_data)` pairs. Only called when
119    /// `num_columns() > 0`.
120    pub fn sample_all(&self, indices: &[usize]) -> Vec<(String, Vec<f32>)> {
121        self.names
122            .iter()
123            .enumerate()
124            .map(|(i, name)| {
125                let handle = ColumnHandle(i);
126                (name.clone(), self.sample(handle, indices))
127            })
128            .collect()
129    }
130
131    /// Clear all data (keep column registrations and capacity).
132    pub fn clear(&mut self) {
133        for col in &mut self.data {
134            col.fill(0.0);
135        }
136    }
137}
138
139impl Default for ExtraColumns {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_extra_columns_register_and_push() {
151        let mut ec = ExtraColumns::new();
152        let h = ec.register("log_prob", 1);
153        ec.allocate(10);
154
155        ec.push(h, 0, &[0.5]).unwrap();
156        ec.push(h, 1, &[-0.3]).unwrap();
157
158        assert_eq!(ec.num_columns(), 1);
159        let (name, dim) = ec.column_info(h);
160        assert_eq!(name, "log_prob");
161        assert_eq!(dim, 1);
162    }
163
164    #[test]
165    fn test_extra_columns_sample_roundtrip() {
166        let mut ec = ExtraColumns::new();
167        let h = ec.register("values", 2);
168        ec.allocate(5);
169
170        for i in 0..5 {
171            let v = i as f32;
172            ec.push(h, i, &[v, v * 10.0]).unwrap();
173        }
174
175        let sampled = ec.sample(h, &[0, 2, 4]);
176        assert_eq!(sampled, vec![0.0, 0.0, 2.0, 20.0, 4.0, 40.0]);
177    }
178
179    #[test]
180    fn test_extra_columns_zero_overhead_when_empty() {
181        let ec = ExtraColumns::new();
182        assert_eq!(ec.num_columns(), 0);
183        // No allocations at all
184        assert!(ec.names.is_empty());
185        assert!(ec.dims.is_empty());
186        assert!(ec.data.is_empty());
187    }
188
189    #[test]
190    fn test_extra_columns_multiple_columns() {
191        let mut ec = ExtraColumns::new();
192        let h1 = ec.register("log_prob", 1);
193        let h2 = ec.register("action_mean", 3);
194        ec.allocate(4);
195
196        ec.push(h1, 0, &[0.1]).unwrap();
197        ec.push(h2, 0, &[1.0, 2.0, 3.0]).unwrap();
198        ec.push(h1, 1, &[0.2]).unwrap();
199        ec.push(h2, 1, &[4.0, 5.0, 6.0]).unwrap();
200
201        let s1 = ec.sample(h1, &[0, 1]);
202        assert_eq!(s1, vec![0.1, 0.2]);
203
204        let s2 = ec.sample(h2, &[0, 1]);
205        assert_eq!(s2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
206    }
207
208    #[test]
209    fn test_extra_columns_dim_mismatch_errors() {
210        let mut ec = ExtraColumns::new();
211        let h = ec.register("test", 2);
212        ec.allocate(4);
213
214        let result = ec.push(h, 0, &[1.0]); // dim 1 but column expects 2
215        assert!(result.is_err());
216        let err = result.unwrap_err().to_string();
217        assert!(
218            err.contains("dim=2"),
219            "error should mention dim, got: {err}"
220        );
221
222        let result = ec.push(h, 0, &[1.0, 2.0, 3.0]); // dim 3 but expects 2
223        assert!(result.is_err());
224    }
225
226    #[test]
227    fn test_extra_columns_out_of_bounds_errors() {
228        let mut ec = ExtraColumns::new();
229        let h = ec.register("test", 1);
230        ec.allocate(2);
231
232        ec.push(h, 0, &[1.0]).unwrap();
233        ec.push(h, 1, &[2.0]).unwrap();
234        let result = ec.push(h, 2, &[3.0]); // pos 2 but capacity is 2
235        assert!(result.is_err());
236    }
237
238    #[test]
239    fn test_extra_columns_sample_all() {
240        let mut ec = ExtraColumns::new();
241        let h1 = ec.register("alpha", 1);
242        let h2 = ec.register("beta", 2);
243        ec.allocate(3);
244
245        ec.push(h1, 0, &[0.1]).unwrap();
246        ec.push(h1, 1, &[0.2]).unwrap();
247        ec.push(h2, 0, &[1.0, 2.0]).unwrap();
248        ec.push(h2, 1, &[3.0, 4.0]).unwrap();
249
250        let all = ec.sample_all(&[0, 1]);
251        assert_eq!(all.len(), 2);
252        assert_eq!(all[0].0, "alpha");
253        assert_eq!(all[0].1, vec![0.1, 0.2]);
254        assert_eq!(all[1].0, "beta");
255        assert_eq!(all[1].1, vec![1.0, 2.0, 3.0, 4.0]);
256    }
257
258    #[test]
259    fn test_extra_columns_clear_preserves_registrations() {
260        let mut ec = ExtraColumns::new();
261        let h = ec.register("test", 1);
262        ec.allocate(3);
263        ec.push(h, 0, &[42.0]).unwrap();
264
265        ec.clear();
266        assert_eq!(ec.num_columns(), 1);
267        // Data should be zeroed
268        let sampled = ec.sample(h, &[0]);
269        assert_eq!(sampled, vec![0.0]);
270    }
271}