1use rand::Rng;
2use rand::SeedableRng;
3use rand_chacha::ChaCha8Rng;
4
5use crate::error::RloxError;
6
7use super::extra_columns::{ColumnHandle, ExtraColumns};
8use super::ExperienceRecord;
9
10#[derive(Debug)]
19pub struct ReplayBuffer {
20 obs_dim: usize,
21 act_dim: usize,
22 capacity: usize,
23 observations: Vec<f32>,
24 next_observations: Vec<f32>,
25 actions: Vec<f32>,
26 rewards: Vec<f32>,
27 terminated: Vec<bool>,
28 truncated: Vec<bool>,
29 write_pos: usize,
30 count: usize,
31 extra: ExtraColumns,
32}
33
34#[derive(Debug, Clone)]
36pub struct SampledBatch {
37 pub observations: Vec<f32>,
38 pub next_observations: Vec<f32>,
39 pub actions: Vec<f32>,
40 pub rewards: Vec<f32>,
41 pub terminated: Vec<bool>,
42 pub truncated: Vec<bool>,
43 pub obs_dim: usize,
44 pub act_dim: usize,
45 pub batch_size: usize,
46 pub extra: Vec<(String, Vec<f32>)>,
50}
51
52impl SampledBatch {
53 pub fn with_capacity(batch_size: usize, obs_dim: usize, act_dim: usize) -> Self {
54 Self {
55 observations: Vec::with_capacity(batch_size * obs_dim),
56 next_observations: Vec::with_capacity(batch_size * obs_dim),
57 actions: Vec::with_capacity(batch_size * act_dim),
58 rewards: Vec::with_capacity(batch_size),
59 terminated: Vec::with_capacity(batch_size),
60 truncated: Vec::with_capacity(batch_size),
61 obs_dim,
62 act_dim,
63 batch_size: 0,
64 extra: Vec::new(),
65 }
66 }
67
68 pub fn clear(&mut self) {
75 self.observations.clear();
76 self.next_observations.clear();
77 self.actions.clear();
78 self.rewards.clear();
79 self.terminated.clear();
80 self.truncated.clear();
81 self.extra.clear();
82 self.batch_size = 0;
83 }
84}
85
86impl ReplayBuffer {
87 pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
89 Self {
90 obs_dim,
91 act_dim,
92 capacity,
93 observations: vec![0.0; capacity * obs_dim],
94 next_observations: vec![0.0; capacity * obs_dim],
95 actions: vec![0.0; capacity * act_dim],
96 rewards: vec![0.0; capacity],
97 terminated: vec![false; capacity],
98 truncated: vec![false; capacity],
99 write_pos: 0,
100 count: 0,
101 extra: ExtraColumns::new(),
102 }
103 }
104
105 pub fn register_column(&mut self, name: &str, dim: usize) -> ColumnHandle {
111 let handle = self.extra.register(name, dim);
112 self.extra.allocate(self.capacity);
113 handle
114 }
115
116 pub fn push_extra(&mut self, handle: ColumnHandle, values: &[f32]) -> Result<(), RloxError> {
121 if self.count == 0 {
122 return Err(RloxError::BufferError(
123 "push_extra called before any push()".into(),
124 ));
125 }
126 let pos = if self.write_pos == 0 {
128 self.capacity - 1
129 } else {
130 self.write_pos - 1
131 };
132 self.extra.push(handle, pos, values)
133 }
134
135 pub fn obs_dim(&self) -> usize {
137 self.obs_dim
138 }
139
140 pub fn act_dim(&self) -> usize {
142 self.act_dim
143 }
144
145 pub fn len(&self) -> usize {
147 self.count
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.count == 0
153 }
154
155 pub(crate) fn write_pos(&self) -> usize {
157 self.write_pos
158 }
159
160 pub(crate) fn get(&self, idx: usize) -> (&[f32], &[f32], &[f32], f32, bool, bool) {
168 assert!(
169 idx < self.count,
170 "index {idx} out of bounds (count={})",
171 self.count
172 );
173 let obs_start = idx * self.obs_dim;
174 let act_start = idx * self.act_dim;
175 (
176 &self.observations[obs_start..obs_start + self.obs_dim],
177 &self.next_observations[obs_start..obs_start + self.obs_dim],
178 &self.actions[act_start..act_start + self.act_dim],
179 self.rewards[idx],
180 self.terminated[idx],
181 self.truncated[idx],
182 )
183 }
184
185 pub fn push_slices(
187 &mut self,
188 obs: &[f32],
189 next_obs: &[f32],
190 action: &[f32],
191 reward: f32,
192 terminated: bool,
193 truncated: bool,
194 ) -> Result<(), RloxError> {
195 if obs.len() != self.obs_dim {
196 return Err(RloxError::ShapeMismatch {
197 expected: format!("obs_dim={}", self.obs_dim),
198 got: format!("obs.len()={}", obs.len()),
199 });
200 }
201 if next_obs.len() != self.obs_dim {
202 return Err(RloxError::ShapeMismatch {
203 expected: format!("obs_dim={}", self.obs_dim),
204 got: format!("next_obs.len()={}", next_obs.len()),
205 });
206 }
207 if action.len() != self.act_dim {
208 return Err(RloxError::ShapeMismatch {
209 expected: format!("act_dim={}", self.act_dim),
210 got: format!("action.len()={}", action.len()),
211 });
212 }
213 let idx = self.write_pos;
214 let obs_start = idx * self.obs_dim;
215 self.observations[obs_start..obs_start + self.obs_dim].copy_from_slice(obs);
216 self.next_observations[obs_start..obs_start + self.obs_dim].copy_from_slice(next_obs);
217 let act_start = idx * self.act_dim;
218 self.actions[act_start..act_start + self.act_dim].copy_from_slice(action);
219 self.rewards[idx] = reward;
220 self.terminated[idx] = terminated;
221 self.truncated[idx] = truncated;
222
223 self.write_pos = (self.write_pos + 1) % self.capacity;
224 if self.count < self.capacity {
225 self.count += 1;
226 }
227 Ok(())
228 }
229
230 pub fn push_batch(
235 &mut self,
236 obs_batch: &[f32],
237 next_obs_batch: &[f32],
238 actions_batch: &[f32],
239 rewards: &[f32],
240 terminated: &[bool],
241 truncated: &[bool],
242 ) -> Result<(), RloxError> {
243 let n = rewards.len();
244 if obs_batch.len() != n * self.obs_dim
245 || next_obs_batch.len() != n * self.obs_dim
246 || actions_batch.len() != n * self.act_dim
247 || terminated.len() != n
248 || truncated.len() != n
249 {
250 return Err(RloxError::ShapeMismatch {
251 expected: format!("n={n}, obs_dim={}, act_dim={}", self.obs_dim, self.act_dim),
252 got: format!(
253 "obs={}, next_obs={}, act={}, rew={}, term={}, trunc={}",
254 obs_batch.len(),
255 next_obs_batch.len(),
256 actions_batch.len(),
257 rewards.len(),
258 terminated.len(),
259 truncated.len()
260 ),
261 });
262 }
263 for i in 0..n {
264 let obs = &obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim];
265 let next_obs = &next_obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim];
266 let action = &actions_batch[i * self.act_dim..(i + 1) * self.act_dim];
267 self.push_slices(
268 obs,
269 next_obs,
270 action,
271 rewards[i],
272 terminated[i],
273 truncated[i],
274 )?;
275 }
276 Ok(())
277 }
278
279 pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
284 self.push_slices(
285 &record.obs,
286 &record.next_obs,
287 &record.action,
288 record.reward,
289 record.terminated,
290 record.truncated,
291 )
292 }
293
294 pub fn sample(&self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
302 if batch_size > self.count {
303 return Err(RloxError::BufferError(format!(
304 "batch_size {} > buffer len {}",
305 batch_size, self.count
306 )));
307 }
308 let mut rng = ChaCha8Rng::seed_from_u64(seed);
309 let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
310
311 let has_extra = self.extra.num_columns() > 0;
312 let mut indices = if has_extra {
313 Vec::with_capacity(batch_size)
314 } else {
315 Vec::new()
316 };
317
318 for _ in 0..batch_size {
319 let idx = rng.random_range(0..self.count);
320 let obs_start = idx * self.obs_dim;
321 batch
322 .observations
323 .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
324 batch
325 .next_observations
326 .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
327 let act_start = idx * self.act_dim;
328 batch
329 .actions
330 .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
331 batch.rewards.push(self.rewards[idx]);
332 batch.terminated.push(self.terminated[idx]);
333 batch.truncated.push(self.truncated[idx]);
334
335 if has_extra {
336 indices.push(idx);
337 }
338 }
339 batch.batch_size = batch_size;
340
341 if has_extra {
342 batch.extra = self.extra.sample_all(&indices);
343 }
344
345 Ok(batch)
346 }
347
348 pub fn sample_into(
352 &self,
353 batch: &mut SampledBatch,
354 batch_size: usize,
355 seed: u64,
356 ) -> Result<(), RloxError> {
357 if batch_size > self.count {
358 return Err(RloxError::BufferError(format!(
359 "batch_size {} > buffer len {}",
360 batch_size, self.count
361 )));
362 }
363 batch.clear();
364 batch.obs_dim = self.obs_dim;
365 batch.act_dim = self.act_dim;
366
367 let mut rng = ChaCha8Rng::seed_from_u64(seed);
368 let has_extra = self.extra.num_columns() > 0;
369 let mut indices = if has_extra {
370 Vec::with_capacity(batch_size)
371 } else {
372 Vec::new()
373 };
374
375 for _ in 0..batch_size {
376 let idx = rng.random_range(0..self.count);
377 let obs_start = idx * self.obs_dim;
378 batch
379 .observations
380 .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
381 batch
382 .next_observations
383 .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
384 let act_start = idx * self.act_dim;
385 batch
386 .actions
387 .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
388 batch.rewards.push(self.rewards[idx]);
389 batch.terminated.push(self.terminated[idx]);
390 batch.truncated.push(self.truncated[idx]);
391 if has_extra {
392 indices.push(idx);
393 }
394 }
395 batch.batch_size = batch_size;
396 if has_extra {
397 batch.extra = self.extra.sample_all(&indices);
398 }
399 Ok(())
400 }
401
402 pub fn extra_columns(&self) -> &ExtraColumns {
404 &self.extra
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::buffer::sample_record;
412
413 #[test]
414 fn ring_buffer_respects_capacity() {
415 let mut buf = ReplayBuffer::new(100, 4, 1);
416 for _ in 0..200 {
417 buf.push(sample_record(4)).unwrap();
418 }
419 assert_eq!(buf.len(), 100);
420 }
421
422 #[test]
423 fn ring_buffer_overwrites_oldest() {
424 let mut buf = ReplayBuffer::new(3, 4, 1);
425 for i in 0..5 {
426 let mut r = sample_record(4);
427 r.reward = i as f32;
428 buf.push(r).unwrap();
429 }
430 let batch = buf.sample(3, 42).unwrap();
432 assert!(!batch.rewards.contains(&0.0));
433 assert!(!batch.rewards.contains(&1.0));
434 }
435
436 #[test]
437 fn sample_returns_requested_size() {
438 let mut buf = ReplayBuffer::new(1000, 4, 1);
439 for _ in 0..1000 {
440 buf.push(sample_record(4)).unwrap();
441 }
442 let batch = buf.sample(64, 42).unwrap();
443 assert_eq!(batch.batch_size, 64);
444 assert_eq!(batch.observations.len(), 64 * 4);
445 }
446
447 #[test]
448 fn sample_errors_when_too_few() {
449 let mut buf = ReplayBuffer::new(100, 4, 1);
450 buf.push(sample_record(4)).unwrap();
451 assert!(buf.sample(32, 42).is_err());
452 }
453
454 #[test]
455 fn sample_is_deterministic_with_same_seed() {
456 let mut buf = ReplayBuffer::new(1000, 4, 1);
457 for _ in 0..1000 {
458 buf.push(sample_record(4)).unwrap();
459 }
460 let b1 = buf.sample(32, 42).unwrap();
461 let b2 = buf.sample(32, 42).unwrap();
462 assert_eq!(b1.observations, b2.observations);
463 assert_eq!(b1.rewards, b2.rewards);
464 }
465
466 #[test]
467 fn replay_buffer_is_send_sync() {
468 fn assert_send_sync<T: Send + Sync>() {}
469 assert_send_sync::<ReplayBuffer>();
470 }
471
472 #[test]
473 fn empty_buffer_has_zero_len() {
474 let buf = ReplayBuffer::new(100, 4, 1);
475 assert_eq!(buf.len(), 0);
476 assert!(buf.is_empty());
477 }
478
479 #[test]
480 fn test_replay_buffer_next_obs_roundtrip() {
481 let obs_dim = 4;
482 let mut buf = ReplayBuffer::new(100, obs_dim, 1);
483 let record = ExperienceRecord {
484 obs: vec![1.0; obs_dim],
485 next_obs: vec![2.0, 3.0, 4.0, 5.0],
486 action: vec![0.0],
487 reward: 1.0,
488 terminated: false,
489 truncated: false,
490 };
491 buf.push(record).unwrap();
492 let batch = buf.sample(1, 42).unwrap();
493 assert_eq!(&batch.next_observations, &[2.0, 3.0, 4.0, 5.0]);
494 }
495
496 #[test]
497 fn test_replay_buffer_next_obs_shape() {
498 let obs_dim = 4;
499 let mut buf = ReplayBuffer::new(1000, obs_dim, 1);
500 for _ in 0..100 {
501 buf.push(sample_record(obs_dim)).unwrap();
502 }
503 let batch = buf.sample(32, 42).unwrap();
504 assert_eq!(batch.next_observations.len(), 32 * obs_dim);
505 }
506
507 #[test]
508 fn test_replay_buffer_next_obs_dim_mismatch_errors() {
509 let mut buf = ReplayBuffer::new(100, 4, 1);
510 let record = ExperienceRecord {
511 obs: vec![1.0; 4],
512 next_obs: vec![2.0; 3], action: vec![0.0],
514 reward: 1.0,
515 terminated: false,
516 truncated: false,
517 };
518 let result = buf.push(record);
519 assert!(result.is_err());
520 assert!(result.unwrap_err().to_string().contains("next_obs"));
521 }
522
523 #[test]
524 fn test_replay_buffer_with_extra_columns_roundtrip() {
525 let mut buf = ReplayBuffer::new(100, 4, 1);
526 let lp = buf.register_column("log_prob", 1);
527 let val = buf.register_column("value", 1);
528
529 for i in 0..10 {
530 buf.push(sample_record(4)).unwrap();
531 buf.push_extra(lp, &[i as f32 * 0.1]).unwrap();
532 buf.push_extra(val, &[i as f32]).unwrap();
533 }
534
535 let batch = buf.sample(5, 42).unwrap();
536 assert_eq!(batch.extra.len(), 2);
537 assert_eq!(batch.extra[0].0, "log_prob");
538 assert_eq!(batch.extra[0].1.len(), 5); assert_eq!(batch.extra[1].0, "value");
540 assert_eq!(batch.extra[1].1.len(), 5);
541 }
542
543 #[test]
544 fn test_replay_buffer_no_extra_columns_has_empty_extra() {
545 let mut buf = ReplayBuffer::new(100, 4, 1);
546 for _ in 0..10 {
547 buf.push(sample_record(4)).unwrap();
548 }
549 let batch = buf.sample(5, 42).unwrap();
550 assert!(batch.extra.is_empty());
551 }
552
553 #[test]
554 fn test_push_extra_before_push_errors() {
555 let mut buf = ReplayBuffer::new(100, 4, 1);
556 let h = buf.register_column("test", 1);
557 let result = buf.push_extra(h, &[1.0]);
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_extra_columns_multidim_roundtrip() {
563 let mut buf = ReplayBuffer::new(100, 4, 1);
564 let h = buf.register_column("action_mean", 3);
565
566 for i in 0..5 {
567 buf.push(sample_record(4)).unwrap();
568 let v = i as f32;
569 buf.push_extra(h, &[v, v + 1.0, v + 2.0]).unwrap();
570 }
571
572 let batch = buf.sample(3, 42).unwrap();
573 assert_eq!(batch.extra.len(), 1);
574 assert_eq!(batch.extra[0].0, "action_mean");
575 assert_eq!(batch.extra[0].1.len(), 9); }
577
578 #[test]
579 fn test_sample_into_matches_sample() {
580 let mut buf = ReplayBuffer::new(100, 4, 1);
581 for _ in 0..50 {
582 buf.push(sample_record(4)).unwrap();
583 }
584
585 let batch1 = buf.sample(16, 42).unwrap();
586 let mut reusable = SampledBatch::with_capacity(16, 4, 1);
587 buf.sample_into(&mut reusable, 16, 42).unwrap();
588
589 assert_eq!(batch1.observations, reusable.observations);
590 assert_eq!(batch1.next_observations, reusable.next_observations);
591 assert_eq!(batch1.actions, reusable.actions);
592 assert_eq!(batch1.rewards, reusable.rewards);
593 assert_eq!(batch1.terminated, reusable.terminated);
594 assert_eq!(batch1.batch_size, reusable.batch_size);
595 }
596
597 #[test]
598 fn test_sample_into_reuses_capacity() {
599 let mut buf = ReplayBuffer::new(100, 4, 1);
600 for _ in 0..50 {
601 buf.push(sample_record(4)).unwrap();
602 }
603
604 let mut batch = SampledBatch::with_capacity(16, 4, 1);
605 buf.sample_into(&mut batch, 16, 1).unwrap();
606 let obs_cap = batch.observations.capacity();
607
608 buf.sample_into(&mut batch, 16, 2).unwrap();
610 assert!(batch.observations.capacity() >= obs_cap);
611 }
612
613 mod proptests {
614 use super::*;
615 use proptest::prelude::*;
616
617 proptest! {
618 #[test]
619 fn ring_buffer_never_exceeds_capacity(capacity in 1..500usize, num_pushes in 0..2000usize) {
620 let mut buf = ReplayBuffer::new(capacity, 4, 1);
621 for _ in 0..num_pushes {
622 buf.push(sample_record(4)).unwrap();
623 }
624 prop_assert!(buf.len() <= capacity);
625 }
626
627 #[test]
628 fn ring_buffer_len_is_min_of_pushes_and_capacity(capacity in 1..500usize, num_pushes in 0..2000usize) {
629 let mut buf = ReplayBuffer::new(capacity, 4, 1);
630 for _ in 0..num_pushes {
631 buf.push(sample_record(4)).unwrap();
632 }
633 prop_assert_eq!(buf.len(), num_pushes.min(capacity));
634 }
635
636 #[test]
637 fn sample_returns_requested_size_prop(capacity in 10..500usize, num_pushes in 10..2000usize, batch_size in 1..50usize) {
638 let mut buf = ReplayBuffer::new(capacity, 4, 1);
639 for _ in 0..num_pushes {
640 buf.push(sample_record(4)).unwrap();
641 }
642 let effective_batch = batch_size.min(buf.len());
643 let batch = buf.sample(effective_batch, 42).unwrap();
644 prop_assert_eq!(batch.batch_size, effective_batch);
645 }
646 }
647 }
648}