1use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12use super::ExperienceRecord;
13
14#[derive(Debug)]
23pub struct SumTree {
24 capacity: usize,
25 tree: Vec<f64>,
26 min_tree: Vec<f64>,
27}
28
29impl SumTree {
30 pub fn new(capacity: usize) -> Self {
38 debug_assert!(capacity > 0, "SumTree capacity must be > 0");
39 let capacity = capacity.next_power_of_two();
40 Self {
41 capacity,
42 tree: vec![0.0; 2 * capacity],
43 min_tree: vec![f64::INFINITY; 2 * capacity],
44 }
45 }
46
47 pub fn capacity(&self) -> usize {
49 self.capacity
50 }
51
52 pub fn total(&self) -> f64 {
54 self.tree[1]
55 }
56
57 pub fn min(&self) -> f64 {
59 self.min_tree[1]
60 }
61
62 pub fn set(&mut self, index: usize, priority: f64) {
68 debug_assert!(index < self.capacity, "SumTree index out of bounds");
69 let mut pos = index + self.capacity;
70 self.tree[pos] = priority;
71 self.min_tree[pos] = priority;
72 while pos > 1 {
73 pos /= 2;
74 self.tree[pos] = self.tree[2 * pos] + self.tree[2 * pos + 1];
75 self.min_tree[pos] = self.min_tree[2 * pos].min(self.min_tree[2 * pos + 1]);
76 }
77 }
78
79 pub fn get(&self, index: usize) -> f64 {
85 debug_assert!(index < self.capacity, "SumTree index out of bounds");
86 self.tree[index + self.capacity]
87 }
88
89 pub fn sample(&self, value: f64) -> usize {
94 debug_assert!(value >= 0.0 && value < self.total() + 1e-12);
95 let mut pos = 1;
96 let mut remaining = value;
97 while pos < self.capacity {
98 let left = 2 * pos;
99 let right = left + 1;
100 if remaining < self.tree[left] {
101 pos = left;
102 } else {
103 remaining -= self.tree[left];
104 pos = right;
105 }
106 }
107 pos - self.capacity
108 }
109}
110
111#[derive(Debug)]
119pub struct PrioritizedReplayBuffer {
120 obs_dim: usize,
121 act_dim: usize,
122 capacity: usize,
123 alpha: f64,
124 beta: f64,
125 tree: SumTree,
126 observations: Vec<f32>,
127 next_observations: Vec<f32>,
128 actions: Vec<f32>,
129 rewards: Vec<f32>,
130 terminated: Vec<bool>,
131 truncated: Vec<bool>,
132 write_pos: usize,
133 count: usize,
134 max_priority: f64,
135}
136
137#[derive(Debug, Clone)]
139pub struct PrioritizedSampledBatch {
140 pub observations: Vec<f32>,
141 pub next_observations: Vec<f32>,
142 pub actions: Vec<f32>,
143 pub rewards: Vec<f32>,
144 pub terminated: Vec<bool>,
145 pub truncated: Vec<bool>,
146 pub obs_dim: usize,
147 pub act_dim: usize,
148 pub batch_size: usize,
149 pub weights: Vec<f64>,
150 pub indices: Vec<usize>,
151}
152
153impl PrioritizedReplayBuffer {
154 pub fn new(capacity: usize, obs_dim: usize, act_dim: usize, alpha: f64, beta: f64) -> Self {
159 Self {
160 obs_dim,
161 act_dim,
162 capacity,
163 alpha,
164 beta,
165 tree: SumTree::new(capacity),
166 observations: vec![0.0; capacity * obs_dim],
167 next_observations: vec![0.0; capacity * obs_dim],
168 actions: vec![0.0; capacity * act_dim],
169 rewards: vec![0.0; capacity],
170 terminated: vec![false; capacity],
171 truncated: vec![false; capacity],
172 write_pos: 0,
173 count: 0,
174 max_priority: 1.0,
175 }
176 }
177
178 pub fn len(&self) -> usize {
180 self.count
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.count == 0
186 }
187
188 #[allow(clippy::too_many_arguments)]
190 pub fn push_slices(
191 &mut self,
192 obs: &[f32],
193 next_obs: &[f32],
194 action: &[f32],
195 reward: f32,
196 terminated: bool,
197 truncated: bool,
198 priority: f64,
199 ) -> Result<(), RloxError> {
200 if priority < 0.0 {
201 return Err(RloxError::BufferError(
202 "priority must be non-negative".into(),
203 ));
204 }
205 if obs.len() != self.obs_dim {
206 return Err(RloxError::ShapeMismatch {
207 expected: format!("obs_dim={}", self.obs_dim),
208 got: format!("obs.len()={}", obs.len()),
209 });
210 }
211 if next_obs.len() != self.obs_dim {
212 return Err(RloxError::ShapeMismatch {
213 expected: format!("obs_dim={}", self.obs_dim),
214 got: format!("next_obs.len()={}", next_obs.len()),
215 });
216 }
217 if action.len() != self.act_dim {
218 return Err(RloxError::ShapeMismatch {
219 expected: format!("act_dim={}", self.act_dim),
220 got: format!("action.len()={}", action.len()),
221 });
222 }
223
224 let idx = self.write_pos;
225 let obs_start = idx * self.obs_dim;
226 self.observations[obs_start..obs_start + self.obs_dim].copy_from_slice(obs);
227 self.next_observations[obs_start..obs_start + self.obs_dim].copy_from_slice(next_obs);
228 let act_start = idx * self.act_dim;
229 self.actions[act_start..act_start + self.act_dim].copy_from_slice(action);
230 self.rewards[idx] = reward;
231 self.terminated[idx] = terminated;
232 self.truncated[idx] = truncated;
233
234 let p_alpha = priority.powf(self.alpha);
235 self.tree.set(idx, p_alpha);
236 if p_alpha > self.max_priority {
237 self.max_priority = p_alpha;
238 }
239
240 self.write_pos = (self.write_pos + 1) % self.capacity;
241 if self.count < self.capacity {
242 self.count += 1;
243 }
244 Ok(())
245 }
246
247 pub fn push(&mut self, record: ExperienceRecord, priority: f64) -> Result<(), RloxError> {
248 self.push_slices(
249 &record.obs,
250 &record.next_obs,
251 &record.action,
252 record.reward,
253 record.terminated,
254 record.truncated,
255 priority,
256 )
257 }
258
259 pub fn set_beta(&mut self, beta: f64) {
261 self.beta = beta;
262 }
263
264 pub fn sample(
266 &self,
267 batch_size: usize,
268 seed: u64,
269 ) -> Result<PrioritizedSampledBatch, RloxError> {
270 if self.count == 0 {
271 return Err(RloxError::BufferError(
272 "cannot sample from empty buffer".into(),
273 ));
274 }
275 if batch_size > self.count {
276 return Err(RloxError::BufferError(format!(
277 "batch_size {} > buffer len {}",
278 batch_size, self.count
279 )));
280 }
281
282 let mut rng = ChaCha8Rng::seed_from_u64(seed);
283 let total = self.tree.total();
284 let segment = total / batch_size as f64;
285
286 let mut batch = PrioritizedSampledBatch {
287 observations: Vec::with_capacity(batch_size * self.obs_dim),
288 next_observations: Vec::with_capacity(batch_size * self.obs_dim),
289 actions: Vec::with_capacity(batch_size * self.act_dim),
290 rewards: Vec::with_capacity(batch_size),
291 terminated: Vec::with_capacity(batch_size),
292 truncated: Vec::with_capacity(batch_size),
293 obs_dim: self.obs_dim,
294 act_dim: self.act_dim,
295 batch_size,
296 weights: Vec::with_capacity(batch_size),
297 indices: Vec::with_capacity(batch_size),
298 };
299
300 let min_prob = self.tree_min_prob();
301 let max_weight = (self.count as f64 * min_prob).powf(-self.beta);
302
303 for i in 0..batch_size {
304 let lo = segment * i as f64;
305 let hi = segment * (i + 1) as f64;
306 let value = rng.random_range(lo..hi);
307 let idx = self.tree.sample(value.min(total - 1e-12));
308
309 debug_assert!(
313 idx < self.count,
314 "SumTree sampled index {idx} >= count {}, total={total}, value={value}",
315 self.count
316 );
317 let idx = idx.min(self.count - 1);
318
319 let obs_start = idx * self.obs_dim;
320 batch
321 .observations
322 .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
323 batch
324 .next_observations
325 .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
326 let act_start = idx * self.act_dim;
327 batch
328 .actions
329 .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
330 batch.rewards.push(self.rewards[idx]);
331 batch.terminated.push(self.terminated[idx]);
332 batch.truncated.push(self.truncated[idx]);
333
334 let prob = self.tree.get(idx) / total;
335 let weight = (self.count as f64 * prob).powf(-self.beta);
336 batch.weights.push(weight / max_weight);
337 batch.indices.push(idx);
338 }
339
340 Ok(batch)
341 }
342
343 pub fn update_priorities(
345 &mut self,
346 indices: &[usize],
347 priorities: &[f64],
348 ) -> Result<(), RloxError> {
349 if indices.len() != priorities.len() {
350 return Err(RloxError::BufferError(
351 "indices and priorities must have same length".into(),
352 ));
353 }
354 for (&idx, &p) in indices.iter().zip(priorities.iter()) {
355 if p < 0.0 {
356 return Err(RloxError::BufferError(
357 "priority must be non-negative".into(),
358 ));
359 }
360 if idx >= self.count {
361 return Err(RloxError::BufferError(format!(
362 "index {} >= buffer len {}",
363 idx, self.count
364 )));
365 }
366 let p_alpha = p.powf(self.alpha);
367 self.tree.set(idx, p_alpha);
368 if p_alpha > self.max_priority {
369 self.max_priority = p_alpha;
370 }
371 }
372 Ok(())
373 }
374
375 fn tree_min_prob(&self) -> f64 {
377 let total = self.tree.total();
378 if total == 0.0 {
379 return 1.0;
380 }
381 let min_p = self.tree.min();
382 if min_p <= 0.0 || min_p == f64::INFINITY {
383 1.0 / self.count as f64
384 } else {
385 min_p / total
386 }
387 }
388}
389
390pub struct LAPConfig {
402 pub eta: f64,
404 pub min_priority: f64,
406}
407
408impl Default for LAPConfig {
409 fn default() -> Self {
410 Self {
411 eta: 1.0,
412 min_priority: 1e-6,
413 }
414 }
415}
416
417#[inline]
421pub fn compute_lap_priorities(
422 td_errors: &[f64],
423 losses: &[f64],
424 config: &LAPConfig,
425) -> Result<Vec<f64>, RloxError> {
426 if td_errors.len() != losses.len() {
427 return Err(RloxError::ShapeMismatch {
428 expected: format!("td_errors.len()={}", td_errors.len()),
429 got: format!("losses.len()={}", losses.len()),
430 });
431 }
432
433 let priorities = td_errors
434 .iter()
435 .zip(losses.iter())
436 .map(|(&td, &loss)| (td.abs() + config.eta * loss).max(config.min_priority))
437 .collect();
438 Ok(priorities)
439}
440
441#[inline]
445pub fn compute_td_priorities(td_errors: &[f64], min_priority: f64) -> Vec<f64> {
446 td_errors
447 .iter()
448 .map(|&td| td.abs().max(min_priority))
449 .collect()
450}
451
452impl PrioritizedReplayBuffer {
453 pub fn update_priorities_from_loss(
458 &mut self,
459 indices: &[usize],
460 losses: &[f64],
461 epsilon: f64,
462 ) -> Result<(), RloxError> {
463 if indices.len() != losses.len() {
464 return Err(RloxError::ShapeMismatch {
465 expected: format!("indices.len()={}", indices.len()),
466 got: format!("losses.len()={}", losses.len()),
467 });
468 }
469 let priorities: Vec<f64> = losses.iter().map(|&l| l.abs() + epsilon).collect();
470 self.update_priorities(indices, &priorities)
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::buffer::{sample_record, sample_record_multidim};
478
479 #[test]
482 fn sum_tree_new_has_zero_total() {
483 let tree = SumTree::new(8);
484 assert_eq!(tree.total(), 0.0);
485 assert_eq!(tree.capacity(), 8);
486 }
487
488 #[test]
489 fn sum_tree_set_and_get() {
490 let mut tree = SumTree::new(4);
491 tree.set(0, 1.0);
492 tree.set(1, 2.0);
493 tree.set(2, 3.0);
494 tree.set(3, 4.0);
495 assert_eq!(tree.get(0), 1.0);
496 assert_eq!(tree.get(1), 2.0);
497 assert_eq!(tree.get(2), 3.0);
498 assert_eq!(tree.get(3), 4.0);
499 assert_eq!(tree.total(), 10.0);
500 }
501
502 #[test]
503 fn sum_tree_update_propagates() {
504 let mut tree = SumTree::new(4);
505 tree.set(0, 1.0);
506 tree.set(1, 1.0);
507 tree.set(2, 1.0);
508 tree.set(3, 1.0);
509 assert_eq!(tree.total(), 4.0);
510
511 tree.set(2, 5.0);
512 assert_eq!(tree.total(), 8.0);
513 assert_eq!(tree.get(2), 5.0);
514 }
515
516 #[test]
517 fn sum_tree_sample_returns_correct_leaf() {
518 let mut tree = SumTree::new(4);
519 tree.set(0, 1.0);
520 tree.set(1, 2.0);
521 tree.set(2, 3.0);
522 tree.set(3, 4.0);
523 assert_eq!(tree.sample(0.0), 0);
525 assert_eq!(tree.sample(0.5), 0);
526 assert_eq!(tree.sample(1.0), 1);
527 assert_eq!(tree.sample(2.9), 1);
528 assert_eq!(tree.sample(3.0), 2);
529 assert_eq!(tree.sample(5.9), 2);
530 assert_eq!(tree.sample(6.0), 3);
531 assert_eq!(tree.sample(9.9), 3);
532 }
533
534 #[test]
535 fn sum_tree_single_leaf() {
536 let mut tree = SumTree::new(1);
537 tree.set(0, 5.0);
538 assert_eq!(tree.total(), 5.0);
539 assert_eq!(tree.sample(0.0), 0);
540 assert_eq!(tree.sample(4.9), 0);
541 }
542
543 #[test]
544 fn sum_tree_min_tracks_minimum() {
545 let mut tree = SumTree::new(4);
546 tree.set(0, 3.0);
547 tree.set(1, 1.0);
548 tree.set(2, 5.0);
549 tree.set(3, 2.0);
550 assert!((tree.min() - 1.0).abs() < 1e-12);
551
552 tree.set(1, 10.0);
554 assert!((tree.min() - 2.0).abs() < 1e-12);
555
556 tree.set(3, 0.5);
558 assert!((tree.min() - 0.5).abs() < 1e-12);
559 }
560
561 #[test]
562 fn sum_tree_min_empty_is_infinity() {
563 let tree = SumTree::new(4);
564 assert!(tree.min().is_infinite());
565 }
566
567 #[test]
570 fn prb_new_is_empty() {
571 let buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
572 assert_eq!(buf.len(), 0);
573 assert!(buf.is_empty());
574 }
575
576 #[test]
577 fn prb_push_increments_len() {
578 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
579 buf.push(sample_record(4), 1.0).unwrap();
580 assert_eq!(buf.len(), 1);
581 }
582
583 #[test]
584 fn prb_negative_priority_errors() {
585 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
586 let result = buf.push(sample_record(4), -1.0);
587 assert!(result.is_err());
588 assert!(result.unwrap_err().to_string().contains("non-negative"));
589 }
590
591 #[test]
592 fn prb_sample_empty_errors() {
593 let buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
594 let result = buf.sample(1, 42);
595 assert!(result.is_err());
596 }
597
598 #[test]
599 fn prb_sample_too_large_errors() {
600 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
601 buf.push(sample_record(4), 1.0).unwrap();
602 let result = buf.sample(10, 42);
603 assert!(result.is_err());
604 }
605
606 #[test]
607 fn prb_sample_returns_correct_size() {
608 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
609 for _ in 0..50 {
610 buf.push(sample_record(4), 1.0).unwrap();
611 }
612 let batch = buf.sample(16, 42).unwrap();
613 assert_eq!(batch.batch_size, 16);
614 assert_eq!(batch.observations.len(), 16 * 4);
615 assert_eq!(batch.actions.len(), 16);
616 assert_eq!(batch.rewards.len(), 16);
617 assert_eq!(batch.weights.len(), 16);
618 assert_eq!(batch.indices.len(), 16);
619 }
620
621 #[test]
622 fn prb_weights_are_in_zero_one() {
623 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
624 for i in 0..50 {
625 buf.push(sample_record(4), (i + 1) as f64).unwrap();
626 }
627 let batch = buf.sample(16, 42).unwrap();
628 for &w in &batch.weights {
629 assert!(w > 0.0, "weight must be positive, got {w}");
630 assert!(w <= 1.0 + 1e-10, "weight must be <= 1.0, got {w}");
631 }
632 }
633
634 #[test]
635 fn prb_high_priority_sampled_more_often() {
636 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
637 let mut rec = sample_record(4);
639 rec.reward = 99.0;
640 buf.push(rec, 100.0).unwrap();
641 for _ in 1..50 {
642 buf.push(sample_record(4), 1.0).unwrap();
643 }
644
645 let mut count_high = 0;
647 for seed in 0..100 {
648 let batch = buf.sample(10, seed).unwrap();
649 for &idx in &batch.indices {
650 if idx == 0 {
651 count_high += 1;
652 }
653 }
654 }
655 assert!(
657 count_high > 200,
658 "high priority item should be sampled frequently, got {count_high}/1000"
659 );
660 }
661
662 #[test]
663 fn prb_update_priorities() {
664 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
665 for _ in 0..10 {
666 buf.push(sample_record(4), 1.0).unwrap();
667 }
668 buf.update_priorities(&[0], &[100.0]).unwrap();
670
671 let mut count_idx0 = 0;
672 for seed in 0..50 {
673 let batch = buf.sample(5, seed).unwrap();
674 for &idx in &batch.indices {
675 if idx == 0 {
676 count_idx0 += 1;
677 }
678 }
679 }
680 assert!(
681 count_idx0 > 100,
682 "updated high-priority item should be sampled frequently"
683 );
684 }
685
686 #[test]
687 fn prb_update_priorities_negative_errors() {
688 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
689 buf.push(sample_record(4), 1.0).unwrap();
690 let result = buf.update_priorities(&[0], &[-1.0]);
691 assert!(result.is_err());
692 }
693
694 #[test]
695 fn prb_update_priorities_oob_errors() {
696 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
697 buf.push(sample_record(4), 1.0).unwrap();
698 let result = buf.update_priorities(&[5], &[1.0]);
699 assert!(result.is_err());
700 }
701
702 #[test]
703 fn prb_update_priorities_length_mismatch_errors() {
704 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
705 buf.push(sample_record(4), 1.0).unwrap();
706 let result = buf.update_priorities(&[0], &[1.0, 2.0]);
707 assert!(result.is_err());
708 }
709
710 #[test]
711 fn prb_set_beta() {
712 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
713 buf.set_beta(1.0);
714 }
716
717 #[test]
718 fn prb_wraps_around() {
719 let mut buf = PrioritizedReplayBuffer::new(5, 4, 1, 0.6, 0.4);
720 for i in 0..10 {
721 let mut rec = sample_record(4);
722 rec.reward = i as f32;
723 buf.push(rec, 1.0).unwrap();
724 }
725 assert_eq!(buf.len(), 5);
726 let batch = buf.sample(5, 42).unwrap();
728 for &r in &batch.rewards {
729 assert!(r >= 5.0, "old data should be overwritten, got reward {r}");
730 }
731 }
732
733 #[test]
734 fn prb_multidim_actions() {
735 let mut buf = PrioritizedReplayBuffer::new(100, 4, 3, 0.6, 0.4);
736 buf.push(sample_record_multidim(4, 3), 1.0).unwrap();
737 let batch = buf.sample(1, 42).unwrap();
738 assert_eq!(batch.act_dim, 3);
739 assert_eq!(batch.actions.len(), 3);
740 }
741
742 #[test]
743 fn prb_deterministic_with_same_seed() {
744 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
745 for _ in 0..50 {
746 buf.push(sample_record(4), 1.0).unwrap();
747 }
748 let b1 = buf.sample(16, 42).unwrap();
749 let b2 = buf.sample(16, 42).unwrap();
750 assert_eq!(b1.indices, b2.indices);
751 assert_eq!(b1.weights, b2.weights);
752 }
753
754 #[test]
755 fn prb_next_obs_roundtrip() {
756 let obs_dim = 4;
757 let mut buf = PrioritizedReplayBuffer::new(100, obs_dim, 1, 0.6, 0.4);
758 let record = ExperienceRecord {
759 obs: vec![1.0; obs_dim],
760 next_obs: vec![2.0, 3.0, 4.0, 5.0],
761 action: vec![0.0],
762 reward: 1.0,
763 terminated: false,
764 truncated: false,
765 };
766 buf.push(record, 1.0).unwrap();
767 let batch = buf.sample(1, 42).unwrap();
768 assert_eq!(&batch.next_observations, &[2.0, 3.0, 4.0, 5.0]);
769 }
770
771 #[test]
772 fn prb_next_obs_shape() {
773 let obs_dim = 4;
774 let mut buf = PrioritizedReplayBuffer::new(200, obs_dim, 1, 0.6, 0.4);
775 for _ in 0..100 {
776 buf.push(sample_record(obs_dim), 1.0).unwrap();
777 }
778 let batch = buf.sample(32, 42).unwrap();
779 assert_eq!(batch.next_observations.len(), 32 * obs_dim);
780 }
781
782 #[test]
783 fn prb_obs_dim_mismatch_errors() {
784 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
785 let result = buf.push(sample_record(8), 1.0);
786 assert!(result.is_err());
787 }
788
789 mod proptests {
792 use super::*;
793 use proptest::prelude::*;
794
795 proptest! {
796 #[test]
797 fn sum_tree_total_equals_sum_of_leaves(
798 priorities in proptest::collection::vec(0.0f64..100.0, 1..64)
799 ) {
800 let n = priorities.len();
801 let mut tree = SumTree::new(n);
802 let mut expected = 0.0;
803 for (i, &p) in priorities.iter().enumerate() {
804 tree.set(i, p);
805 expected += p;
806 }
807 let diff = (tree.total() - expected).abs();
808 prop_assert!(diff < 1e-6, "total {} != expected {}", tree.total(), expected);
809 }
810
811 #[test]
812 fn sum_tree_sample_in_range(
813 priorities in proptest::collection::vec(1.0f64..100.0, 1..64)
814 ) {
815 let n = priorities.len();
816 let mut tree = SumTree::new(n);
817 for (i, &p) in priorities.iter().enumerate() {
818 tree.set(i, p);
819 }
820 let total = tree.total();
822 for v in [0.0, total * 0.25, total * 0.5, total * 0.75, total * 0.999] {
823 let idx = tree.sample(v);
824 prop_assert!(idx < n, "sampled index {} >= capacity {}", idx, n);
825 }
826 }
827
828 #[test]
829 fn prb_never_exceeds_capacity(
830 capacity in 1..100usize,
831 num_pushes in 0..300usize,
832 ) {
833 let mut buf = PrioritizedReplayBuffer::new(capacity, 4, 1, 0.6, 0.4);
834 for _ in 0..num_pushes {
835 buf.push(sample_record(4), 1.0).unwrap();
836 }
837 prop_assert!(buf.len() <= capacity);
838 prop_assert_eq!(buf.len(), num_pushes.min(capacity));
839 }
840
841 #[test]
842 fn prb_weights_are_valid(
843 num_pushes in 10..100usize,
844 batch_size in 1..10usize,
845 ) {
846 let mut buf = PrioritizedReplayBuffer::new(200, 4, 1, 0.6, 0.4);
847 for i in 0..num_pushes {
848 buf.push(sample_record(4), (i + 1) as f64).unwrap();
849 }
850 let effective_batch = batch_size.min(buf.len());
851 let batch = buf.sample(effective_batch, 42).unwrap();
852 for &w in &batch.weights {
853 prop_assert!(w > 0.0, "weight must be positive");
854 prop_assert!(w <= 1.0 + 1e-10, "weight must be <= 1.0");
855 }
856 for &idx in &batch.indices {
857 prop_assert!(idx < buf.len(), "index must be < len");
858 }
859 }
860 }
861 }
862
863 #[test]
866 fn test_lap_known_values() {
867 let td = &[1.0, -2.0];
868 let loss = &[0.5, 0.3];
869 let config = LAPConfig {
870 eta: 1.0,
871 min_priority: 1e-6,
872 };
873 let result = compute_lap_priorities(td, loss, &config).unwrap();
874 assert!(
875 (result[0] - 1.5).abs() < 1e-10,
876 "expected 1.5, got {}",
877 result[0]
878 );
879 assert!(
880 (result[1] - 2.3).abs() < 1e-10,
881 "expected 2.3, got {}",
882 result[1]
883 );
884 }
885
886 #[test]
887 fn test_lap_eta_zero_is_standard_per() {
888 let td = &[1.0, -2.0];
889 let loss = &[999.0, 999.0];
890 let config = LAPConfig {
891 eta: 0.0,
892 min_priority: 1e-6,
893 };
894 let result = compute_lap_priorities(td, loss, &config).unwrap();
895 assert!((result[0] - 1.0).abs() < 1e-10);
896 assert!((result[1] - 2.0).abs() < 1e-10);
897 }
898
899 #[test]
900 fn test_lap_min_priority_floor() {
901 let td = &[0.0];
902 let loss = &[0.0];
903 let config = LAPConfig {
904 eta: 1.0,
905 min_priority: 1e-6,
906 };
907 let result = compute_lap_priorities(td, loss, &config).unwrap();
908 assert!(
909 (result[0] - 1e-6).abs() < 1e-12,
910 "expected 1e-6, got {}",
911 result[0]
912 );
913 }
914
915 #[test]
916 fn test_lap_negative_td_uses_abs() {
917 let td = &[-3.0];
918 let loss = &[0.0];
919 let config = LAPConfig::default();
920 let result = compute_lap_priorities(td, loss, &config).unwrap();
921 assert!((result[0] - 3.0).abs() < 1e-10);
922 }
923
924 #[test]
925 fn test_lap_length_mismatch() {
926 let result = compute_lap_priorities(&[1.0, 2.0, 3.0], &[0.5, 0.5], &LAPConfig::default());
927 assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
928 }
929
930 #[test]
931 fn test_td_priorities_known_values() {
932 let td = &[0.5, -1.0, 0.0];
933 let result = compute_td_priorities(td, 0.01);
934 assert!((result[0] - 0.5).abs() < 1e-10);
935 assert!((result[1] - 1.0).abs() < 1e-10);
936 assert!((result[2] - 0.01).abs() < 1e-10);
937 }
938
939 #[test]
940 fn test_lap_integration_with_per_buffer() {
941 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
942 let config = LAPConfig::default();
943 for i in 0..10 {
944 let td = &[(i + 1) as f64];
945 let loss = &[i as f64 * 0.5];
946 let priorities = compute_lap_priorities(td, loss, &config).unwrap();
947 buf.push(sample_record(4), priorities[0]).unwrap();
948 }
949 let batch = buf.sample(5, 42).unwrap();
950 assert_eq!(batch.batch_size, 5);
951 for &w in &batch.weights {
952 assert!(w > 0.0 && w <= 1.0 + 1e-10);
953 }
954 }
955
956 #[test]
957 fn test_lap_large_eta_emphasizes_loss() {
958 let td = &[0.01];
959 let loss = &[10.0];
960 let config = LAPConfig {
961 eta: 100.0,
962 min_priority: 1e-6,
963 };
964 let result = compute_lap_priorities(td, loss, &config).unwrap();
965 let expected = 0.01 + 100.0 * 10.0;
966 assert!((result[0] - expected).abs() < 1e-10);
967 }
968
969 #[test]
970 fn test_update_priorities_from_loss_correct() {
971 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
972 for _ in 0..10 {
973 buf.push(sample_record(4), 1.0).unwrap();
974 }
975 buf.update_priorities_from_loss(&[0, 1, 2], &[5.0, 0.0, 3.0], 0.01)
976 .unwrap();
977 }
980
981 #[test]
982 fn test_update_priorities_from_loss_mismatched_lengths() {
983 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
984 for _ in 0..10 {
985 buf.push(sample_record(4), 1.0).unwrap();
986 }
987 let result = buf.update_priorities_from_loss(&[0, 1], &[5.0], 0.01);
988 assert!(result.is_err());
989 }
990
991 #[test]
992 fn test_update_priorities_from_loss_zero_gets_epsilon() {
993 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
994 for _ in 0..5 {
995 buf.push(sample_record(4), 1.0).unwrap();
996 }
997 let epsilon = 0.001;
998 buf.update_priorities_from_loss(&[0], &[0.0], epsilon)
999 .unwrap();
1000 }
1002
1003 #[test]
1004 fn test_update_priorities_from_loss_high_loss_favored() {
1005 let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
1006 for _ in 0..10 {
1007 buf.push(sample_record(4), 1.0).unwrap();
1008 }
1009 buf.update_priorities_from_loss(&[0], &[100.0], 0.01)
1011 .unwrap();
1012 for i in 1..10 {
1014 buf.update_priorities_from_loss(&[i], &[0.001], 0.01)
1015 .unwrap();
1016 }
1017 let mut count_idx0 = 0;
1019 for seed in 0..100 {
1020 let batch = buf.sample(5, seed).unwrap();
1021 for &idx in &batch.indices {
1022 if idx == 0 {
1023 count_idx0 += 1;
1024 }
1025 }
1026 }
1027 assert!(
1028 count_idx0 > 100,
1029 "high-loss item should be sampled frequently, got {count_idx0}/500"
1030 );
1031 }
1032
1033 mod lap_proptests {
1034 use super::*;
1035 use proptest::prelude::*;
1036
1037 proptest! {
1038 #[test]
1039 fn prop_lap_priorities_non_negative(
1040 n in 1usize..100,
1041 ) {
1042 let td: Vec<f64> = (0..n).map(|i| (i as f64) - 50.0).collect();
1043 let loss: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
1044 let config = LAPConfig::default();
1045 let priorities = compute_lap_priorities(&td, &loss, &config).unwrap();
1046 for (i, &p) in priorities.iter().enumerate() {
1047 prop_assert!(p >= config.min_priority,
1048 "priority[{i}]={p} < min_priority={}", config.min_priority);
1049 }
1050 }
1051
1052 #[test]
1053 fn prop_lap_monotone_in_td(
1054 td_a in 0.0f64..100.0,
1055 td_b in 0.0f64..100.0,
1056 loss in 0.0f64..10.0,
1057 ) {
1058 let config = LAPConfig {
1059 eta: 1.0,
1060 min_priority: 0.0, };
1062 let p_a = compute_lap_priorities(&[td_a], &[loss], &config).unwrap()[0];
1063 let p_b = compute_lap_priorities(&[td_b], &[loss], &config).unwrap()[0];
1064 if td_a.abs() > td_b.abs() {
1065 prop_assert!(p_a >= p_b,
1066 "|td_a|={} > |td_b|={} but p_a={} < p_b={}",
1067 td_a.abs(), td_b.abs(), p_a, p_b);
1068 }
1069 }
1070 }
1071 }
1072}