rlox_core/buffer/
columnar.rs1use crate::error::RloxError;
2
3use super::ExperienceRecord;
4
5pub struct ExperienceTable {
11 obs_dim: usize,
12 act_dim: usize,
13 observations: Vec<f32>,
14 next_observations: Vec<f32>,
15 actions: Vec<f32>,
16 rewards: Vec<f32>,
17 terminated: Vec<bool>,
18 truncated: Vec<bool>,
19 count: usize,
20}
21
22impl ExperienceTable {
23 pub fn new(obs_dim: usize, act_dim: usize) -> Self {
25 Self {
26 obs_dim,
27 act_dim,
28 observations: Vec::new(),
29 next_observations: Vec::new(),
30 actions: Vec::new(),
31 rewards: Vec::new(),
32 terminated: Vec::new(),
33 truncated: Vec::new(),
34 count: 0,
35 }
36 }
37
38 pub fn obs_dim(&self) -> usize {
40 self.obs_dim
41 }
42
43 pub fn act_dim(&self) -> usize {
45 self.act_dim
46 }
47
48 pub fn len(&self) -> usize {
50 self.count
51 }
52
53 pub fn is_empty(&self) -> bool {
55 self.count == 0
56 }
57
58 pub fn push_slices(
60 &mut self,
61 obs: &[f32],
62 next_obs: &[f32],
63 action: &[f32],
64 reward: f32,
65 terminated: bool,
66 truncated: bool,
67 ) -> Result<(), RloxError> {
68 if obs.len() != self.obs_dim {
69 return Err(RloxError::ShapeMismatch {
70 expected: format!("obs_dim={}", self.obs_dim),
71 got: format!("obs.len()={}", obs.len()),
72 });
73 }
74 if next_obs.len() != self.obs_dim {
75 return Err(RloxError::ShapeMismatch {
76 expected: format!("obs_dim={}", self.obs_dim),
77 got: format!("next_obs.len()={}", next_obs.len()),
78 });
79 }
80 if action.len() != self.act_dim {
81 return Err(RloxError::ShapeMismatch {
82 expected: format!("act_dim={}", self.act_dim),
83 got: format!("action.len()={}", action.len()),
84 });
85 }
86 self.observations.extend_from_slice(obs);
87 self.next_observations.extend_from_slice(next_obs);
88 self.actions.extend_from_slice(action);
89 self.rewards.push(reward);
90 self.terminated.push(terminated);
91 self.truncated.push(truncated);
92 self.count += 1;
93 Ok(())
94 }
95
96 pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
98 self.push_slices(
99 &record.obs,
100 &record.next_obs,
101 &record.action,
102 record.reward,
103 record.terminated,
104 record.truncated,
105 )
106 }
107
108 pub fn observations_raw(&self) -> &[f32] {
110 &self.observations
111 }
112
113 pub fn actions_raw(&self) -> &[f32] {
115 &self.actions
116 }
117
118 pub fn rewards_raw(&self) -> &[f32] {
120 &self.rewards
121 }
122
123 pub fn terminated(&self) -> &[bool] {
125 &self.terminated
126 }
127
128 pub fn truncated(&self) -> &[bool] {
130 &self.truncated
131 }
132
133 pub fn clear(&mut self) {
135 self.observations.clear();
136 self.next_observations.clear();
137 self.actions.clear();
138 self.rewards.clear();
139 self.terminated.clear();
140 self.truncated.clear();
141 self.count = 0;
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::buffer::sample_record;
149
150 #[test]
151 fn empty_table_has_zero_len() {
152 let table = ExperienceTable::new(4, 1);
153 assert_eq!(table.len(), 0);
154 assert!(table.is_empty());
155 }
156
157 #[test]
158 fn push_single_transition_increments_len() {
159 let mut table = ExperienceTable::new(4, 1);
160 table.push(sample_record(4)).unwrap();
161 assert_eq!(table.len(), 1);
162 }
163
164 #[test]
165 fn push_many_transitions() {
166 let mut table = ExperienceTable::new(4, 1);
167 for _ in 0..1000 {
168 table.push(sample_record(4)).unwrap();
169 }
170 assert_eq!(table.len(), 1000);
171 }
172
173 #[test]
174 fn observations_column_correct_length() {
175 let mut table = ExperienceTable::new(4, 1);
176 for _ in 0..10 {
177 table.push(sample_record(4)).unwrap();
178 }
179 assert_eq!(table.observations_raw().len(), 40);
180 }
181
182 #[test]
183 fn rewards_column_correct_values() {
184 let mut table = ExperienceTable::new(4, 1);
185 let mut r = sample_record(4);
186 r.reward = 42.0;
187 table.push(r).unwrap();
188 assert_eq!(table.rewards_raw()[0], 42.0);
189 }
190
191 #[test]
192 fn obs_dim_mismatch_returns_error() {
193 let mut table = ExperienceTable::new(4, 1);
194 let bad = sample_record(8);
195 assert!(table.push(bad).is_err());
196 }
197
198 #[test]
199 fn clear_empties_all_columns() {
200 let mut table = ExperienceTable::new(4, 1);
201 for _ in 0..100 {
202 table.push(sample_record(4)).unwrap();
203 }
204 table.clear();
205 assert_eq!(table.len(), 0);
206 assert!(table.observations_raw().is_empty());
207 }
208
209 #[test]
210 fn obs_dim_getter() {
211 let table = ExperienceTable::new(4, 1);
212 assert_eq!(table.obs_dim(), 4);
213 }
214}