rlox_core/buffer/
columnar.rs

1use crate::error::RloxError;
2
3use super::ExperienceRecord;
4
5/// Append-only columnar table for RL transitions.
6///
7/// All data is stored in flat `Vec<f32>` arrays for efficient export
8/// to numpy. The table only appends, never reallocating existing data
9/// in-place.
10pub struct ExperienceTable {
11    obs_dim: usize,
12    act_dim: usize,
13    observations: Vec<f32>,
14    next_observations: Vec<f32>,
15    actions: Vec<f32>,
16    rewards: Vec<f32>,
17    terminated: Vec<bool>,
18    truncated: Vec<bool>,
19    count: usize,
20}
21
22impl ExperienceTable {
23    /// Create a new table with the given observation and action dimensions.
24    pub fn new(obs_dim: usize, act_dim: usize) -> Self {
25        Self {
26            obs_dim,
27            act_dim,
28            observations: Vec::new(),
29            next_observations: Vec::new(),
30            actions: Vec::new(),
31            rewards: Vec::new(),
32            terminated: Vec::new(),
33            truncated: Vec::new(),
34            count: 0,
35        }
36    }
37
38    /// Observation dimensionality.
39    pub fn obs_dim(&self) -> usize {
40        self.obs_dim
41    }
42
43    /// Action dimensionality.
44    pub fn act_dim(&self) -> usize {
45        self.act_dim
46    }
47
48    /// Number of transitions stored.
49    pub fn len(&self) -> usize {
50        self.count
51    }
52
53    /// Whether the table is empty.
54    pub fn is_empty(&self) -> bool {
55        self.count == 0
56    }
57
58    /// Append a transition from borrowed slices, avoiding intermediate allocation.
59    pub fn push_slices(
60        &mut self,
61        obs: &[f32],
62        next_obs: &[f32],
63        action: &[f32],
64        reward: f32,
65        terminated: bool,
66        truncated: bool,
67    ) -> Result<(), RloxError> {
68        if obs.len() != self.obs_dim {
69            return Err(RloxError::ShapeMismatch {
70                expected: format!("obs_dim={}", self.obs_dim),
71                got: format!("obs.len()={}", obs.len()),
72            });
73        }
74        if next_obs.len() != self.obs_dim {
75            return Err(RloxError::ShapeMismatch {
76                expected: format!("obs_dim={}", self.obs_dim),
77                got: format!("next_obs.len()={}", next_obs.len()),
78            });
79        }
80        if action.len() != self.act_dim {
81            return Err(RloxError::ShapeMismatch {
82                expected: format!("act_dim={}", self.act_dim),
83                got: format!("action.len()={}", action.len()),
84            });
85        }
86        self.observations.extend_from_slice(obs);
87        self.next_observations.extend_from_slice(next_obs);
88        self.actions.extend_from_slice(action);
89        self.rewards.push(reward);
90        self.terminated.push(terminated);
91        self.truncated.push(truncated);
92        self.count += 1;
93        Ok(())
94    }
95
96    /// Append a single transition. Returns error on dimension mismatch.
97    pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
98        self.push_slices(
99            &record.obs,
100            &record.next_obs,
101            &record.action,
102            record.reward,
103            record.terminated,
104            record.truncated,
105        )
106    }
107
108    /// Raw slice of all observation data. Shape: [count * obs_dim].
109    pub fn observations_raw(&self) -> &[f32] {
110        &self.observations
111    }
112
113    /// Raw slice of all action data. Shape: [count * act_dim].
114    pub fn actions_raw(&self) -> &[f32] {
115        &self.actions
116    }
117
118    /// Raw slice of all rewards.
119    pub fn rewards_raw(&self) -> &[f32] {
120        &self.rewards
121    }
122
123    /// Slice of terminated flags.
124    pub fn terminated(&self) -> &[bool] {
125        &self.terminated
126    }
127
128    /// Slice of truncated flags.
129    pub fn truncated(&self) -> &[bool] {
130        &self.truncated
131    }
132
133    /// Drop all stored data.
134    pub fn clear(&mut self) {
135        self.observations.clear();
136        self.next_observations.clear();
137        self.actions.clear();
138        self.rewards.clear();
139        self.terminated.clear();
140        self.truncated.clear();
141        self.count = 0;
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::buffer::sample_record;
149
150    #[test]
151    fn empty_table_has_zero_len() {
152        let table = ExperienceTable::new(4, 1);
153        assert_eq!(table.len(), 0);
154        assert!(table.is_empty());
155    }
156
157    #[test]
158    fn push_single_transition_increments_len() {
159        let mut table = ExperienceTable::new(4, 1);
160        table.push(sample_record(4)).unwrap();
161        assert_eq!(table.len(), 1);
162    }
163
164    #[test]
165    fn push_many_transitions() {
166        let mut table = ExperienceTable::new(4, 1);
167        for _ in 0..1000 {
168            table.push(sample_record(4)).unwrap();
169        }
170        assert_eq!(table.len(), 1000);
171    }
172
173    #[test]
174    fn observations_column_correct_length() {
175        let mut table = ExperienceTable::new(4, 1);
176        for _ in 0..10 {
177            table.push(sample_record(4)).unwrap();
178        }
179        assert_eq!(table.observations_raw().len(), 40);
180    }
181
182    #[test]
183    fn rewards_column_correct_values() {
184        let mut table = ExperienceTable::new(4, 1);
185        let mut r = sample_record(4);
186        r.reward = 42.0;
187        table.push(r).unwrap();
188        assert_eq!(table.rewards_raw()[0], 42.0);
189    }
190
191    #[test]
192    fn obs_dim_mismatch_returns_error() {
193        let mut table = ExperienceTable::new(4, 1);
194        let bad = sample_record(8);
195        assert!(table.push(bad).is_err());
196    }
197
198    #[test]
199    fn clear_empties_all_columns() {
200        let mut table = ExperienceTable::new(4, 1);
201        for _ in 0..100 {
202            table.push(sample_record(4)).unwrap();
203        }
204        table.clear();
205        assert_eq!(table.len(), 0);
206        assert!(table.observations_raw().is_empty());
207    }
208
209    #[test]
210    fn obs_dim_getter() {
211        let table = ExperienceTable::new(4, 1);
212        assert_eq!(table.obs_dim(), 4);
213    }
214}