rlox_core/buffer/
ringbuf.rs

1use rand::Rng;
2use rand::SeedableRng;
3use rand_chacha::ChaCha8Rng;
4
5use crate::error::RloxError;
6
7use super::extra_columns::{ColumnHandle, ExtraColumns};
8use super::ExperienceRecord;
9
10/// Fixed-capacity ring buffer with uniform random sampling.
11///
12/// Pre-allocates all arrays at construction for zero-allocation push.
13/// Oldest transitions are overwritten when capacity is reached.
14///
15/// Supports optional extra f32 columns (e.g. log-probs, value estimates)
16/// via [`ColumnHandle`]. When no extra columns are registered, there is
17/// zero overhead — no allocations and no branches in the hot push/sample path.
18#[derive(Debug)]
19pub struct ReplayBuffer {
20    obs_dim: usize,
21    act_dim: usize,
22    capacity: usize,
23    observations: Vec<f32>,
24    next_observations: Vec<f32>,
25    actions: Vec<f32>,
26    rewards: Vec<f32>,
27    terminated: Vec<bool>,
28    truncated: Vec<bool>,
29    write_pos: usize,
30    count: usize,
31    extra: ExtraColumns,
32}
33
34/// A sampled batch of transitions. Owns its data (copied from the ring buffer).
35#[derive(Debug, Clone)]
36pub struct SampledBatch {
37    pub observations: Vec<f32>,
38    pub next_observations: Vec<f32>,
39    pub actions: Vec<f32>,
40    pub rewards: Vec<f32>,
41    pub terminated: Vec<bool>,
42    pub truncated: Vec<bool>,
43    pub obs_dim: usize,
44    pub act_dim: usize,
45    pub batch_size: usize,
46    /// Extra column data, populated only when columns are registered.
47    /// Each entry is `(column_name, flat_data)` where `flat_data` has
48    /// length `batch_size * column_dim`.
49    pub extra: Vec<(String, Vec<f32>)>,
50}
51
52impl SampledBatch {
53    pub fn with_capacity(batch_size: usize, obs_dim: usize, act_dim: usize) -> Self {
54        Self {
55            observations: Vec::with_capacity(batch_size * obs_dim),
56            next_observations: Vec::with_capacity(batch_size * obs_dim),
57            actions: Vec::with_capacity(batch_size * act_dim),
58            rewards: Vec::with_capacity(batch_size),
59            terminated: Vec::with_capacity(batch_size),
60            truncated: Vec::with_capacity(batch_size),
61            obs_dim,
62            act_dim,
63            batch_size: 0,
64            extra: Vec::new(),
65        }
66    }
67
68    /// Clear all data but retain allocated capacity for reuse.
69    ///
70    /// Note: `extra` is cleared entirely (the outer Vec). If you alternate
71    /// between buffers with different extra-column schemas, the inner Vecs'
72    /// capacity is lost. This is acceptable because cross-buffer reuse of
73    /// extra columns is uncommon.
74    pub fn clear(&mut self) {
75        self.observations.clear();
76        self.next_observations.clear();
77        self.actions.clear();
78        self.rewards.clear();
79        self.terminated.clear();
80        self.truncated.clear();
81        self.extra.clear();
82        self.batch_size = 0;
83    }
84}
85
86impl ReplayBuffer {
87    /// Create a ring buffer with fixed capacity. All arrays are pre-allocated.
88    pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
89        Self {
90            obs_dim,
91            act_dim,
92            capacity,
93            observations: vec![0.0; capacity * obs_dim],
94            next_observations: vec![0.0; capacity * obs_dim],
95            actions: vec![0.0; capacity * act_dim],
96            rewards: vec![0.0; capacity],
97            terminated: vec![false; capacity],
98            truncated: vec![false; capacity],
99            write_pos: 0,
100            count: 0,
101            extra: ExtraColumns::new(),
102        }
103    }
104
105    /// Register an extra f32 column with the given name and dimensionality.
106    ///
107    /// Returns a [`ColumnHandle`] for O(1) push/sample access.
108    /// Must be called before any `push()` — the column is pre-allocated to
109    /// match the buffer's capacity.
110    pub fn register_column(&mut self, name: &str, dim: usize) -> ColumnHandle {
111        let handle = self.extra.register(name, dim);
112        self.extra.allocate(self.capacity);
113        handle
114    }
115
116    /// Push extra column data for the most recently pushed transition.
117    ///
118    /// Must be called *after* `push()` and before the next `push()`.
119    /// The `values` slice length must match the column's registered dim.
120    pub fn push_extra(&mut self, handle: ColumnHandle, values: &[f32]) -> Result<(), RloxError> {
121        if self.count == 0 {
122            return Err(RloxError::BufferError(
123                "push_extra called before any push()".into(),
124            ));
125        }
126        // The most recently written position is one step behind write_pos
127        let pos = if self.write_pos == 0 {
128            self.capacity - 1
129        } else {
130            self.write_pos - 1
131        };
132        self.extra.push(handle, pos, values)
133    }
134
135    /// Observation dimensionality.
136    pub fn obs_dim(&self) -> usize {
137        self.obs_dim
138    }
139
140    /// Action dimensionality.
141    pub fn act_dim(&self) -> usize {
142        self.act_dim
143    }
144
145    /// Number of valid transitions currently stored.
146    pub fn len(&self) -> usize {
147        self.count
148    }
149
150    /// Whether the buffer is empty.
151    pub fn is_empty(&self) -> bool {
152        self.count == 0
153    }
154
155    /// Current write position in the ring buffer.
156    pub(crate) fn write_pos(&self) -> usize {
157        self.write_pos
158    }
159
160    /// Access the record at `idx` by reference.
161    ///
162    /// Returns `(obs_slice, next_obs_slice, action_slice, reward, terminated, truncated)`.
163    ///
164    /// # Panics
165    ///
166    /// Panics if `idx >= self.count`.
167    pub(crate) fn get(&self, idx: usize) -> (&[f32], &[f32], &[f32], f32, bool, bool) {
168        assert!(
169            idx < self.count,
170            "index {idx} out of bounds (count={})",
171            self.count
172        );
173        let obs_start = idx * self.obs_dim;
174        let act_start = idx * self.act_dim;
175        (
176            &self.observations[obs_start..obs_start + self.obs_dim],
177            &self.next_observations[obs_start..obs_start + self.obs_dim],
178            &self.actions[act_start..act_start + self.act_dim],
179            self.rewards[idx],
180            self.terminated[idx],
181            self.truncated[idx],
182        )
183    }
184
185    /// Push a transition from borrowed slices, avoiding intermediate allocation.
186    pub fn push_slices(
187        &mut self,
188        obs: &[f32],
189        next_obs: &[f32],
190        action: &[f32],
191        reward: f32,
192        terminated: bool,
193        truncated: bool,
194    ) -> Result<(), RloxError> {
195        if obs.len() != self.obs_dim {
196            return Err(RloxError::ShapeMismatch {
197                expected: format!("obs_dim={}", self.obs_dim),
198                got: format!("obs.len()={}", obs.len()),
199            });
200        }
201        if next_obs.len() != self.obs_dim {
202            return Err(RloxError::ShapeMismatch {
203                expected: format!("obs_dim={}", self.obs_dim),
204                got: format!("next_obs.len()={}", next_obs.len()),
205            });
206        }
207        if action.len() != self.act_dim {
208            return Err(RloxError::ShapeMismatch {
209                expected: format!("act_dim={}", self.act_dim),
210                got: format!("action.len()={}", action.len()),
211            });
212        }
213        let idx = self.write_pos;
214        let obs_start = idx * self.obs_dim;
215        self.observations[obs_start..obs_start + self.obs_dim].copy_from_slice(obs);
216        self.next_observations[obs_start..obs_start + self.obs_dim].copy_from_slice(next_obs);
217        let act_start = idx * self.act_dim;
218        self.actions[act_start..act_start + self.act_dim].copy_from_slice(action);
219        self.rewards[idx] = reward;
220        self.terminated[idx] = terminated;
221        self.truncated[idx] = truncated;
222
223        self.write_pos = (self.write_pos + 1) % self.capacity;
224        if self.count < self.capacity {
225            self.count += 1;
226        }
227        Ok(())
228    }
229
230    /// Push multiple transitions at once from flat arrays.
231    ///
232    /// `obs_batch` shape: `[n * obs_dim]`, `next_obs_batch`: same,
233    /// `actions_batch`: `[n * act_dim]`, others: `[n]`.
234    pub fn push_batch(
235        &mut self,
236        obs_batch: &[f32],
237        next_obs_batch: &[f32],
238        actions_batch: &[f32],
239        rewards: &[f32],
240        terminated: &[bool],
241        truncated: &[bool],
242    ) -> Result<(), RloxError> {
243        let n = rewards.len();
244        if obs_batch.len() != n * self.obs_dim
245            || next_obs_batch.len() != n * self.obs_dim
246            || actions_batch.len() != n * self.act_dim
247            || terminated.len() != n
248            || truncated.len() != n
249        {
250            return Err(RloxError::ShapeMismatch {
251                expected: format!("n={n}, obs_dim={}, act_dim={}", self.obs_dim, self.act_dim),
252                got: format!(
253                    "obs={}, next_obs={}, act={}, rew={}, term={}, trunc={}",
254                    obs_batch.len(),
255                    next_obs_batch.len(),
256                    actions_batch.len(),
257                    rewards.len(),
258                    terminated.len(),
259                    truncated.len()
260                ),
261            });
262        }
263        for i in 0..n {
264            let obs = &obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim];
265            let next_obs = &next_obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim];
266            let action = &actions_batch[i * self.act_dim..(i + 1) * self.act_dim];
267            self.push_slices(
268                obs,
269                next_obs,
270                action,
271                rewards[i],
272                terminated[i],
273                truncated[i],
274            )?;
275        }
276        Ok(())
277    }
278
279    /// Push a transition, overwriting the oldest if at capacity.
280    ///
281    /// Prefer [`push_slices`](Self::push_slices) to avoid the intermediate
282    /// `Vec<f32>` allocations inside `ExperienceRecord`.
283    pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
284        self.push_slices(
285            &record.obs,
286            &record.next_obs,
287            &record.action,
288            record.reward,
289            record.terminated,
290            record.truncated,
291        )
292    }
293
294    /// Sample a batch of transitions uniformly at random.
295    ///
296    /// Uses ChaCha8Rng seeded with `seed` for deterministic cross-platform
297    /// reproducibility. Returns owned `SampledBatch`.
298    ///
299    /// If extra columns have been registered, their data is included in
300    /// `SampledBatch::extra`.
301    pub fn sample(&self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
302        if batch_size > self.count {
303            return Err(RloxError::BufferError(format!(
304                "batch_size {} > buffer len {}",
305                batch_size, self.count
306            )));
307        }
308        let mut rng = ChaCha8Rng::seed_from_u64(seed);
309        let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
310
311        let has_extra = self.extra.num_columns() > 0;
312        let mut indices = if has_extra {
313            Vec::with_capacity(batch_size)
314        } else {
315            Vec::new()
316        };
317
318        for _ in 0..batch_size {
319            let idx = rng.random_range(0..self.count);
320            let obs_start = idx * self.obs_dim;
321            batch
322                .observations
323                .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
324            batch
325                .next_observations
326                .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
327            let act_start = idx * self.act_dim;
328            batch
329                .actions
330                .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
331            batch.rewards.push(self.rewards[idx]);
332            batch.terminated.push(self.terminated[idx]);
333            batch.truncated.push(self.truncated[idx]);
334
335            if has_extra {
336                indices.push(idx);
337            }
338        }
339        batch.batch_size = batch_size;
340
341        if has_extra {
342            batch.extra = self.extra.sample_all(&indices);
343        }
344
345        Ok(batch)
346    }
347
348    /// Sample into a pre-allocated batch, reusing its capacity.
349    ///
350    /// Same as `sample()` but avoids allocation by reusing `batch`.
351    pub fn sample_into(
352        &self,
353        batch: &mut SampledBatch,
354        batch_size: usize,
355        seed: u64,
356    ) -> Result<(), RloxError> {
357        if batch_size > self.count {
358            return Err(RloxError::BufferError(format!(
359                "batch_size {} > buffer len {}",
360                batch_size, self.count
361            )));
362        }
363        batch.clear();
364        batch.obs_dim = self.obs_dim;
365        batch.act_dim = self.act_dim;
366
367        let mut rng = ChaCha8Rng::seed_from_u64(seed);
368        let has_extra = self.extra.num_columns() > 0;
369        let mut indices = if has_extra {
370            Vec::with_capacity(batch_size)
371        } else {
372            Vec::new()
373        };
374
375        for _ in 0..batch_size {
376            let idx = rng.random_range(0..self.count);
377            let obs_start = idx * self.obs_dim;
378            batch
379                .observations
380                .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
381            batch
382                .next_observations
383                .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
384            let act_start = idx * self.act_dim;
385            batch
386                .actions
387                .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
388            batch.rewards.push(self.rewards[idx]);
389            batch.terminated.push(self.terminated[idx]);
390            batch.truncated.push(self.truncated[idx]);
391            if has_extra {
392                indices.push(idx);
393            }
394        }
395        batch.batch_size = batch_size;
396        if has_extra {
397            batch.extra = self.extra.sample_all(&indices);
398        }
399        Ok(())
400    }
401
402    /// Access the extra columns storage (for advanced use / testing).
403    pub fn extra_columns(&self) -> &ExtraColumns {
404        &self.extra
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::buffer::sample_record;
412
413    #[test]
414    fn ring_buffer_respects_capacity() {
415        let mut buf = ReplayBuffer::new(100, 4, 1);
416        for _ in 0..200 {
417            buf.push(sample_record(4)).unwrap();
418        }
419        assert_eq!(buf.len(), 100);
420    }
421
422    #[test]
423    fn ring_buffer_overwrites_oldest() {
424        let mut buf = ReplayBuffer::new(3, 4, 1);
425        for i in 0..5 {
426            let mut r = sample_record(4);
427            r.reward = i as f32;
428            buf.push(r).unwrap();
429        }
430        // Should contain rewards 2.0, 3.0, 4.0
431        let batch = buf.sample(3, 42).unwrap();
432        assert!(!batch.rewards.contains(&0.0));
433        assert!(!batch.rewards.contains(&1.0));
434    }
435
436    #[test]
437    fn sample_returns_requested_size() {
438        let mut buf = ReplayBuffer::new(1000, 4, 1);
439        for _ in 0..1000 {
440            buf.push(sample_record(4)).unwrap();
441        }
442        let batch = buf.sample(64, 42).unwrap();
443        assert_eq!(batch.batch_size, 64);
444        assert_eq!(batch.observations.len(), 64 * 4);
445    }
446
447    #[test]
448    fn sample_errors_when_too_few() {
449        let mut buf = ReplayBuffer::new(100, 4, 1);
450        buf.push(sample_record(4)).unwrap();
451        assert!(buf.sample(32, 42).is_err());
452    }
453
454    #[test]
455    fn sample_is_deterministic_with_same_seed() {
456        let mut buf = ReplayBuffer::new(1000, 4, 1);
457        for _ in 0..1000 {
458            buf.push(sample_record(4)).unwrap();
459        }
460        let b1 = buf.sample(32, 42).unwrap();
461        let b2 = buf.sample(32, 42).unwrap();
462        assert_eq!(b1.observations, b2.observations);
463        assert_eq!(b1.rewards, b2.rewards);
464    }
465
466    #[test]
467    fn replay_buffer_is_send_sync() {
468        fn assert_send_sync<T: Send + Sync>() {}
469        assert_send_sync::<ReplayBuffer>();
470    }
471
472    #[test]
473    fn empty_buffer_has_zero_len() {
474        let buf = ReplayBuffer::new(100, 4, 1);
475        assert_eq!(buf.len(), 0);
476        assert!(buf.is_empty());
477    }
478
479    #[test]
480    fn test_replay_buffer_next_obs_roundtrip() {
481        let obs_dim = 4;
482        let mut buf = ReplayBuffer::new(100, obs_dim, 1);
483        let record = ExperienceRecord {
484            obs: vec![1.0; obs_dim],
485            next_obs: vec![2.0, 3.0, 4.0, 5.0],
486            action: vec![0.0],
487            reward: 1.0,
488            terminated: false,
489            truncated: false,
490        };
491        buf.push(record).unwrap();
492        let batch = buf.sample(1, 42).unwrap();
493        assert_eq!(&batch.next_observations, &[2.0, 3.0, 4.0, 5.0]);
494    }
495
496    #[test]
497    fn test_replay_buffer_next_obs_shape() {
498        let obs_dim = 4;
499        let mut buf = ReplayBuffer::new(1000, obs_dim, 1);
500        for _ in 0..100 {
501            buf.push(sample_record(obs_dim)).unwrap();
502        }
503        let batch = buf.sample(32, 42).unwrap();
504        assert_eq!(batch.next_observations.len(), 32 * obs_dim);
505    }
506
507    #[test]
508    fn test_replay_buffer_next_obs_dim_mismatch_errors() {
509        let mut buf = ReplayBuffer::new(100, 4, 1);
510        let record = ExperienceRecord {
511            obs: vec![1.0; 4],
512            next_obs: vec![2.0; 3], // wrong dim
513            action: vec![0.0],
514            reward: 1.0,
515            terminated: false,
516            truncated: false,
517        };
518        let result = buf.push(record);
519        assert!(result.is_err());
520        assert!(result.unwrap_err().to_string().contains("next_obs"));
521    }
522
523    #[test]
524    fn test_replay_buffer_with_extra_columns_roundtrip() {
525        let mut buf = ReplayBuffer::new(100, 4, 1);
526        let lp = buf.register_column("log_prob", 1);
527        let val = buf.register_column("value", 1);
528
529        for i in 0..10 {
530            buf.push(sample_record(4)).unwrap();
531            buf.push_extra(lp, &[i as f32 * 0.1]).unwrap();
532            buf.push_extra(val, &[i as f32]).unwrap();
533        }
534
535        let batch = buf.sample(5, 42).unwrap();
536        assert_eq!(batch.extra.len(), 2);
537        assert_eq!(batch.extra[0].0, "log_prob");
538        assert_eq!(batch.extra[0].1.len(), 5); // batch_size * dim(1)
539        assert_eq!(batch.extra[1].0, "value");
540        assert_eq!(batch.extra[1].1.len(), 5);
541    }
542
543    #[test]
544    fn test_replay_buffer_no_extra_columns_has_empty_extra() {
545        let mut buf = ReplayBuffer::new(100, 4, 1);
546        for _ in 0..10 {
547            buf.push(sample_record(4)).unwrap();
548        }
549        let batch = buf.sample(5, 42).unwrap();
550        assert!(batch.extra.is_empty());
551    }
552
553    #[test]
554    fn test_push_extra_before_push_errors() {
555        let mut buf = ReplayBuffer::new(100, 4, 1);
556        let h = buf.register_column("test", 1);
557        let result = buf.push_extra(h, &[1.0]);
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_extra_columns_multidim_roundtrip() {
563        let mut buf = ReplayBuffer::new(100, 4, 1);
564        let h = buf.register_column("action_mean", 3);
565
566        for i in 0..5 {
567            buf.push(sample_record(4)).unwrap();
568            let v = i as f32;
569            buf.push_extra(h, &[v, v + 1.0, v + 2.0]).unwrap();
570        }
571
572        let batch = buf.sample(3, 42).unwrap();
573        assert_eq!(batch.extra.len(), 1);
574        assert_eq!(batch.extra[0].0, "action_mean");
575        assert_eq!(batch.extra[0].1.len(), 9); // 3 * 3
576    }
577
578    #[test]
579    fn test_sample_into_matches_sample() {
580        let mut buf = ReplayBuffer::new(100, 4, 1);
581        for _ in 0..50 {
582            buf.push(sample_record(4)).unwrap();
583        }
584
585        let batch1 = buf.sample(16, 42).unwrap();
586        let mut reusable = SampledBatch::with_capacity(16, 4, 1);
587        buf.sample_into(&mut reusable, 16, 42).unwrap();
588
589        assert_eq!(batch1.observations, reusable.observations);
590        assert_eq!(batch1.next_observations, reusable.next_observations);
591        assert_eq!(batch1.actions, reusable.actions);
592        assert_eq!(batch1.rewards, reusable.rewards);
593        assert_eq!(batch1.terminated, reusable.terminated);
594        assert_eq!(batch1.batch_size, reusable.batch_size);
595    }
596
597    #[test]
598    fn test_sample_into_reuses_capacity() {
599        let mut buf = ReplayBuffer::new(100, 4, 1);
600        for _ in 0..50 {
601            buf.push(sample_record(4)).unwrap();
602        }
603
604        let mut batch = SampledBatch::with_capacity(16, 4, 1);
605        buf.sample_into(&mut batch, 16, 1).unwrap();
606        let obs_cap = batch.observations.capacity();
607
608        // Second sample_into should reuse capacity, not shrink
609        buf.sample_into(&mut batch, 16, 2).unwrap();
610        assert!(batch.observations.capacity() >= obs_cap);
611    }
612
613    mod proptests {
614        use super::*;
615        use proptest::prelude::*;
616
617        proptest! {
618            #[test]
619            fn ring_buffer_never_exceeds_capacity(capacity in 1..500usize, num_pushes in 0..2000usize) {
620                let mut buf = ReplayBuffer::new(capacity, 4, 1);
621                for _ in 0..num_pushes {
622                    buf.push(sample_record(4)).unwrap();
623                }
624                prop_assert!(buf.len() <= capacity);
625            }
626
627            #[test]
628            fn ring_buffer_len_is_min_of_pushes_and_capacity(capacity in 1..500usize, num_pushes in 0..2000usize) {
629                let mut buf = ReplayBuffer::new(capacity, 4, 1);
630                for _ in 0..num_pushes {
631                    buf.push(sample_record(4)).unwrap();
632                }
633                prop_assert_eq!(buf.len(), num_pushes.min(capacity));
634            }
635
636            #[test]
637            fn sample_returns_requested_size_prop(capacity in 10..500usize, num_pushes in 10..2000usize, batch_size in 1..50usize) {
638                let mut buf = ReplayBuffer::new(capacity, 4, 1);
639                for _ in 0..num_pushes {
640                    buf.push(sample_record(4)).unwrap();
641                }
642                let effective_batch = batch_size.min(buf.len());
643                let batch = buf.sample(effective_batch, 42).unwrap();
644                prop_assert_eq!(batch.batch_size, effective_batch);
645            }
646        }
647    }
648}