1use crate::error::RloxError;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub struct ColumnHandle(usize);
6
7impl ColumnHandle {
8 pub fn index(self) -> usize {
10 self.0
11 }
12}
13
14#[derive(Debug)]
20pub struct ExtraColumns {
21 names: Vec<String>,
22 dims: Vec<usize>,
23 data: Vec<Vec<f32>>,
24 capacity: usize,
25}
26
27impl ExtraColumns {
28 pub fn new() -> Self {
30 Self {
31 names: Vec::new(),
32 dims: Vec::new(),
33 data: Vec::new(),
34 capacity: 0,
35 }
36 }
37
38 pub fn register(&mut self, name: &str, dim: usize) -> ColumnHandle {
43 let handle = ColumnHandle(self.names.len());
44 self.names.push(name.to_owned());
45 self.dims.push(dim);
46 let col = if self.capacity > 0 {
47 vec![0.0; self.capacity * dim]
48 } else {
49 Vec::new()
50 };
51 self.data.push(col);
52 handle
53 }
54
55 pub fn num_columns(&self) -> usize {
57 self.names.len()
58 }
59
60 pub fn column_info(&self, handle: ColumnHandle) -> (&str, usize) {
62 (&self.names[handle.0], self.dims[handle.0])
63 }
64
65 pub fn allocate(&mut self, capacity: usize) {
67 self.capacity = capacity;
68 for (col, &dim) in self.data.iter_mut().zip(self.dims.iter()) {
69 col.resize(capacity * dim, 0.0);
70 }
71 }
72
73 pub fn push(
77 &mut self,
78 handle: ColumnHandle,
79 pos: usize,
80 values: &[f32],
81 ) -> Result<(), RloxError> {
82 let dim = self.dims[handle.0];
83 if values.len() != dim {
84 return Err(RloxError::ShapeMismatch {
85 expected: format!("extra column '{}' dim={}", self.names[handle.0], dim),
86 got: format!("values.len()={}", values.len()),
87 });
88 }
89 let col = &mut self.data[handle.0];
90 let start = pos * dim;
91 if start + dim > col.len() {
92 return Err(RloxError::BufferError(format!(
93 "extra column position {} out of bounds (allocated for {})",
94 pos,
95 col.len() / dim
96 )));
97 }
98 col[start..start + dim].copy_from_slice(values);
99 Ok(())
100 }
101
102 pub fn sample(&self, handle: ColumnHandle, indices: &[usize]) -> Vec<f32> {
106 let dim = self.dims[handle.0];
107 let col = &self.data[handle.0];
108 let mut out = Vec::with_capacity(indices.len() * dim);
109 for &idx in indices {
110 let start = idx * dim;
111 out.extend_from_slice(&col[start..start + dim]);
112 }
113 out
114 }
115
116 pub fn sample_all(&self, indices: &[usize]) -> Vec<(String, Vec<f32>)> {
121 self.names
122 .iter()
123 .enumerate()
124 .map(|(i, name)| {
125 let handle = ColumnHandle(i);
126 (name.clone(), self.sample(handle, indices))
127 })
128 .collect()
129 }
130
131 pub fn clear(&mut self) {
133 for col in &mut self.data {
134 col.fill(0.0);
135 }
136 }
137}
138
139impl Default for ExtraColumns {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_extra_columns_register_and_push() {
151 let mut ec = ExtraColumns::new();
152 let h = ec.register("log_prob", 1);
153 ec.allocate(10);
154
155 ec.push(h, 0, &[0.5]).unwrap();
156 ec.push(h, 1, &[-0.3]).unwrap();
157
158 assert_eq!(ec.num_columns(), 1);
159 let (name, dim) = ec.column_info(h);
160 assert_eq!(name, "log_prob");
161 assert_eq!(dim, 1);
162 }
163
164 #[test]
165 fn test_extra_columns_sample_roundtrip() {
166 let mut ec = ExtraColumns::new();
167 let h = ec.register("values", 2);
168 ec.allocate(5);
169
170 for i in 0..5 {
171 let v = i as f32;
172 ec.push(h, i, &[v, v * 10.0]).unwrap();
173 }
174
175 let sampled = ec.sample(h, &[0, 2, 4]);
176 assert_eq!(sampled, vec![0.0, 0.0, 2.0, 20.0, 4.0, 40.0]);
177 }
178
179 #[test]
180 fn test_extra_columns_zero_overhead_when_empty() {
181 let ec = ExtraColumns::new();
182 assert_eq!(ec.num_columns(), 0);
183 assert!(ec.names.is_empty());
185 assert!(ec.dims.is_empty());
186 assert!(ec.data.is_empty());
187 }
188
189 #[test]
190 fn test_extra_columns_multiple_columns() {
191 let mut ec = ExtraColumns::new();
192 let h1 = ec.register("log_prob", 1);
193 let h2 = ec.register("action_mean", 3);
194 ec.allocate(4);
195
196 ec.push(h1, 0, &[0.1]).unwrap();
197 ec.push(h2, 0, &[1.0, 2.0, 3.0]).unwrap();
198 ec.push(h1, 1, &[0.2]).unwrap();
199 ec.push(h2, 1, &[4.0, 5.0, 6.0]).unwrap();
200
201 let s1 = ec.sample(h1, &[0, 1]);
202 assert_eq!(s1, vec![0.1, 0.2]);
203
204 let s2 = ec.sample(h2, &[0, 1]);
205 assert_eq!(s2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
206 }
207
208 #[test]
209 fn test_extra_columns_dim_mismatch_errors() {
210 let mut ec = ExtraColumns::new();
211 let h = ec.register("test", 2);
212 ec.allocate(4);
213
214 let result = ec.push(h, 0, &[1.0]); assert!(result.is_err());
216 let err = result.unwrap_err().to_string();
217 assert!(
218 err.contains("dim=2"),
219 "error should mention dim, got: {err}"
220 );
221
222 let result = ec.push(h, 0, &[1.0, 2.0, 3.0]); assert!(result.is_err());
224 }
225
226 #[test]
227 fn test_extra_columns_out_of_bounds_errors() {
228 let mut ec = ExtraColumns::new();
229 let h = ec.register("test", 1);
230 ec.allocate(2);
231
232 ec.push(h, 0, &[1.0]).unwrap();
233 ec.push(h, 1, &[2.0]).unwrap();
234 let result = ec.push(h, 2, &[3.0]); assert!(result.is_err());
236 }
237
238 #[test]
239 fn test_extra_columns_sample_all() {
240 let mut ec = ExtraColumns::new();
241 let h1 = ec.register("alpha", 1);
242 let h2 = ec.register("beta", 2);
243 ec.allocate(3);
244
245 ec.push(h1, 0, &[0.1]).unwrap();
246 ec.push(h1, 1, &[0.2]).unwrap();
247 ec.push(h2, 0, &[1.0, 2.0]).unwrap();
248 ec.push(h2, 1, &[3.0, 4.0]).unwrap();
249
250 let all = ec.sample_all(&[0, 1]);
251 assert_eq!(all.len(), 2);
252 assert_eq!(all[0].0, "alpha");
253 assert_eq!(all[0].1, vec![0.1, 0.2]);
254 assert_eq!(all[1].0, "beta");
255 assert_eq!(all[1].1, vec![1.0, 2.0, 3.0, 4.0]);
256 }
257
258 #[test]
259 fn test_extra_columns_clear_preserves_registrations() {
260 let mut ec = ExtraColumns::new();
261 let h = ec.register("test", 1);
262 ec.allocate(3);
263 ec.push(h, 0, &[42.0]).unwrap();
264
265 ec.clear();
266 assert_eq!(ec.num_columns(), 1);
267 let sampled = ec.sample(h, &[0]);
269 assert_eq!(sampled, vec![0.0]);
270 }
271}