1use std::fs::{File, OpenOptions};
2use std::path::PathBuf;
3
4use memmap2::{Mmap, MmapMut};
5use rand::Rng;
6use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8
9use crate::error::RloxError;
10
11use super::ringbuf::{ReplayBuffer, SampledBatch};
12use super::ExperienceRecord;
13
14const HEADER_SIZE: usize = 24;
16
17pub struct MmapReplayBuffer {
25 hot: ReplayBuffer,
26 cold_path: PathBuf,
27 cold_file: Option<File>,
28 cold_mmap: Option<Mmap>,
31 cold_mmap_mut: Option<MmapMut>,
34 cold_count: usize,
36 cold_write_pos: usize,
38 obs_dim: usize,
39 act_dim: usize,
40 hot_capacity: usize,
41 total_capacity: usize,
42 mmap_stale: bool,
44}
45
46impl std::fmt::Debug for MmapReplayBuffer {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("MmapReplayBuffer")
49 .field("obs_dim", &self.obs_dim)
50 .field("act_dim", &self.act_dim)
51 .field("hot_capacity", &self.hot_capacity)
52 .field("total_capacity", &self.total_capacity)
53 .field("cold_count", &self.cold_count)
54 .field("cold_write_pos", &self.cold_write_pos)
55 .field("cold_path", &self.cold_path)
56 .finish_non_exhaustive()
57 }
58}
59
60impl MmapReplayBuffer {
61 pub fn new(
67 hot_capacity: usize,
68 total_capacity: usize,
69 obs_dim: usize,
70 act_dim: usize,
71 cold_path: PathBuf,
72 ) -> Result<Self, RloxError> {
73 if hot_capacity == 0 {
74 return Err(RloxError::BufferError(
75 "hot_capacity must be > 0".to_string(),
76 ));
77 }
78 if total_capacity < hot_capacity {
79 return Err(RloxError::BufferError(
80 "total_capacity must be >= hot_capacity".to_string(),
81 ));
82 }
83 if obs_dim == 0 {
84 return Err(RloxError::BufferError("obs_dim must be > 0".to_string()));
85 }
86 if act_dim == 0 {
87 return Err(RloxError::BufferError("act_dim must be > 0".to_string()));
88 }
89
90 Ok(Self {
91 hot: ReplayBuffer::new(hot_capacity, obs_dim, act_dim),
92 cold_path,
93 cold_file: None,
94 cold_mmap: None,
95 cold_mmap_mut: None,
96 cold_count: 0,
97 cold_write_pos: 0,
98 obs_dim,
99 act_dim,
100 hot_capacity,
101 total_capacity,
102 mmap_stale: false,
103 })
104 }
105
106 fn record_byte_size(&self) -> usize {
108 (self.obs_dim * 2 + self.act_dim + 1) * 4 + 2
110 }
111
112 pub fn len(&self) -> usize {
114 self.hot.len() + self.cold_count
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.len() == 0
120 }
121
122 pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
125 if record.obs.len() != self.obs_dim {
126 return Err(RloxError::ShapeMismatch {
127 expected: format!("obs_dim={}", self.obs_dim),
128 got: format!("obs.len()={}", record.obs.len()),
129 });
130 }
131 if record.next_obs.len() != self.obs_dim {
132 return Err(RloxError::ShapeMismatch {
133 expected: format!("obs_dim={}", self.obs_dim),
134 got: format!("next_obs.len()={}", record.next_obs.len()),
135 });
136 }
137 if record.action.len() != self.act_dim {
138 return Err(RloxError::ShapeMismatch {
139 expected: format!("act_dim={}", self.act_dim),
140 got: format!("action.len()={}", record.action.len()),
141 });
142 }
143
144 if self.hot.len() == self.hot_capacity {
146 let oldest = self.read_oldest_hot_record();
147 self.write_to_cold(&oldest)?;
148 }
149
150 self.hot.push(record)?;
151 Ok(())
152 }
153
154 pub fn push_batch(
165 &mut self,
166 obs_batch: &[f32],
167 next_obs_batch: &[f32],
168 actions_batch: &[f32],
169 rewards: &[f32],
170 terminated: &[f32],
171 truncated: &[f32],
172 ) -> Result<(), RloxError> {
173 let n = rewards.len();
174 if obs_batch.len() != n * self.obs_dim
175 || next_obs_batch.len() != n * self.obs_dim
176 || actions_batch.len() != n * self.act_dim
177 || terminated.len() != n
178 || truncated.len() != n
179 {
180 return Err(RloxError::ShapeMismatch {
181 expected: format!("n={n}, obs_dim={}, act_dim={}", self.obs_dim, self.act_dim),
182 got: format!(
183 "obs={}, next_obs={}, act={}, rew={}, term={}, trunc={}",
184 obs_batch.len(),
185 next_obs_batch.len(),
186 actions_batch.len(),
187 rewards.len(),
188 terminated.len(),
189 truncated.len()
190 ),
191 });
192 }
193 for i in 0..n {
194 let obs = obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim].to_vec();
195 let next_obs = next_obs_batch[i * self.obs_dim..(i + 1) * self.obs_dim].to_vec();
196 let action = actions_batch[i * self.act_dim..(i + 1) * self.act_dim].to_vec();
197 let record = ExperienceRecord {
198 obs,
199 next_obs,
200 action,
201 reward: rewards[i],
202 terminated: terminated[i] != 0.0,
203 truncated: truncated[i] != 0.0,
204 };
205 self.push(record)?;
206 }
207 Ok(())
208 }
209
210 pub fn sample(&mut self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
213 let total = self.len();
214 if batch_size > total {
215 return Err(RloxError::BufferError(format!(
216 "batch_size {} > buffer len {}",
217 batch_size, total,
218 )));
219 }
220
221 let mut rng = ChaCha8Rng::seed_from_u64(seed);
222 let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
223
224 let cold_len = self.cold_count;
225
226 if cold_len > 0 {
228 self.refresh_read_mmap()?;
229 }
230
231 for _ in 0..batch_size {
232 let idx = rng.random_range(0..total);
233 if idx < cold_len {
234 self.read_cold_record_into(idx, &mut batch);
236 } else {
237 let hot_idx = idx - cold_len;
239 self.read_hot_record_into(hot_idx, &mut batch);
240 }
241 }
242 batch.batch_size = batch_size;
243 Ok(batch)
244 }
245
246 pub fn close(&mut self) -> Result<(), RloxError> {
248 self.cold_mmap.take();
250 self.cold_mmap_mut.take();
251 self.cold_file.take();
252 if self.cold_path.exists() {
253 std::fs::remove_file(&self.cold_path)?;
254 }
255 self.cold_count = 0;
256 self.cold_write_pos = 0;
257 Ok(())
258 }
259
260 fn read_hot_record_into(&self, hot_idx: usize, batch: &mut SampledBatch) {
264 let (obs, next_obs, act, reward, terminated, truncated) = self.hot.get(hot_idx);
265 batch.observations.extend_from_slice(obs);
266 batch.next_observations.extend_from_slice(next_obs);
267 batch.actions.extend_from_slice(act);
268 batch.rewards.push(reward);
269 batch.terminated.push(terminated);
270 batch.truncated.push(truncated);
271 }
272
273 fn read_oldest_hot_record(&self) -> ExperienceRecord {
276 let oldest_idx = self.hot.write_pos();
277 let (obs, next_obs, act, reward, terminated, truncated) = self.hot.get(oldest_idx);
278 ExperienceRecord {
279 obs: obs.to_vec(),
280 next_obs: next_obs.to_vec(),
281 action: act.to_vec(),
282 reward,
283 terminated,
284 truncated,
285 }
286 }
287
288 fn write_to_cold(&mut self, record: &ExperienceRecord) -> Result<(), RloxError> {
293 self.ensure_cold_file()?;
294
295 let cold_capacity = self.total_capacity - self.hot_capacity;
296 let record_size = self.record_byte_size();
297
298 let file_offset = HEADER_SIZE + self.cold_write_pos * record_size;
300 let mmap = self
301 .cold_mmap_mut
302 .as_mut()
303 .expect("cold_mmap_mut must be set after ensure_cold_file");
304
305 let dst = &mut mmap[file_offset..file_offset + record_size];
306 let mut pos = 0;
307 for &v in &record.obs {
308 dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
309 pos += 4;
310 }
311 for &v in &record.next_obs {
312 dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
313 pos += 4;
314 }
315 for &v in &record.action {
316 dst[pos..pos + 4].copy_from_slice(&v.to_le_bytes());
317 pos += 4;
318 }
319 dst[pos..pos + 4].copy_from_slice(&record.reward.to_le_bytes());
320 pos += 4;
321 dst[pos] = record.terminated as u8;
322 dst[pos + 1] = record.truncated as u8;
323
324 self.cold_write_pos = (self.cold_write_pos + 1) % cold_capacity;
326 if self.cold_count < cold_capacity {
327 self.cold_count += 1;
328 }
329
330 self.write_cold_header_mmap();
332
333 self.mmap_stale = true;
335
336 Ok(())
337 }
338
339 fn ensure_cold_file(&mut self) -> Result<(), RloxError> {
345 if self.cold_file.is_some() {
346 return Ok(());
347 }
348
349 let file = OpenOptions::new()
350 .create(true)
351 .read(true)
352 .write(true)
353 .truncate(true)
354 .open(&self.cold_path)?;
355
356 let cold_capacity = self.total_capacity - self.hot_capacity;
357 let total_file_size = HEADER_SIZE + cold_capacity * self.record_byte_size();
358 file.set_len(total_file_size as u64)?;
359
360 let mmap_mut = unsafe { MmapMut::map_mut(&file)? };
364
365 self.cold_file = Some(file);
366 self.cold_mmap_mut = Some(mmap_mut);
367
368 self.write_cold_header_mmap();
370
371 Ok(())
372 }
373
374 fn write_cold_header_mmap(&mut self) {
376 let mmap = self
377 .cold_mmap_mut
378 .as_mut()
379 .expect("cold_mmap_mut must be set");
380 mmap[0..8].copy_from_slice(&(self.obs_dim as u64).to_le_bytes());
381 mmap[8..16].copy_from_slice(&(self.act_dim as u64).to_le_bytes());
382 mmap[16..24].copy_from_slice(&(self.cold_count as u64).to_le_bytes());
383 }
384
385 fn refresh_read_mmap(&mut self) -> Result<(), RloxError> {
391 if !self.mmap_stale && self.cold_mmap.is_some() {
392 return Ok(());
393 }
394 if let Some(ref mmap_mut) = self.cold_mmap_mut {
396 mmap_mut.flush()?;
397 }
398 self.cold_mmap.take();
400 let file = self.cold_file.as_ref().expect("cold_file must be open");
401 let mmap = unsafe { Mmap::map(file)? };
404 self.cold_mmap = Some(mmap);
405 self.mmap_stale = false;
406 Ok(())
407 }
408
409 fn read_cold_record_into(&self, idx: usize, batch: &mut SampledBatch) {
411 let mmap = self.cold_mmap.as_ref().expect("cold_mmap must exist");
412 let record_size = self.record_byte_size();
413 let offset = HEADER_SIZE + idx * record_size;
414 let data = &mmap[offset..offset + record_size];
415
416 let obs_bytes = self.obs_dim * 4;
417 let act_bytes = self.act_dim * 4;
418
419 for i in 0..self.obs_dim {
421 let start = i * 4;
422 let val = f32::from_le_bytes([
423 data[start],
424 data[start + 1],
425 data[start + 2],
426 data[start + 3],
427 ]);
428 batch.observations.push(val);
429 }
430
431 let next_obs_base = obs_bytes;
433 for i in 0..self.obs_dim {
434 let start = next_obs_base + i * 4;
435 let val = f32::from_le_bytes([
436 data[start],
437 data[start + 1],
438 data[start + 2],
439 data[start + 3],
440 ]);
441 batch.next_observations.push(val);
442 }
443
444 let act_base = obs_bytes * 2;
446 for i in 0..self.act_dim {
447 let start = act_base + i * 4;
448 let val = f32::from_le_bytes([
449 data[start],
450 data[start + 1],
451 data[start + 2],
452 data[start + 3],
453 ]);
454 batch.actions.push(val);
455 }
456
457 let reward_base = obs_bytes * 2 + act_bytes;
459 let reward = f32::from_le_bytes([
460 data[reward_base],
461 data[reward_base + 1],
462 data[reward_base + 2],
463 data[reward_base + 3],
464 ]);
465 batch.rewards.push(reward);
466
467 batch.terminated.push(data[reward_base + 4] != 0);
469 batch.truncated.push(data[reward_base + 5] != 0);
470 }
471}
472
473impl Drop for MmapReplayBuffer {
474 fn drop(&mut self) {
475 let _ = self.close();
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::buffer::{sample_record, sample_record_multidim};
484
485 fn temp_cold_path() -> (tempfile::NamedTempFile, PathBuf) {
486 let tmp = tempfile::NamedTempFile::new().unwrap();
487 let path = tmp.path().to_path_buf();
488 (tmp, path)
491 }
492
493 #[test]
494 fn test_mmap_buffer_new_is_empty() {
495 let (_tmp, path) = temp_cold_path();
496 let buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
497 assert_eq!(buf.len(), 0);
498 assert!(buf.is_empty());
499 }
500
501 #[test]
502 fn test_mmap_buffer_push_within_hot_capacity() {
503 let (_tmp, path) = temp_cold_path();
504 let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
505 for _ in 0..50 {
506 buf.push(sample_record(4)).unwrap();
507 }
508 assert_eq!(buf.len(), 50);
509 assert_eq!(buf.cold_count, 0);
511 }
512
513 #[test]
514 fn test_mmap_buffer_push_exceeds_hot_spills_to_cold() {
515 let (_tmp, path) = temp_cold_path();
516 let mut buf = MmapReplayBuffer::new(10, 100, 4, 1, path).unwrap();
517
518 for i in 0..15 {
520 let mut rec = sample_record(4);
521 rec.reward = i as f32;
522 buf.push(rec).unwrap();
523 }
524
525 assert_eq!(buf.hot.len(), 10);
527 assert_eq!(buf.cold_count, 5);
528 assert_eq!(buf.len(), 15);
529 }
530
531 #[test]
532 fn test_mmap_buffer_sample_from_hot_only() {
533 let (_tmp, path) = temp_cold_path();
534 let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
535 for i in 0..50 {
536 let mut rec = sample_record(4);
537 rec.reward = i as f32;
538 buf.push(rec).unwrap();
539 }
540 let batch = buf.sample(10, 42).unwrap();
541 assert_eq!(batch.batch_size, 10);
542 assert_eq!(batch.observations.len(), 10 * 4);
543 assert_eq!(batch.rewards.len(), 10);
544 }
545
546 #[test]
547 fn test_mmap_buffer_sample_from_hot_and_cold() {
548 let (_tmp, path) = temp_cold_path();
549 let mut buf = MmapReplayBuffer::new(10, 100, 4, 1, path).unwrap();
550
551 for i in 0..20 {
553 let mut rec = sample_record(4);
554 rec.reward = i as f32;
555 buf.push(rec).unwrap();
556 }
557
558 assert_eq!(buf.cold_count, 10);
559 assert_eq!(buf.hot.len(), 10);
560
561 let batch = buf.sample(20, 99).unwrap();
563 assert_eq!(batch.batch_size, 20);
564 assert_eq!(batch.rewards.len(), 20);
565
566 let has_cold = batch.rewards.iter().any(|&r| r < 10.0);
569 let has_hot = batch.rewards.iter().any(|&r| r >= 10.0);
570 assert!(has_cold, "expected some samples from cold storage");
571 assert!(has_hot, "expected some samples from hot storage");
572 }
573
574 #[test]
575 fn test_mmap_buffer_total_count() {
576 let (_tmp, path) = temp_cold_path();
577 let mut buf = MmapReplayBuffer::new(5, 50, 4, 1, path).unwrap();
578 for _ in 0..30 {
579 buf.push(sample_record(4)).unwrap();
580 }
581 assert_eq!(buf.len(), 30);
583 assert_eq!(buf.hot.len(), 5);
584 assert_eq!(buf.cold_count, 25);
585 }
586
587 #[test]
588 fn test_mmap_buffer_deterministic_sampling() {
589 let (_tmp, path1) = temp_cold_path();
590 let (_tmp2, path2) = temp_cold_path();
591 let mut buf1 = MmapReplayBuffer::new(10, 100, 4, 1, path1).unwrap();
592 let mut buf2 = MmapReplayBuffer::new(10, 100, 4, 1, path2).unwrap();
593
594 for i in 0..20 {
595 let mut rec = sample_record(4);
596 rec.reward = i as f32;
597 buf1.push(rec.clone()).unwrap();
598 buf2.push(rec).unwrap();
599 }
600
601 let b1 = buf1.sample(15, 42).unwrap();
602 let b2 = buf2.sample(15, 42).unwrap();
603
604 assert_eq!(b1.observations, b2.observations);
605 assert_eq!(b1.rewards, b2.rewards);
606 assert_eq!(b1.terminated, b2.terminated);
607 }
608
609 #[test]
610 fn test_mmap_buffer_cleanup_removes_file() {
611 let (_tmp, path) = temp_cold_path();
612 let cold_path = path.clone();
613 let mut buf = MmapReplayBuffer::new(5, 50, 4, 1, path).unwrap();
614
615 for _ in 0..10 {
617 buf.push(sample_record(4)).unwrap();
618 }
619 assert!(cold_path.exists(), "cold file should exist after spill");
620
621 buf.close().unwrap();
622 assert!(
623 !cold_path.exists(),
624 "cold file should be removed after close()"
625 );
626 }
627
628 #[test]
629 fn test_mmap_buffer_large_obs_dim() {
630 let (_tmp, path) = temp_cold_path();
631 let obs_dim = 28224; let act_dim = 1;
633 let mut buf = MmapReplayBuffer::new(5, 20, obs_dim, act_dim, path).unwrap();
634
635 for i in 0..10 {
636 let mut rec = sample_record_multidim(obs_dim, act_dim);
637 rec.reward = i as f32;
638 rec.obs[0] = i as f32;
640 buf.push(rec).unwrap();
641 }
642
643 assert_eq!(buf.len(), 10);
644 assert_eq!(buf.hot.len(), 5);
645 assert_eq!(buf.cold_count, 5);
646
647 let batch = buf.sample(5, 42).unwrap();
649 assert_eq!(batch.observations.len(), 5 * obs_dim);
650 assert_eq!(batch.obs_dim, obs_dim);
651 }
652
653 #[test]
656 fn test_push_batch_fills_hot_correctly() {
657 let (_tmp, path) = temp_cold_path();
658 let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
659
660 let n = 10;
661 let obs: Vec<f32> = (0..n * 4).map(|i| i as f32).collect();
662 let next_obs: Vec<f32> = (0..n * 4).map(|i| i as f32 + 100.0).collect();
663 let actions: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
664 let rewards: Vec<f32> = (0..n).map(|i| i as f32).collect();
665 let terminated = vec![0.0f32; n];
666 let truncated = vec![0.0f32; n];
667
668 buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
669 .unwrap();
670
671 assert_eq!(buf.len(), 10);
672 assert_eq!(buf.cold_count, 0);
673
674 let batch = buf.sample(10, 42).unwrap();
676 assert_eq!(batch.batch_size, 10);
677 assert_eq!(batch.rewards.len(), 10);
678 }
679
680 #[test]
681 fn test_push_batch_triggers_spill_to_cold() {
682 let (_tmp, path) = temp_cold_path();
683 let mut buf = MmapReplayBuffer::new(5, 100, 2, 1, path).unwrap();
684
685 let n = 12;
686 let obs: Vec<f32> = (0..n * 2).map(|i| i as f32).collect();
687 let next_obs: Vec<f32> = (0..n * 2).map(|i| i as f32 + 100.0).collect();
688 let actions: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
689 let rewards: Vec<f32> = (0..n).map(|i| i as f32).collect();
690 let terminated = vec![0.0f32; n];
691 let truncated = vec![0.0f32; n];
692
693 buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
694 .unwrap();
695
696 assert_eq!(buf.hot.len(), 5);
697 assert_eq!(buf.cold_count, 7);
698 assert_eq!(buf.len(), 12);
699 }
700
701 #[test]
702 fn test_push_batch_shape_mismatch() {
703 let (_tmp, path) = temp_cold_path();
704 let mut buf = MmapReplayBuffer::new(100, 1000, 4, 1, path).unwrap();
705
706 let result = buf.push_batch(
708 &[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0, 4.0],
710 &[0.0],
711 &[1.0],
712 &[0.0],
713 &[0.0],
714 );
715 assert!(result.is_err());
716 }
717
718 #[test]
719 fn test_push_batch_terminated_truncated_flags() {
720 let (_tmp, path) = temp_cold_path();
721 let mut buf = MmapReplayBuffer::new(100, 1000, 2, 1, path).unwrap();
722
723 let obs = vec![1.0, 2.0, 3.0, 4.0]; let next_obs = vec![5.0, 6.0, 7.0, 8.0];
725 let actions = vec![0.0, 1.0];
726 let rewards = vec![1.0, 2.0];
727 let terminated = vec![1.0, 0.0]; let truncated = vec![0.0, 1.0]; buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
731 .unwrap();
732
733 let batch = buf.sample(2, 42).unwrap();
734 assert_eq!(batch.terminated.len(), 2);
736 assert_eq!(batch.truncated.len(), 2);
737 }
738
739 #[test]
742 fn test_cold_ring_buffer_eviction_overwrites_oldest() {
743 let (_tmp, path) = temp_cold_path();
744 let mut buf = MmapReplayBuffer::new(3, 6, 2, 1, path).unwrap();
746
747 for i in 0..10 {
750 let rec = ExperienceRecord {
751 obs: vec![i as f32; 2],
752 next_obs: vec![i as f32 + 100.0; 2],
753 action: vec![i as f32 * 0.1],
754 reward: i as f32,
755 terminated: false,
756 truncated: false,
757 };
758 buf.push(rec).unwrap();
759 }
760
761 assert_eq!(buf.len(), 6);
763 assert_eq!(buf.hot.len(), 3);
764 assert_eq!(buf.cold_count, 3);
765
766 let batch = buf.sample(6, 42).unwrap();
769 assert_eq!(batch.batch_size, 6);
770 assert_eq!(batch.rewards.len(), 6);
771
772 for &r in &batch.rewards {
776 assert!(
777 r >= 4.0,
778 "expected only recent records (reward >= 4.0), got {r}"
779 );
780 }
781 }
782
783 #[test]
784 fn test_cold_eviction_cold_write_pos_wraps() {
785 let (_tmp, path) = temp_cold_path();
786 let mut buf = MmapReplayBuffer::new(2, 4, 1, 1, path).unwrap();
788
789 for i in 0..8 {
791 let rec = ExperienceRecord {
792 obs: vec![i as f32],
793 next_obs: vec![i as f32 + 10.0],
794 action: vec![0.0],
795 reward: i as f32,
796 terminated: false,
797 truncated: false,
798 };
799 buf.push(rec).unwrap();
800 }
801
802 assert_eq!(buf.len(), 4);
803 assert_eq!(buf.cold_count, 2);
804 assert_eq!(buf.cold_write_pos, 0);
806 }
807
808 #[test]
809 fn test_sampling_after_cold_eviction_returns_valid_data() {
810 let (_tmp, path) = temp_cold_path();
811 let mut buf = MmapReplayBuffer::new(5, 8, 3, 1, path).unwrap();
813
814 for i in 0..20 {
816 let rec = ExperienceRecord {
817 obs: vec![i as f32; 3],
818 next_obs: vec![i as f32 + 0.5; 3],
819 action: vec![i as f32 * 0.01],
820 reward: i as f32,
821 terminated: i % 3 == 0,
822 truncated: i % 5 == 0,
823 };
824 buf.push(rec).unwrap();
825 }
826
827 assert_eq!(buf.len(), 8);
828
829 let batch = buf.sample(8, 123).unwrap();
831 assert_eq!(batch.batch_size, 8);
832 assert_eq!(batch.observations.len(), 8 * 3);
833 assert_eq!(batch.next_observations.len(), 8 * 3);
834 assert_eq!(batch.actions.len(), 8);
835 assert_eq!(batch.rewards.len(), 8);
836
837 for i in 0..8 {
839 let r = batch.rewards[i];
840 let obs_slice = &batch.observations[i * 3..(i + 1) * 3];
841 for &v in obs_slice {
842 assert!((v - r).abs() < 1e-6, "obs {v} should match reward {r}");
843 }
844 }
845 }
846
847 #[test]
848 fn test_push_batch_then_sample_consistency() {
849 let (_tmp, path) = temp_cold_path();
850 let mut buf = MmapReplayBuffer::new(5, 15, 2, 1, path).unwrap();
851
852 for batch_idx in 0..3 {
854 let base = batch_idx as f32 * 5.0;
855 let n = 5;
856 let obs: Vec<f32> = (0..n)
857 .flat_map(|i| {
858 let v = base + i as f32;
859 vec![v, v + 0.1]
860 })
861 .collect();
862 let next_obs: Vec<f32> = (0..n)
863 .flat_map(|i| {
864 let v = base + i as f32 + 100.0;
865 vec![v, v + 0.1]
866 })
867 .collect();
868 let actions: Vec<f32> = (0..n).map(|i| (base + i as f32) * 0.01).collect();
869 let rewards: Vec<f32> = (0..n).map(|i| base + i as f32).collect();
870 let terminated = vec![0.0f32; n];
871 let truncated = vec![0.0f32; n];
872 buf.push_batch(&obs, &next_obs, &actions, &rewards, &terminated, &truncated)
873 .unwrap();
874 }
875
876 assert_eq!(buf.len(), 15);
877 assert_eq!(buf.hot.len(), 5);
878 assert_eq!(buf.cold_count, 10);
879
880 let batch = buf.sample(15, 7).unwrap();
882 assert_eq!(batch.batch_size, 15);
883 assert_eq!(batch.observations.len(), 15 * 2);
884
885 for i in 0..15 {
887 let o0 = batch.observations[i * 2];
888 let o1 = batch.observations[i * 2 + 1];
889 assert!(
890 (o1 - o0 - 0.1).abs() < 1e-5,
891 "obs pair mismatch at sample {i}: {o0}, {o1}"
892 );
893 }
894 }
895}