rlox_core/buffer/
mmap.rs

1use std::fs::{File, OpenOptions};
2use std::path::PathBuf;
3
4use memmap2::{Mmap, MmapMut};
5use rand::Rng;
6use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8
9use crate::error::RloxError;
10
11use super::ringbuf::{ReplayBuffer, SampledBatch};
12use super::ExperienceRecord;
13
14/// Binary header: obs_dim(u64) + act_dim(u64) + record_count(u64) = 24 bytes.
15const HEADER_SIZE: usize = 24;
16
17/// When replay buffer exceeds RAM capacity, transparently spill to disk.
18/// Uses mmap for lazy loading -- only pages actually accessed are read.
19///
20/// Architecture:
21/// - `hot`: in-memory `ReplayBuffer` for recent data (fast sampling)
22/// - `cold`: memory-mapped file for overflow data (ring-buffer with wrapping)
23/// - Sampling prefers hot data but can reach into cold
24pub struct MmapReplayBuffer {
25    hot: ReplayBuffer,
26    cold_path: PathBuf,
27    cold_file: Option<File>,
28    /// Read-only mmap for sampling. Created lazily on first sample after cold
29    /// file is initialized.
30    cold_mmap: Option<Mmap>,
31    /// Mutable mmap for writes. Pre-allocated to full cold capacity so no
32    /// remap is ever needed during push.
33    cold_mmap_mut: Option<MmapMut>,
34    /// Number of valid records in cold storage (saturates at cold_capacity).
35    cold_count: usize,
36    /// Next write position in cold ring-buffer (wraps at cold_capacity).
37    cold_write_pos: usize,
38    obs_dim: usize,
39    act_dim: usize,
40    hot_capacity: usize,
41    total_capacity: usize,
42    /// Whether the read mmap needs to be refreshed before the next sample.
43    mmap_stale: bool,
44}
45
46impl std::fmt::Debug for MmapReplayBuffer {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("MmapReplayBuffer")
49            .field("obs_dim", &self.obs_dim)
50            .field("act_dim", &self.act_dim)
51            .field("hot_capacity", &self.hot_capacity)
52            .field("total_capacity", &self.total_capacity)
53            .field("cold_count", &self.cold_count)
54            .field("cold_write_pos", &self.cold_write_pos)
55            .field("cold_path", &self.cold_path)
56            .finish_non_exhaustive()
57    }
58}
59
60impl MmapReplayBuffer {
61    /// Create a new mmap-backed replay buffer.
62    ///
63    /// `hot_capacity` records are kept in memory. When exceeded, the oldest
64    /// hot records spill to the cold file at `cold_path`. The total number of
65    /// records stored (hot + cold) is capped at `total_capacity`.
66    pub fn new(
67        hot_capacity: usize,
68        total_capacity: usize,
69        obs_dim: usize,
70        act_dim: usize,
71        cold_path: PathBuf,
72    ) -> Result<Self, RloxError> {
73        if hot_capacity == 0 {
74            return Err(RloxError::BufferError(
75                "hot_capacity must be > 0".to_string(),
76            ));
77        }
78        if total_capacity < hot_capacity {
79            return Err(RloxError::BufferError(
80                "total_capacity must be >= hot_capacity".to_string(),
81            ));
82        }
83        if obs_dim == 0 {
84            return Err(RloxError::BufferError("obs_dim must be > 0".to_string()));
85        }
86        if act_dim == 0 {
87            return Err(RloxError::BufferError("act_dim must be > 0".to_string()));
88        }
89
90        Ok(Self {
91            hot: ReplayBuffer::new(hot_capacity, obs_dim, act_dim),
92            cold_path,
93            cold_file: None,
94            cold_mmap: None,
95            cold_mmap_mut: None,
96            cold_count: 0,
97            cold_write_pos: 0,
98            obs_dim,
99            act_dim,
100            hot_capacity,
101            total_capacity,
102            mmap_stale: false,
103        })
104    }
105
106    /// Byte size of a single serialized record (no padding).
107    fn record_byte_size(&self) -> usize {
108        // obs(f32 * obs_dim) + next_obs(f32 * obs_dim) + action(f32 * act_dim) + reward(f32) + terminated(u8) + truncated(u8)
109        (self.obs_dim * 2 + self.act_dim + 1) * 4 + 2
110    }
111
112    /// Total number of records stored (hot + cold).
113    pub fn len(&self) -> usize {
114        self.hot.len() + self.cold_count
115    }
116
117    /// Whether the buffer contains no records.
118    pub fn is_empty(&self) -> bool {
119        self.len() == 0
120    }
121
122    /// Push a record. If the hot buffer is full, the oldest hot record is
123    /// spilled to cold storage before the new record is inserted.
124    pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
125        if record.obs.len() != self.obs_dim {
126            return Err(RloxError::ShapeMismatch {
127                expected: format!("obs_dim={}", self.obs_dim),
128                got: format!("obs.len()={}", record.obs.len()),
129            });
130        }
131        if record.next_obs.len() != self.obs_dim {
132            return Err(RloxError::ShapeMismatch {
133                expected: format!("obs_dim={}", self.obs_dim),
134                got: format!("next_obs.len()={}", record.next_obs.len()),
135            });
136        }
137        if record.action.len() != self.act_dim {
138            return Err(RloxError::ShapeMismatch {
139                expected: format!("act_dim={}", self.act_dim),
140                got: format!("action.len()={}", record.action.len()),
141            });
142        }
143
144        // If hot is full, spill the oldest hot record to cold before pushing.
145        if self.hot.len() == self.hot_capacity {
146            let oldest = self.read_oldest_hot_record();
147            self.write_to_cold(&oldest)?;
148        }
149
150        self.hot.push(record)?;
151        Ok(())
152    }
153
154    /// Push multiple transitions at once from flat arrays.
155    ///
156    /// `obs_batch` shape: `[n * obs_dim]`, `next_obs_batch`: same,
157    /// `actions_batch`: `[n * act_dim]`, others: `[n]`.
158    ///
159    /// # `terminated` / `truncated` convention
160    ///
161    /// These take `&[f32]` (not `bool`) for compatibility with numpy arrays
162    /// from the Python side. Non-zero values are treated as `true`.
163    /// This differs from [`push`](Self::push) which accepts native `bool`.
164    pub fn push_batch(
165        &mut self,
166        obs_batch: &[f32],
167        next_obs_batch: &[f32],
168        actions_batch: &[f32],
169        rewards: &[f32],
170        terminated: &[f32],
171        truncated: &[f32],
172    ) -> Result<(), RloxError> {
173        let n = rewards.len();
174        if obs_batch.len() != n * self.obs_dim
175            || next_obs_batch.len() != n * self.obs_dim
176            || actions_batch.len() != n * self.act_dim
177            || terminated.len() != n
178            || truncated.len() != n
179        {
180            return Err(RloxError::ShapeMismatch {
181                expected: format!("n={n}, obs_dim={}, act_dim={}", self.obs_dim, self.act_dim),
182                got: format!(
183                    "obs={}, next_obs={}, act={}, rew={}, term={}, trunc={}",
184                    obs_batch.len(),
185                    next_obs_batch.len(),
186                    actions_batch.len(),
187                    rewards.len(),
188                    terminated.len(),
189                    truncated.len()
190                ),
191            });
192        }
193        for i in 0..n {
194            let obs = obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim].to_vec();
195            let next_obs = next_obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim].to_vec();
196            let action = actions_batch[i * self.act_dim..(i + 1) * self.act_dim].to_vec();
197            let record = ExperienceRecord {
198                obs,
199                next_obs,
200                action,
201                reward: rewards[i],
202                terminated: terminated[i] != 0.0,
203                truncated: truncated[i] != 0.0,
204            };
205            self.push(record)?;
206        }
207        Ok(())
208    }
209
210    /// Sample `batch_size` records uniformly from hot + cold storage.
211    /// Uses ChaCha8Rng seeded with `seed` for deterministic reproducibility.
212    pub fn sample(&mut self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
213        let total = self.len();
214        if batch_size > total {
215            return Err(RloxError::BufferError(format!(
216                "batch_size {} > buffer len {}",
217                batch_size, total,
218            )));
219        }
220
221        let mut rng = ChaCha8Rng::seed_from_u64(seed);
222        let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
223
224        let cold_len = self.cold_count;
225
226        // Refresh the read mmap if any writes occurred since last sample.
227        if cold_len > 0 {
228            self.refresh_read_mmap()?;
229        }
230
231        for _ in 0..batch_size {
232            let idx = rng.random_range(0..total);
233            if idx < cold_len {
234                // Read from cold mmap.
235                self.read_cold_record_into(idx, &mut batch);
236            } else {
237                // Read from hot. Hot indices are [0..hot_len).
238                let hot_idx = idx - cold_len;
239                self.read_hot_record_into(hot_idx, &mut batch);
240            }
241        }
242        batch.batch_size = batch_size;
243        Ok(batch)
244    }
245
246    /// Unmap the cold file and delete it from disk.
247    pub fn close(&mut self) -> Result<(), RloxError> {
248        // Drop mmaps first so the file can be removed.
249        self.cold_mmap.take();
250        self.cold_mmap_mut.take();
251        self.cold_file.take();
252        if self.cold_path.exists() {
253            std::fs::remove_file(&self.cold_path)?;
254        }
255        self.cold_count = 0;
256        self.cold_write_pos = 0;
257        Ok(())
258    }
259
260    // ---- private helpers ------------------------------------------------
261
262    /// Read the record at position `hot_idx` from the hot buffer into `batch`.
263    fn read_hot_record_into(&self, hot_idx: usize, batch: &mut SampledBatch) {
264        let (obs, next_obs, act, reward, terminated, truncated) = self.hot.get(hot_idx);
265        batch.observations.extend_from_slice(obs);
266        batch.next_observations.extend_from_slice(next_obs);
267        batch.actions.extend_from_slice(act);
268        batch.rewards.push(reward);
269        batch.terminated.push(terminated);
270        batch.truncated.push(truncated);
271    }
272
273    /// Read the oldest record from the hot ring buffer (the one at write_pos,
274    /// which is about to be overwritten).
275    fn read_oldest_hot_record(&self) -> ExperienceRecord {
276        let oldest_idx = self.hot.write_pos();
277        let (obs, next_obs, act, reward, terminated, truncated) = self.hot.get(oldest_idx);
278        ExperienceRecord {
279            obs: obs.to_vec(),
280            next_obs: next_obs.to_vec(),
281            action: act.to_vec(),
282            reward,
283            terminated,
284            truncated,
285        }
286    }
287
288    /// Write a record to the cold ring-buffer at `cold_write_pos`, then advance.
289    ///
290    /// Uses a pre-allocated `MmapMut` so no remap is needed per push.
291    /// The file is pre-allocated to full cold capacity via `ftruncate`.
292    fn write_to_cold(&mut self, record: &ExperienceRecord) -> Result<(), RloxError> {
293        self.ensure_cold_file()?;
294
295        let cold_capacity = self.total_capacity - self.hot_capacity;
296        let record_size = self.record_byte_size();
297
298        // Serialize record into the mmap directly.
299        let file_offset = HEADER_SIZE + self.cold_write_pos * record_size;
300        let mmap = self
301            .cold_mmap_mut
302            .as_mut()
303            .expect("cold_mmap_mut must be set after ensure_cold_file");
304
305        let dst = &mut mmap[file_offset..file_offset + record_size];
306        let mut pos = 0;
307        for &v in &record.obs {
308            dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
309            pos += 4;
310        }
311        for &v in &record.next_obs {
312            dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
313            pos += 4;
314        }
315        for &v in &record.action {
316            dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
317            pos += 4;
318        }
319        dst[pos..pos + 4].copy_from_slice(&record.reward.to_le_bytes());
320        pos += 4;
321        dst[pos] = record.terminated as u8;
322        dst[pos + 1] = record.truncated as u8;
323
324        // Advance ring-buffer position.
325        self.cold_write_pos = (self.cold_write_pos + 1) % cold_capacity;
326        if self.cold_count < cold_capacity {
327            self.cold_count += 1;
328        }
329
330        // Update header in the mmap.
331        self.write_cold_header_mmap();
332
333        // Mark read mmap as stale so it refreshes before next sample.
334        self.mmap_stale = true;
335
336        Ok(())
337    }
338
339    /// Ensure the cold file is open, pre-allocated to full capacity, and mapped.
340    ///
341    /// The file is sized to `HEADER_SIZE + cold_capacity * record_byte_size`
342    /// via `ftruncate` at creation time. A single `MmapMut` covers the entire
343    /// file, eliminating per-push remap overhead.
344    fn ensure_cold_file(&mut self) -> Result<(), RloxError> {
345        if self.cold_file.is_some() {
346            return Ok(());
347        }
348
349        let file = OpenOptions::new()
350            .create(true)
351            .read(true)
352            .write(true)
353            .truncate(true)
354            .open(&self.cold_path)?;
355
356        let cold_capacity = self.total_capacity - self.hot_capacity;
357        let total_file_size = HEADER_SIZE + cold_capacity * self.record_byte_size();
358        file.set_len(total_file_size as u64)?;
359
360        // SAFETY: The file is exclusively owned by this buffer. No other
361        // process or thread accesses it. We write through `cold_mmap_mut`
362        // and read through a separate `cold_mmap` (refreshed lazily).
363        let mmap_mut = unsafe { MmapMut::map_mut(&file)? };
364
365        self.cold_file = Some(file);
366        self.cold_mmap_mut = Some(mmap_mut);
367
368        // Write initial header (cold_count = 0).
369        self.write_cold_header_mmap();
370
371        Ok(())
372    }
373
374    /// Write the 24-byte header directly into the mutable mmap.
375    fn write_cold_header_mmap(&mut self) {
376        let mmap = self
377            .cold_mmap_mut
378            .as_mut()
379            .expect("cold_mmap_mut must be set");
380        mmap[0..8].copy_from_slice(&(self.obs_dim as u64).to_le_bytes());
381        mmap[8..16].copy_from_slice(&(self.act_dim as u64).to_le_bytes());
382        mmap[16..24].copy_from_slice(&(self.cold_count as u64).to_le_bytes());
383    }
384
385    /// Refresh the read-only mmap from the file.
386    ///
387    /// Called lazily before sampling when writes have occurred since the
388    /// last refresh. Since the file is pre-allocated, this does not change
389    /// the file size — it just picks up the bytes written via `cold_mmap_mut`.
390    fn refresh_read_mmap(&mut self) -> Result<(), RloxError> {
391        if !self.mmap_stale && self.cold_mmap.is_some() {
392            return Ok(());
393        }
394        // Flush the mutable mmap so reads see the latest data.
395        if let Some(ref mmap_mut) = self.cold_mmap_mut {
396            mmap_mut.flush()?;
397        }
398        // Drop old read mmap.
399        self.cold_mmap.take();
400        let file = self.cold_file.as_ref().expect("cold_file must be open");
401        // SAFETY: The file is exclusively owned by this buffer. The mutable
402        // mmap has been flushed above, so all writes are visible.
403        let mmap = unsafe { Mmap::map(file)? };
404        self.cold_mmap = Some(mmap);
405        self.mmap_stale = false;
406        Ok(())
407    }
408
409    /// Read record at `idx` from the cold mmap into `batch`.
410    fn read_cold_record_into(&self, idx: usize, batch: &mut SampledBatch) {
411        let mmap = self.cold_mmap.as_ref().expect("cold_mmap must exist");
412        let record_size = self.record_byte_size();
413        let offset = HEADER_SIZE + idx * record_size;
414        let data = &mmap[offset..offset + record_size];
415
416        let obs_bytes = self.obs_dim * 4;
417        let act_bytes = self.act_dim * 4;
418
419        // Parse obs.
420        for i in 0..self.obs_dim {
421            let start = i * 4;
422            let val = f32::from_le_bytes([
423                data[start],
424                data[start + 1],
425                data[start + 2],
426                data[start + 3],
427            ]);
428            batch.observations.push(val);
429        }
430
431        // Parse next_obs.
432        let next_obs_base = obs_bytes;
433        for i in 0..self.obs_dim {
434            let start = next_obs_base + i * 4;
435            let val = f32::from_le_bytes([
436                data[start],
437                data[start + 1],
438                data[start + 2],
439                data[start + 3],
440            ]);
441            batch.next_observations.push(val);
442        }
443
444        // Parse actions.
445        let act_base = obs_bytes * 2;
446        for i in 0..self.act_dim {
447            let start = act_base + i * 4;
448            let val = f32::from_le_bytes([
449                data[start],
450                data[start + 1],
451                data[start + 2],
452                data[start + 3],
453            ]);
454            batch.actions.push(val);
455        }
456
457        // Parse reward.
458        let reward_base = obs_bytes * 2 + act_bytes;
459        let reward = f32::from_le_bytes([
460            data[reward_base],
461            data[reward_base + 1],
462            data[reward_base + 2],
463            data[reward_base + 3],
464        ]);
465        batch.rewards.push(reward);
466
467        // Parse terminated, truncated.
468        batch.terminated.push(data[reward_base + 4] != 0);
469        batch.truncated.push(data[reward_base + 5] != 0);
470    }
471}
472
473impl Drop for MmapReplayBuffer {
474    fn drop(&mut self) {
475        // Best-effort cleanup.
476        let _ = self.close();
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::buffer::{sample_record, sample_record_multidim};
484
485    fn temp_cold_path() -> (tempfile::NamedTempFile, PathBuf) {
486        let tmp = tempfile::NamedTempFile::new().unwrap();
487        let path = tmp.path().to_path_buf();
488        // Close the named temp file so MmapReplayBuffer can create/truncate it.
489        // We keep the PathBuf; the file will be cleaned up by the buffer or test.
490        (tmp, path)
491    }
492
493    #[test]
494    fn test_mmap_buffer_new_is_empty() {
495        let (_tmp, path) = temp_cold_path();
496        let buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
497        assert_eq!(buf.len(), 0);
498        assert!(buf.is_empty());
499    }
500
501    #[test]
502    fn test_mmap_buffer_push_within_hot_capacity() {
503        let (_tmp, path) = temp_cold_path();
504        let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
505        for _ in 0..50 {
506            buf.push(sample_record(4)).unwrap();
507        }
508        assert_eq!(buf.len(), 50);
509        // No cold records should exist.
510        assert_eq!(buf.cold_count, 0);
511    }
512
513    #[test]
514    fn test_mmap_buffer_push_exceeds_hot_spills_to_cold() {
515        let (_tmp, path) = temp_cold_path();
516        let mut buf = MmapReplayBuffer::new(10, 100, 4, 1, path).unwrap();
517
518        // Push 15 records: 10 fit in hot, then pushes 11-15 spill records to cold.
519        for i in 0..15 {
520            let mut rec = sample_record(4);
521            rec.reward = i as f32;
522            buf.push(rec).unwrap();
523        }
524
525        // Hot should be full (10), cold should have 5 spilled records.
526        assert_eq!(buf.hot.len(), 10);
527        assert_eq!(buf.cold_count, 5);
528        assert_eq!(buf.len(), 15);
529    }
530
531    #[test]
532    fn test_mmap_buffer_sample_from_hot_only() {
533        let (_tmp, path) = temp_cold_path();
534        let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
535        for i in 0..50 {
536            let mut rec = sample_record(4);
537            rec.reward = i as f32;
538            buf.push(rec).unwrap();
539        }
540        let batch = buf.sample(10, 42).unwrap();
541        assert_eq!(batch.batch_size, 10);
542        assert_eq!(batch.observations.len(), 10 * 4);
543        assert_eq!(batch.rewards.len(), 10);
544    }
545
546    #[test]
547    fn test_mmap_buffer_sample_from_hot_and_cold() {
548        let (_tmp, path) = temp_cold_path();
549        let mut buf = MmapReplayBuffer::new(10, 100, 4, 1, path).unwrap();
550
551        // Push 20 records with distinct rewards.
552        for i in 0..20 {
553            let mut rec = sample_record(4);
554            rec.reward = i as f32;
555            buf.push(rec).unwrap();
556        }
557
558        assert_eq!(buf.cold_count, 10);
559        assert_eq!(buf.hot.len(), 10);
560
561        // Sample a large batch -- should draw from both hot and cold.
562        let batch = buf.sample(20, 99).unwrap();
563        assert_eq!(batch.batch_size, 20);
564        assert_eq!(batch.rewards.len(), 20);
565
566        // Verify we see some rewards from the cold range [0..10)
567        // and some from the hot range [10..20).
568        let has_cold = batch.rewards.iter().any(|&r| r < 10.0);
569        let has_hot = batch.rewards.iter().any(|&r| r >= 10.0);
570        assert!(has_cold, "expected some samples from cold storage");
571        assert!(has_hot, "expected some samples from hot storage");
572    }
573
574    #[test]
575    fn test_mmap_buffer_total_count() {
576        let (_tmp, path) = temp_cold_path();
577        let mut buf = MmapReplayBuffer::new(5, 50, 4, 1, path).unwrap();
578        for _ in 0..30 {
579            buf.push(sample_record(4)).unwrap();
580        }
581        // hot: 5, cold: 25
582        assert_eq!(buf.len(), 30);
583        assert_eq!(buf.hot.len(), 5);
584        assert_eq!(buf.cold_count, 25);
585    }
586
587    #[test]
588    fn test_mmap_buffer_deterministic_sampling() {
589        let (_tmp, path1) = temp_cold_path();
590        let (_tmp2, path2) = temp_cold_path();
591        let mut buf1 = MmapReplayBuffer::new(10, 100, 4, 1, path1).unwrap();
592        let mut buf2 = MmapReplayBuffer::new(10, 100, 4, 1, path2).unwrap();
593
594        for i in 0..20 {
595            let mut rec = sample_record(4);
596            rec.reward = i as f32;
597            buf1.push(rec.clone()).unwrap();
598            buf2.push(rec).unwrap();
599        }
600
601        let b1 = buf1.sample(15, 42).unwrap();
602        let b2 = buf2.sample(15, 42).unwrap();
603
604        assert_eq!(b1.observations, b2.observations);
605        assert_eq!(b1.rewards, b2.rewards);
606        assert_eq!(b1.terminated, b2.terminated);
607    }
608
609    #[test]
610    fn test_mmap_buffer_cleanup_removes_file() {
611        let (_tmp, path) = temp_cold_path();
612        let cold_path = path.clone();
613        let mut buf = MmapReplayBuffer::new(5, 50, 4, 1, path).unwrap();
614
615        // Push enough to create the cold file.
616        for _ in 0..10 {
617            buf.push(sample_record(4)).unwrap();
618        }
619        assert!(cold_path.exists(), "cold file should exist after spill");
620
621        buf.close().unwrap();
622        assert!(
623            !cold_path.exists(),
624            "cold file should be removed after close()"
625        );
626    }
627
628    #[test]
629    fn test_mmap_buffer_large_obs_dim() {
630        let (_tmp, path) = temp_cold_path();
631        let obs_dim = 28224; // Atari-scale (84*84*4)
632        let act_dim = 1;
633        let mut buf = MmapReplayBuffer::new(5, 20, obs_dim, act_dim, path).unwrap();
634
635        for i in 0..10 {
636            let mut rec = sample_record_multidim(obs_dim, act_dim);
637            rec.reward = i as f32;
638            // Set first obs element to distinguish records.
639            rec.obs[0] = i as f32;
640            buf.push(rec).unwrap();
641        }
642
643        assert_eq!(buf.len(), 10);
644        assert_eq!(buf.hot.len(), 5);
645        assert_eq!(buf.cold_count, 5);
646
647        // Sample and verify obs_dim is preserved.
648        let batch = buf.sample(5, 42).unwrap();
649        assert_eq!(batch.observations.len(), 5 * obs_dim);
650        assert_eq!(batch.obs_dim, obs_dim);
651    }
652
653    // ---- push_batch tests ----
654
655    #[test]
656    fn test_push_batch_fills_hot_correctly() {
657        let (_tmp, path) = temp_cold_path();
658        let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
659
660        let n = 10;
661        let obs: Vec<f32> = (0..n * 4).map(|i| i as f32).collect();
662        let next_obs: Vec<f32> = (0..n * 4).map(|i| i as f32 + 100.0).collect();
663        let actions: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
664        let rewards: Vec<f32> = (0..n).map(|i| i as f32).collect();
665        let terminated = vec![0.0f32; n];
666        let truncated = vec![0.0f32; n];
667
668        buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
669            .unwrap();
670
671        assert_eq!(buf.len(), 10);
672        assert_eq!(buf.cold_count, 0);
673
674        // Sample all and verify rewards round-trip.
675        let batch = buf.sample(10, 42).unwrap();
676        assert_eq!(batch.batch_size, 10);
677        assert_eq!(batch.rewards.len(), 10);
678    }
679
680    #[test]
681    fn test_push_batch_triggers_spill_to_cold() {
682        let (_tmp, path) = temp_cold_path();
683        let mut buf = MmapReplayBuffer::new(5, 100, 2, 1, path).unwrap();
684
685        let n = 12;
686        let obs: Vec<f32> = (0..n * 2).map(|i| i as f32).collect();
687        let next_obs: Vec<f32> = (0..n * 2).map(|i| i as f32 + 100.0).collect();
688        let actions: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
689        let rewards: Vec<f32> = (0..n).map(|i| i as f32).collect();
690        let terminated = vec![0.0f32; n];
691        let truncated = vec![0.0f32; n];
692
693        buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
694            .unwrap();
695
696        assert_eq!(buf.hot.len(), 5);
697        assert_eq!(buf.cold_count, 7);
698        assert_eq!(buf.len(), 12);
699    }
700
701    #[test]
702    fn test_push_batch_shape_mismatch() {
703        let (_tmp, path) = temp_cold_path();
704        let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
705
706        // Wrong obs length.
707        let result = buf.push_batch(
708            &[1.0, 2.0, 3.0], // 3 instead of 4
709            &[1.0, 2.0, 3.0, 4.0],
710            &[0.0],
711            &[1.0],
712            &[0.0],
713            &[0.0],
714        );
715        assert!(result.is_err());
716    }
717
718    #[test]
719    fn test_push_batch_terminated_truncated_flags() {
720        let (_tmp, path) = temp_cold_path();
721        let mut buf = MmapReplayBuffer::new(100, 1000, 2, 1, path).unwrap();
722
723        let obs = vec![1.0, 2.0, 3.0, 4.0]; // 2 records, obs_dim=2
724        let next_obs = vec![5.0, 6.0, 7.0, 8.0];
725        let actions = vec![0.0, 1.0];
726        let rewards = vec![1.0, 2.0];
727        let terminated = vec![1.0, 0.0]; // first terminated
728        let truncated = vec![0.0, 1.0]; // second truncated
729
730        buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
731            .unwrap();
732
733        let batch = buf.sample(2, 42).unwrap();
734        // Both records should be present; verify bools are meaningful.
735        assert_eq!(batch.terminated.len(), 2);
736        assert_eq!(batch.truncated.len(), 2);
737    }
738
739    // ---- cold ring-buffer eviction tests ----
740
741    #[test]
742    fn test_cold_ring_buffer_eviction_overwrites_oldest() {
743        let (_tmp, path) = temp_cold_path();
744        // hot_capacity=3, total_capacity=6 => cold_capacity=3
745        let mut buf = MmapReplayBuffer::new(3, 6, 2, 1, path).unwrap();
746
747        // Push 10 records. After 6, cold is full (3 hot + 3 cold).
748        // Records 7-10 should evict oldest cold entries via ring-buffer.
749        for i in 0..10 {
750            let rec = ExperienceRecord {
751                obs: vec![i as f32; 2],
752                next_obs: vec![i as f32 + 100.0; 2],
753                action: vec![i as f32 * 0.1],
754                reward: i as f32,
755                terminated: false,
756                truncated: false,
757            };
758            buf.push(rec).unwrap();
759        }
760
761        // Total should be capped at total_capacity=6.
762        assert_eq!(buf.len(), 6);
763        assert_eq!(buf.hot.len(), 3);
764        assert_eq!(buf.cold_count, 3);
765
766        // Sample all 6 records. The cold ring-buffer should contain records
767        // with rewards from the spill sequence, not the very oldest ones.
768        let batch = buf.sample(6, 42).unwrap();
769        assert_eq!(batch.batch_size, 6);
770        assert_eq!(batch.rewards.len(), 6);
771
772        // Rewards 0, 1, 2 were the earliest cold entries but should have been
773        // overwritten. Hot has the last 3 pushed (7, 8, 9).
774        // Cold should have the 3 most recent spills.
775        for &r in &batch.rewards {
776            assert!(
777                r >= 4.0,
778                "expected only recent records (reward >= 4.0), got {r}"
779            );
780        }
781    }
782
783    #[test]
784    fn test_cold_eviction_cold_write_pos_wraps() {
785        let (_tmp, path) = temp_cold_path();
786        // hot_capacity=2, total_capacity=4 => cold_capacity=2
787        let mut buf = MmapReplayBuffer::new(2, 4, 1, 1, path).unwrap();
788
789        // Push 8 records to force multiple cold wrap-arounds.
790        for i in 0..8 {
791            let rec = ExperienceRecord {
792                obs: vec![i as f32],
793                next_obs: vec![i as f32 + 10.0],
794                action: vec![0.0],
795                reward: i as f32,
796                terminated: false,
797                truncated: false,
798            };
799            buf.push(rec).unwrap();
800        }
801
802        assert_eq!(buf.len(), 4);
803        assert_eq!(buf.cold_count, 2);
804        // cold_write_pos should have wrapped: 6 spills into capacity 2 => pos 0
805        assert_eq!(buf.cold_write_pos, 0);
806    }
807
808    #[test]
809    fn test_sampling_after_cold_eviction_returns_valid_data() {
810        let (_tmp, path) = temp_cold_path();
811        // hot=5, total=8, cold_capacity=3
812        let mut buf = MmapReplayBuffer::new(5, 8, 3, 1, path).unwrap();
813
814        // Push 20 records to cause many evictions.
815        for i in 0..20 {
816            let rec = ExperienceRecord {
817                obs: vec![i as f32; 3],
818                next_obs: vec![i as f32 + 0.5; 3],
819                action: vec![i as f32 * 0.01],
820                reward: i as f32,
821                terminated: i % 3 == 0,
822                truncated: i % 5 == 0,
823            };
824            buf.push(rec).unwrap();
825        }
826
827        assert_eq!(buf.len(), 8);
828
829        // Sampling should not panic and data should be internally consistent.
830        let batch = buf.sample(8, 123).unwrap();
831        assert_eq!(batch.batch_size, 8);
832        assert_eq!(batch.observations.len(), 8 * 3);
833        assert_eq!(batch.next_observations.len(), 8 * 3);
834        assert_eq!(batch.actions.len(), 8);
835        assert_eq!(batch.rewards.len(), 8);
836
837        // Each sampled obs should match its reward (obs was [r; 3]).
838        for i in 0..8 {
839            let r = batch.rewards[i];
840            let obs_slice = &batch.observations[i * 3..(i + 1) * 3];
841            for &v in obs_slice {
842                assert!((v - r).abs() < 1e-6, "obs {v} should match reward {r}");
843            }
844        }
845    }
846
847    #[test]
848    fn test_push_batch_then_sample_consistency() {
849        let (_tmp, path) = temp_cold_path();
850        let mut buf = MmapReplayBuffer::new(5, 15, 2, 1, path).unwrap();
851
852        // Push 3 batches of 5 to fill past hot and into cold.
853        for batch_idx in 0..3 {
854            let base = batch_idx as f32 * 5.0;
855            let n = 5;
856            let obs: Vec<f32> = (0..n)
857                .flat_map(|i| {
858                    let v = base + i as f32;
859                    vec![v, v + 0.1]
860                })
861                .collect();
862            let next_obs: Vec<f32> = (0..n)
863                .flat_map(|i| {
864                    let v = base + i as f32 + 100.0;
865                    vec![v, v + 0.1]
866                })
867                .collect();
868            let actions: Vec<f32> = (0..n).map(|i| (base + i as f32) * 0.01).collect();
869            let rewards: Vec<f32> = (0..n).map(|i| base + i as f32).collect();
870            let terminated = vec![0.0f32; n];
871            let truncated = vec![0.0f32; n];
872            buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
873                .unwrap();
874        }
875
876        assert_eq!(buf.len(), 15);
877        assert_eq!(buf.hot.len(), 5);
878        assert_eq!(buf.cold_count, 10);
879
880        // Sample and check structural validity.
881        let batch = buf.sample(15, 7).unwrap();
882        assert_eq!(batch.batch_size, 15);
883        assert_eq!(batch.observations.len(), 15 * 2);
884
885        // Each observation pair should match: obs[0] and obs[1] differ by 0.1.
886        for i in 0..15 {
887            let o0 = batch.observations[i * 2];
888            let o1 = batch.observations[i * 2 + 1];
889            assert!(
890                (o1 - o0 - 0.1).abs() < 1e-5,
891                "obs pair mismatch at sample {i}: {o0}, {o1}"
892            );
893        }
894    }
895}