1use crate::error::RloxError;
8
9use super::episode::EpisodeTracker;
10use super::ringbuf::ReplayBuffer;
11
12#[derive(Debug)]
18pub struct SequenceReplayBuffer {
19 buffer: ReplayBuffer,
20 tracker: EpisodeTracker,
21 obs_dim: usize,
22 act_dim: usize,
23 capacity: usize,
24}
25
26#[derive(Debug, Clone)]
28pub struct SequenceBatch {
29 pub observations: Vec<f32>,
31 pub next_observations: Vec<f32>,
33 pub actions: Vec<f32>,
35 pub rewards: Vec<f32>,
37 pub terminated: Vec<bool>,
39 pub truncated: Vec<bool>,
41 pub obs_dim: usize,
42 pub act_dim: usize,
43 pub batch_size: usize,
44 pub seq_len: usize,
45}
46
47impl SequenceReplayBuffer {
48 pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
50 Self {
51 buffer: ReplayBuffer::new(capacity, obs_dim, act_dim),
52 tracker: EpisodeTracker::new(capacity),
53 obs_dim,
54 act_dim,
55 capacity,
56 }
57 }
58
59 pub fn push_slices(
61 &mut self,
62 obs: &[f32],
63 next_obs: &[f32],
64 action: &[f32],
65 reward: f32,
66 terminated: bool,
67 truncated: bool,
68 ) -> Result<(), RloxError> {
69 let write_pos = self.buffer.write_pos();
70 let was_full = self.buffer.len() == self.capacity;
71
72 if was_full {
74 self.tracker.invalidate_overwritten(write_pos, 1);
75 }
76
77 self.buffer
78 .push_slices(obs, next_obs, action, reward, terminated, truncated)?;
79
80 let done = terminated || truncated;
81 self.tracker.notify_push(write_pos, done);
82
83 Ok(())
84 }
85
86 pub fn sample_sequences(
90 &self,
91 batch_size: usize,
92 seq_len: usize,
93 seed: u64,
94 ) -> Result<SequenceBatch, RloxError> {
95 let windows = self.tracker.sample_windows(batch_size, seq_len, seed)?;
96
97 let total_obs = batch_size * seq_len * self.obs_dim;
98 let total_act = batch_size * seq_len * self.act_dim;
99 let total_flat = batch_size * seq_len;
100
101 let mut batch = SequenceBatch {
102 observations: Vec::with_capacity(total_obs),
103 next_observations: Vec::with_capacity(total_obs),
104 actions: Vec::with_capacity(total_act),
105 rewards: Vec::with_capacity(total_flat),
106 terminated: Vec::with_capacity(total_flat),
107 truncated: Vec::with_capacity(total_flat),
108 obs_dim: self.obs_dim,
109 act_dim: self.act_dim,
110 batch_size,
111 seq_len,
112 };
113
114 for window in &windows {
115 for offset in 0..seq_len {
116 let idx = (window.ring_start + offset) % self.capacity;
117 let (obs, next_obs, action, reward, terminated, truncated) = self.buffer.get(idx);
118 batch.observations.extend_from_slice(obs);
119 batch.next_observations.extend_from_slice(next_obs);
120 batch.actions.extend_from_slice(action);
121 batch.rewards.push(reward);
122 batch.terminated.push(terminated);
123 batch.truncated.push(truncated);
124 }
125 }
126
127 Ok(batch)
128 }
129
130 pub fn sample(
132 &self,
133 batch_size: usize,
134 seed: u64,
135 ) -> Result<super::ringbuf::SampledBatch, RloxError> {
136 self.buffer.sample(batch_size, seed)
137 }
138
139 pub fn len(&self) -> usize {
141 self.buffer.len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.buffer.is_empty()
147 }
148
149 pub fn num_complete_episodes(&self) -> usize {
151 self.tracker.num_complete_episodes()
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 fn push_episode(buf: &mut SequenceReplayBuffer, length: usize, obs_base: f32) {
161 let obs_dim = buf.obs_dim;
162 let act_dim = buf.act_dim;
163 for i in 0..length {
164 let val = obs_base + i as f32;
165 let obs = vec![val; obs_dim];
166 let next_obs = vec![val + 1.0; obs_dim];
167 let action = vec![0.0; act_dim];
168 let reward = val;
169 let done = i == length - 1;
170 buf.push_slices(&obs, &next_obs, &action, reward, done, false)
171 .unwrap();
172 }
173 }
174
175 #[test]
176 fn test_new_is_empty() {
177 let buf = SequenceReplayBuffer::new(100, 4, 1);
178 assert_eq!(buf.len(), 0);
179 assert!(buf.is_empty());
180 assert_eq!(buf.num_complete_episodes(), 0);
181 }
182
183 #[test]
184 fn test_push_increments_len() {
185 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
186 push_episode(&mut buf, 5, 0.0);
187 assert_eq!(buf.len(), 5);
188 assert_eq!(buf.num_complete_episodes(), 1);
189 }
190
191 #[test]
192 fn test_single_episode_sequence_sample() {
193 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
194 push_episode(&mut buf, 10, 0.0);
195 let batch = buf.sample_sequences(1, 3, 42).unwrap();
196 assert_eq!(batch.batch_size, 1);
197 assert_eq!(batch.seq_len, 3);
198 assert_eq!(batch.observations.len(), 3 * 4);
199 assert_eq!(batch.rewards.len(), 3);
200 }
201
202 #[test]
203 fn test_sequences_dont_cross_episodes() {
204 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
205 push_episode(&mut buf, 5, 0.0);
207 push_episode(&mut buf, 5, 100.0);
208 assert_eq!(buf.num_complete_episodes(), 2);
209
210 let batch = buf.sample_sequences(20, 4, 42).unwrap();
212
213 for seq_idx in 0..20 {
215 let rewards: Vec<f32> = (0..4).map(|t| batch.rewards[seq_idx * 4 + t]).collect();
216 let all_low = rewards.iter().all(|&r| r < 50.0);
217 let all_high = rewards.iter().all(|&r| r >= 50.0);
218 assert!(
219 all_low || all_high,
220 "sequence {seq_idx} crosses episode boundary: {rewards:?}"
221 );
222 }
223 }
224
225 #[test]
226 fn test_sequence_contiguity() {
227 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
228 push_episode(&mut buf, 10, 0.0);
229
230 let batch = buf.sample_sequences(5, 5, 42).unwrap();
231 let obs_dim = 4;
232 let seq_len = 5;
233
234 for seq_idx in 0..5 {
235 for t in 0..(seq_len - 1) {
236 let next_obs_start = (seq_idx * seq_len + t) * obs_dim;
237 let obs_next_start = (seq_idx * seq_len + t + 1) * obs_dim;
238
239 let next_obs = &batch.next_observations[next_obs_start..next_obs_start + obs_dim];
240 let obs_t1 = &batch.observations[obs_next_start..obs_next_start + obs_dim];
241
242 assert_eq!(
243 next_obs,
244 obs_t1,
245 "next_obs[{t}] != obs[{t_plus_1}] in seq {seq_idx}",
246 t_plus_1 = t + 1
247 );
248 }
249 }
250 }
251
252 #[test]
253 fn test_sequence_deterministic() {
254 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
255 push_episode(&mut buf, 10, 0.0);
256 let b1 = buf.sample_sequences(5, 3, 42).unwrap();
257 let b2 = buf.sample_sequences(5, 3, 42).unwrap();
258 assert_eq!(b1.observations, b2.observations);
259 assert_eq!(b1.rewards, b2.rewards);
260 }
261
262 #[test]
263 fn test_reject_too_long_sequence() {
264 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
265 push_episode(&mut buf, 3, 0.0);
266 let result = buf.sample_sequences(1, 5, 42);
267 assert!(result.is_err());
268 }
269
270 #[test]
271 fn test_capacity_respected() {
272 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
273 for i in 0..20 {
275 push_episode(&mut buf, 10, i as f32 * 100.0);
276 }
277 assert_eq!(buf.len(), 100);
278 }
279
280 #[test]
281 fn test_batch_shape_correct() {
282 let mut buf = SequenceReplayBuffer::new(200, 8, 2);
283 push_episode(&mut buf, 20, 0.0);
284
285 let batch = buf.sample_sequences(4, 3, 42).unwrap();
286 assert_eq!(batch.observations.len(), 4 * 3 * 8);
287 assert_eq!(batch.next_observations.len(), 4 * 3 * 8);
288 assert_eq!(batch.actions.len(), 4 * 3 * 2);
289 assert_eq!(batch.rewards.len(), 4 * 3);
290 assert_eq!(batch.terminated.len(), 4 * 3);
291 assert_eq!(batch.truncated.len(), 4 * 3);
292 }
293
294 #[test]
295 fn test_empty_buffer_sample_errors() {
296 let buf = SequenceReplayBuffer::new(100, 4, 1);
297 let result = buf.sample_sequences(1, 1, 42);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn test_push_slices_validates_dims() {
303 let mut buf = SequenceReplayBuffer::new(100, 4, 1);
304 let result = buf.push_slices(
305 &[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0, 4.0],
307 &[0.0],
308 1.0,
309 false,
310 false,
311 );
312 assert!(result.is_err());
313 }
314
315 #[test]
316 fn test_multiple_episodes_mixed_lengths() {
317 let mut buf = SequenceReplayBuffer::new(200, 4, 1);
318 push_episode(&mut buf, 3, 0.0);
319 push_episode(&mut buf, 7, 100.0);
320 push_episode(&mut buf, 2, 200.0);
321 push_episode(&mut buf, 10, 300.0);
322
323 let batch = buf.sample_sequences(10, 4, 42).unwrap();
325 for seq_idx in 0..10 {
326 let first_reward = batch.rewards[seq_idx * 4];
327 assert!(
329 first_reward >= 90.0,
330 "seq {seq_idx} sampled from too-short episode: reward={first_reward}"
331 );
332 }
333 }
334
335 #[test]
336 fn test_sequence_rewards_match_buffer() {
337 let mut buf = SequenceReplayBuffer::new(100, 2, 1);
338 for i in 0..5 {
340 let val = (i + 1) as f32 * 10.0;
341 buf.push_slices(
342 &[val, val],
343 &[val + 1.0, val + 1.0],
344 &[0.0],
345 val,
346 i == 4,
347 false,
348 )
349 .unwrap();
350 }
351
352 let batch = buf.sample_sequences(1, 5, 42).unwrap();
354 let expected = vec![10.0, 20.0, 30.0, 40.0, 50.0];
356 assert_eq!(batch.rewards, expected);
357 }
358
359 #[test]
360 fn test_is_send_sync() {
361 fn assert_send_sync<T: Send + Sync>() {}
362 assert_send_sync::<SequenceReplayBuffer>();
363 }
364
365 mod proptests {
366 use super::*;
367 use proptest::prelude::*;
368
369 proptest! {
370 #[test]
371 fn prop_batch_size_matches_request(
372 batch_size in 1usize..10,
373 seq_len in 1usize..5,
374 ep_len in 5usize..20,
375 ) {
376 let cap = ep_len * 5;
377 let mut buf = SequenceReplayBuffer::new(cap, 4, 1);
378 push_episode(&mut buf, ep_len, 0.0);
379 push_episode(&mut buf, ep_len, 100.0);
380 let batch = buf.sample_sequences(batch_size, seq_len, 42).unwrap();
381 prop_assert_eq!(batch.batch_size, batch_size);
382 prop_assert_eq!(batch.seq_len, seq_len);
383 }
384
385 #[test]
386 fn prop_len_never_exceeds_capacity(
387 cap in 10usize..100,
388 n_pushes in 1usize..300,
389 ) {
390 let mut buf = SequenceReplayBuffer::new(cap, 4, 1);
391 for i in 0..n_pushes {
392 let done = i % 7 == 6;
393 buf.push_slices(
394 &[i as f32; 4],
395 &[(i + 1) as f32; 4],
396 &[0.0],
397 i as f32,
398 done,
399 false,
400 ).unwrap();
401 }
402 prop_assert!(buf.len() <= cap);
403 }
404
405 #[test]
406 fn prop_sequence_obs_contiguous(
407 ep_len in 5usize..20,
408 seq_len in 2usize..5,
409 batch_size in 1usize..5,
410 ) {
411 let mut buf = SequenceReplayBuffer::new(ep_len * 3, 4, 1);
412 push_episode(&mut buf, ep_len, 0.0);
413 let batch = buf.sample_sequences(batch_size, seq_len, 42).unwrap();
414 let obs_dim = 4;
415 for seq_idx in 0..batch_size {
416 for t in 0..(seq_len - 1) {
417 let next_start = (seq_idx * seq_len + t) * obs_dim;
418 let obs_next_start = (seq_idx * seq_len + t + 1) * obs_dim;
419 let next_obs = &batch.next_observations[next_start..next_start + obs_dim];
420 let obs_t1 = &batch.observations[obs_next_start..obs_next_start + obs_dim];
421 prop_assert_eq!(next_obs, obs_t1);
422 }
423 }
424 }
425 }
426 }
427}