rlox_core/buffer/
priority.rs

1//! Sum-tree backed prioritized experience replay.
2//!
3//! Provides O(log N) prefix-sum sampling and priority updates for
4//! Prioritized Experience Replay (Schaul et al., 2015).
5
6use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12use super::ExperienceRecord;
13
14// ---------------------------------------------------------------------------
15// SumTree
16// ---------------------------------------------------------------------------
17
18/// Binary sum-tree for O(log N) prefix-sum queries.
19///
20/// Internally stores `2 * capacity` nodes where leaves occupy indices
21/// `[capacity .. 2*capacity)` and internal nodes hold partial sums.
22#[derive(Debug)]
23pub struct SumTree {
24    capacity: usize,
25    tree: Vec<f64>,
26    min_tree: Vec<f64>,
27}
28
29impl SumTree {
30    /// Create a sum-tree with `capacity` leaves, all initialised to zero.
31    ///
32    /// The capacity is rounded up to the next power of two for efficient
33    /// binary-tree indexing. Unused leaves beyond the logical capacity
34    /// hold priority `0.0` (sum-tree) and `f64::INFINITY` (min-tree).
35    /// The `min()` method may therefore return `INFINITY` when the buffer
36    /// is not full; callers should handle this (see [`PrioritizedReplayBuffer::tree_min_prob`]).
37    pub fn new(capacity: usize) -> Self {
38        debug_assert!(capacity > 0, "SumTree capacity must be > 0");
39        let capacity = capacity.next_power_of_two();
40        Self {
41            capacity,
42            tree: vec![0.0; 2 * capacity],
43            min_tree: vec![f64::INFINITY; 2 * capacity],
44        }
45    }
46
47    /// Number of leaves.
48    pub fn capacity(&self) -> usize {
49        self.capacity
50    }
51
52    /// Sum of all leaf priorities.
53    pub fn total(&self) -> f64 {
54        self.tree[1]
55    }
56
57    /// Minimum leaf priority (O(1) via the min-tree).
58    pub fn min(&self) -> f64 {
59        self.min_tree[1]
60    }
61
62    /// Set the priority of leaf `index`.
63    ///
64    /// # Panics (debug only)
65    ///
66    /// Panics in debug builds if `index >= capacity`.
67    pub fn set(&mut self, index: usize, priority: f64) {
68        debug_assert!(index < self.capacity, "SumTree index out of bounds");
69        let mut pos = index + self.capacity;
70        self.tree[pos] = priority;
71        self.min_tree[pos] = priority;
72        while pos > 1 {
73            pos /= 2;
74            self.tree[pos] = self.tree[2 * pos] + self.tree[2 * pos + 1];
75            self.min_tree[pos] = self.min_tree[2 * pos].min(self.min_tree[2 * pos + 1]);
76        }
77    }
78
79    /// Get the priority of leaf `index`.
80    ///
81    /// # Panics (debug only)
82    ///
83    /// Panics in debug builds if `index >= capacity`.
84    pub fn get(&self, index: usize) -> f64 {
85        debug_assert!(index < self.capacity, "SumTree index out of bounds");
86        self.tree[index + self.capacity]
87    }
88
89    /// Sample a leaf index such that the probability of choosing leaf `i` is
90    /// `priority[i] / total()`.
91    ///
92    /// `value` should be in `[0, total())`.
93    pub fn sample(&self, value: f64) -> usize {
94        debug_assert!(value >= 0.0 && value < self.total() + 1e-12);
95        let mut pos = 1;
96        let mut remaining = value;
97        while pos < self.capacity {
98            let left = 2 * pos;
99            let right = left + 1;
100            if remaining < self.tree[left] {
101                pos = left;
102            } else {
103                remaining -= self.tree[left];
104                pos = right;
105            }
106        }
107        pos - self.capacity
108    }
109}
110
111// ---------------------------------------------------------------------------
112// PrioritizedReplayBuffer
113// ---------------------------------------------------------------------------
114
115/// Prioritized experience replay buffer backed by a sum-tree.
116///
117/// Implements proportional prioritization with importance-sampling weights.
118#[derive(Debug)]
119pub struct PrioritizedReplayBuffer {
120    obs_dim: usize,
121    act_dim: usize,
122    capacity: usize,
123    alpha: f64,
124    beta: f64,
125    tree: SumTree,
126    observations: Vec<f32>,
127    next_observations: Vec<f32>,
128    actions: Vec<f32>,
129    rewards: Vec<f32>,
130    terminated: Vec<bool>,
131    truncated: Vec<bool>,
132    write_pos: usize,
133    count: usize,
134    max_priority: f64,
135}
136
137/// A sampled batch with importance-sampling weights.
138#[derive(Debug, Clone)]
139pub struct PrioritizedSampledBatch {
140    pub observations: Vec<f32>,
141    pub next_observations: Vec<f32>,
142    pub actions: Vec<f32>,
143    pub rewards: Vec<f32>,
144    pub terminated: Vec<bool>,
145    pub truncated: Vec<bool>,
146    pub obs_dim: usize,
147    pub act_dim: usize,
148    pub batch_size: usize,
149    pub weights: Vec<f64>,
150    pub indices: Vec<usize>,
151}
152
153impl PrioritizedReplayBuffer {
154    /// Create a new prioritized replay buffer.
155    ///
156    /// * `alpha` — prioritization exponent (0 = uniform, 1 = full prioritization)
157    /// * `beta` — importance-sampling correction exponent (1 = full correction)
158    pub fn new(capacity: usize, obs_dim: usize, act_dim: usize, alpha: f64, beta: f64) -> Self {
159        Self {
160            obs_dim,
161            act_dim,
162            capacity,
163            alpha,
164            beta,
165            tree: SumTree::new(capacity),
166            observations: vec![0.0; capacity * obs_dim],
167            next_observations: vec![0.0; capacity * obs_dim],
168            actions: vec![0.0; capacity * act_dim],
169            rewards: vec![0.0; capacity],
170            terminated: vec![false; capacity],
171            truncated: vec![false; capacity],
172            write_pos: 0,
173            count: 0,
174            max_priority: 1.0,
175        }
176    }
177
178    /// Number of valid transitions stored.
179    pub fn len(&self) -> usize {
180        self.count
181    }
182
183    /// Whether the buffer is empty.
184    pub fn is_empty(&self) -> bool {
185        self.count == 0
186    }
187
188    /// Push a transition with the given TD-error priority.
189    #[allow(clippy::too_many_arguments)]
190    pub fn push_slices(
191        &mut self,
192        obs: &[f32],
193        next_obs: &[f32],
194        action: &[f32],
195        reward: f32,
196        terminated: bool,
197        truncated: bool,
198        priority: f64,
199    ) -> Result<(), RloxError> {
200        if priority < 0.0 {
201            return Err(RloxError::BufferError(
202                "priority must be non-negative".into(),
203            ));
204        }
205        if obs.len() != self.obs_dim {
206            return Err(RloxError::ShapeMismatch {
207                expected: format!("obs_dim={}", self.obs_dim),
208                got: format!("obs.len()={}", obs.len()),
209            });
210        }
211        if next_obs.len() != self.obs_dim {
212            return Err(RloxError::ShapeMismatch {
213                expected: format!("obs_dim={}", self.obs_dim),
214                got: format!("next_obs.len()={}", next_obs.len()),
215            });
216        }
217        if action.len() != self.act_dim {
218            return Err(RloxError::ShapeMismatch {
219                expected: format!("act_dim={}", self.act_dim),
220                got: format!("action.len()={}", action.len()),
221            });
222        }
223
224        let idx = self.write_pos;
225        let obs_start = idx * self.obs_dim;
226        self.observations[obs_start..obs_start + self.obs_dim].copy_from_slice(obs);
227        self.next_observations[obs_start..obs_start + self.obs_dim].copy_from_slice(next_obs);
228        let act_start = idx * self.act_dim;
229        self.actions[act_start..act_start + self.act_dim].copy_from_slice(action);
230        self.rewards[idx] = reward;
231        self.terminated[idx] = terminated;
232        self.truncated[idx] = truncated;
233
234        let p_alpha = priority.powf(self.alpha);
235        self.tree.set(idx, p_alpha);
236        if p_alpha > self.max_priority {
237            self.max_priority = p_alpha;
238        }
239
240        self.write_pos = (self.write_pos + 1) % self.capacity;
241        if self.count < self.capacity {
242            self.count += 1;
243        }
244        Ok(())
245    }
246
247    pub fn push(&mut self, record: ExperienceRecord, priority: f64) -> Result<(), RloxError> {
248        self.push_slices(
249            &record.obs,
250            &record.next_obs,
251            &record.action,
252            record.reward,
253            record.terminated,
254            record.truncated,
255            priority,
256        )
257    }
258
259    /// Set the importance-sampling beta parameter.
260    pub fn set_beta(&mut self, beta: f64) {
261        self.beta = beta;
262    }
263
264    /// Sample a batch with importance-sampling weights.
265    pub fn sample(
266        &self,
267        batch_size: usize,
268        seed: u64,
269    ) -> Result<PrioritizedSampledBatch, RloxError> {
270        if self.count == 0 {
271            return Err(RloxError::BufferError(
272                "cannot sample from empty buffer".into(),
273            ));
274        }
275        if batch_size > self.count {
276            return Err(RloxError::BufferError(format!(
277                "batch_size {} > buffer len {}",
278                batch_size, self.count
279            )));
280        }
281
282        let mut rng = ChaCha8Rng::seed_from_u64(seed);
283        let total = self.tree.total();
284        let segment = total / batch_size as f64;
285
286        let mut batch = PrioritizedSampledBatch {
287            observations: Vec::with_capacity(batch_size * self.obs_dim),
288            next_observations: Vec::with_capacity(batch_size * self.obs_dim),
289            actions: Vec::with_capacity(batch_size * self.act_dim),
290            rewards: Vec::with_capacity(batch_size),
291            terminated: Vec::with_capacity(batch_size),
292            truncated: Vec::with_capacity(batch_size),
293            obs_dim: self.obs_dim,
294            act_dim: self.act_dim,
295            batch_size,
296            weights: Vec::with_capacity(batch_size),
297            indices: Vec::with_capacity(batch_size),
298        };
299
300        let min_prob = self.tree_min_prob();
301        let max_weight = (self.count as f64 * min_prob).powf(-self.beta);
302
303        for i in 0..batch_size {
304            let lo = segment * i as f64;
305            let hi = segment * (i + 1) as f64;
306            let value = rng.random_range(lo..hi);
307            let idx = self.tree.sample(value.min(total - 1e-12));
308
309            // Unused tree leaves have priority 0, so sample() should never
310            // return an index beyond self.count. Assert this rather than
311            // silently clamping (which would bias sampling toward the last entry).
312            debug_assert!(
313                idx < self.count,
314                "SumTree sampled index {idx} >= count {}, total={total}, value={value}",
315                self.count
316            );
317            let idx = idx.min(self.count - 1);
318
319            let obs_start = idx * self.obs_dim;
320            batch
321                .observations
322                .extend_from_slice(&self.observations[obs_start..obs_start + self.obs_dim]);
323            batch
324                .next_observations
325                .extend_from_slice(&self.next_observations[obs_start..obs_start + self.obs_dim]);
326            let act_start = idx * self.act_dim;
327            batch
328                .actions
329                .extend_from_slice(&self.actions[act_start..act_start + self.act_dim]);
330            batch.rewards.push(self.rewards[idx]);
331            batch.terminated.push(self.terminated[idx]);
332            batch.truncated.push(self.truncated[idx]);
333
334            let prob = self.tree.get(idx) / total;
335            let weight = (self.count as f64 * prob).powf(-self.beta);
336            batch.weights.push(weight / max_weight);
337            batch.indices.push(idx);
338        }
339
340        Ok(batch)
341    }
342
343    /// Update priorities for previously sampled indices.
344    pub fn update_priorities(
345        &mut self,
346        indices: &[usize],
347        priorities: &[f64],
348    ) -> Result<(), RloxError> {
349        if indices.len() != priorities.len() {
350            return Err(RloxError::BufferError(
351                "indices and priorities must have same length".into(),
352            ));
353        }
354        for (&idx, &p) in indices.iter().zip(priorities.iter()) {
355            if p < 0.0 {
356                return Err(RloxError::BufferError(
357                    "priority must be non-negative".into(),
358                ));
359            }
360            if idx >= self.count {
361                return Err(RloxError::BufferError(format!(
362                    "index {} >= buffer len {}",
363                    idx, self.count
364                )));
365            }
366            let p_alpha = p.powf(self.alpha);
367            self.tree.set(idx, p_alpha);
368            if p_alpha > self.max_priority {
369                self.max_priority = p_alpha;
370            }
371        }
372        Ok(())
373    }
374
375    /// Find the minimum non-zero probability among valid leaves (O(1) via min-tree).
376    fn tree_min_prob(&self) -> f64 {
377        let total = self.tree.total();
378        if total == 0.0 {
379            return 1.0;
380        }
381        let min_p = self.tree.min();
382        if min_p <= 0.0 || min_p == f64::INFINITY {
383            1.0 / self.count as f64
384        } else {
385            min_p / total
386        }
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Loss-Adjusted Prioritization (LAP)
392// ---------------------------------------------------------------------------
393
394/// Loss-Adjusted Prioritization (LAP) configuration.
395///
396/// LAP uses `priority = |TD_error| + eta * loss` where `eta` is a
397/// configurable scale factor. This provides better prioritization
398/// for actor-critic methods where TD error alone is insufficient.
399///
400/// Reference: Schaul et al. (2015) extended with Fujimoto et al. (2020) LAP.
401pub struct LAPConfig {
402    /// Scale factor for the loss component.
403    pub eta: f64,
404    /// Minimum priority to avoid zero-probability sampling.
405    pub min_priority: f64,
406}
407
408impl Default for LAPConfig {
409    fn default() -> Self {
410        Self {
411            eta: 1.0,
412            min_priority: 1e-6,
413        }
414    }
415}
416
417/// Compute LAP priorities from TD errors and losses.
418///
419/// `priority[i] = max(|td_errors[i]| + eta * losses[i], min_priority)`
420#[inline]
421pub fn compute_lap_priorities(
422    td_errors: &[f64],
423    losses: &[f64],
424    config: &LAPConfig,
425) -> Result<Vec<f64>, RloxError> {
426    if td_errors.len() != losses.len() {
427        return Err(RloxError::ShapeMismatch {
428            expected: format!("td_errors.len()={}", td_errors.len()),
429            got: format!("losses.len()={}", losses.len()),
430        });
431    }
432
433    let priorities = td_errors
434        .iter()
435        .zip(losses.iter())
436        .map(|(&td, &loss)| (td.abs() + config.eta * loss).max(config.min_priority))
437        .collect();
438    Ok(priorities)
439}
440
441/// Convenience: compute priorities from TD errors only (standard PER).
442///
443/// `priority[i] = max(|td_errors[i]|, min_priority)`
444#[inline]
445pub fn compute_td_priorities(td_errors: &[f64], min_priority: f64) -> Vec<f64> {
446    td_errors
447        .iter()
448        .map(|&td| td.abs().max(min_priority))
449        .collect()
450}
451
452impl PrioritizedReplayBuffer {
453    /// Update priorities from raw loss values using LAP.
454    ///
455    /// Convenience method: computes `priority = |loss| + epsilon` for each
456    /// index, then calls `update_priorities`.
457    pub fn update_priorities_from_loss(
458        &mut self,
459        indices: &[usize],
460        losses: &[f64],
461        epsilon: f64,
462    ) -> Result<(), RloxError> {
463        if indices.len() != losses.len() {
464            return Err(RloxError::ShapeMismatch {
465                expected: format!("indices.len()={}", indices.len()),
466                got: format!("losses.len()={}", losses.len()),
467            });
468        }
469        let priorities: Vec<f64> = losses.iter().map(|&l| l.abs() + epsilon).collect();
470        self.update_priorities(indices, &priorities)
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::buffer::{sample_record, sample_record_multidim};
478
479    // ---- SumTree tests ----
480
481    #[test]
482    fn sum_tree_new_has_zero_total() {
483        let tree = SumTree::new(8);
484        assert_eq!(tree.total(), 0.0);
485        assert_eq!(tree.capacity(), 8);
486    }
487
488    #[test]
489    fn sum_tree_set_and_get() {
490        let mut tree = SumTree::new(4);
491        tree.set(0, 1.0);
492        tree.set(1, 2.0);
493        tree.set(2, 3.0);
494        tree.set(3, 4.0);
495        assert_eq!(tree.get(0), 1.0);
496        assert_eq!(tree.get(1), 2.0);
497        assert_eq!(tree.get(2), 3.0);
498        assert_eq!(tree.get(3), 4.0);
499        assert_eq!(tree.total(), 10.0);
500    }
501
502    #[test]
503    fn sum_tree_update_propagates() {
504        let mut tree = SumTree::new(4);
505        tree.set(0, 1.0);
506        tree.set(1, 1.0);
507        tree.set(2, 1.0);
508        tree.set(3, 1.0);
509        assert_eq!(tree.total(), 4.0);
510
511        tree.set(2, 5.0);
512        assert_eq!(tree.total(), 8.0);
513        assert_eq!(tree.get(2), 5.0);
514    }
515
516    #[test]
517    fn sum_tree_sample_returns_correct_leaf() {
518        let mut tree = SumTree::new(4);
519        tree.set(0, 1.0);
520        tree.set(1, 2.0);
521        tree.set(2, 3.0);
522        tree.set(3, 4.0);
523        // total = 10; prefix sums: [0,1) -> 0, [1,3) -> 1, [3,6) -> 2, [6,10) -> 3
524        assert_eq!(tree.sample(0.0), 0);
525        assert_eq!(tree.sample(0.5), 0);
526        assert_eq!(tree.sample(1.0), 1);
527        assert_eq!(tree.sample(2.9), 1);
528        assert_eq!(tree.sample(3.0), 2);
529        assert_eq!(tree.sample(5.9), 2);
530        assert_eq!(tree.sample(6.0), 3);
531        assert_eq!(tree.sample(9.9), 3);
532    }
533
534    #[test]
535    fn sum_tree_single_leaf() {
536        let mut tree = SumTree::new(1);
537        tree.set(0, 5.0);
538        assert_eq!(tree.total(), 5.0);
539        assert_eq!(tree.sample(0.0), 0);
540        assert_eq!(tree.sample(4.9), 0);
541    }
542
543    #[test]
544    fn sum_tree_min_tracks_minimum() {
545        let mut tree = SumTree::new(4);
546        tree.set(0, 3.0);
547        tree.set(1, 1.0);
548        tree.set(2, 5.0);
549        tree.set(3, 2.0);
550        assert!((tree.min() - 1.0).abs() < 1e-12);
551
552        // Update minimum
553        tree.set(1, 10.0);
554        assert!((tree.min() - 2.0).abs() < 1e-12);
555
556        // Set new minimum
557        tree.set(3, 0.5);
558        assert!((tree.min() - 0.5).abs() < 1e-12);
559    }
560
561    #[test]
562    fn sum_tree_min_empty_is_infinity() {
563        let tree = SumTree::new(4);
564        assert!(tree.min().is_infinite());
565    }
566
567    // ---- PrioritizedReplayBuffer tests ----
568
569    #[test]
570    fn prb_new_is_empty() {
571        let buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
572        assert_eq!(buf.len(), 0);
573        assert!(buf.is_empty());
574    }
575
576    #[test]
577    fn prb_push_increments_len() {
578        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
579        buf.push(sample_record(4), 1.0).unwrap();
580        assert_eq!(buf.len(), 1);
581    }
582
583    #[test]
584    fn prb_negative_priority_errors() {
585        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
586        let result = buf.push(sample_record(4), -1.0);
587        assert!(result.is_err());
588        assert!(result.unwrap_err().to_string().contains("non-negative"));
589    }
590
591    #[test]
592    fn prb_sample_empty_errors() {
593        let buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
594        let result = buf.sample(1, 42);
595        assert!(result.is_err());
596    }
597
598    #[test]
599    fn prb_sample_too_large_errors() {
600        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
601        buf.push(sample_record(4), 1.0).unwrap();
602        let result = buf.sample(10, 42);
603        assert!(result.is_err());
604    }
605
606    #[test]
607    fn prb_sample_returns_correct_size() {
608        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
609        for _ in 0..50 {
610            buf.push(sample_record(4), 1.0).unwrap();
611        }
612        let batch = buf.sample(16, 42).unwrap();
613        assert_eq!(batch.batch_size, 16);
614        assert_eq!(batch.observations.len(), 16 * 4);
615        assert_eq!(batch.actions.len(), 16);
616        assert_eq!(batch.rewards.len(), 16);
617        assert_eq!(batch.weights.len(), 16);
618        assert_eq!(batch.indices.len(), 16);
619    }
620
621    #[test]
622    fn prb_weights_are_in_zero_one() {
623        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
624        for i in 0..50 {
625            buf.push(sample_record(4), (i + 1) as f64).unwrap();
626        }
627        let batch = buf.sample(16, 42).unwrap();
628        for &w in &batch.weights {
629            assert!(w > 0.0, "weight must be positive, got {w}");
630            assert!(w <= 1.0 + 1e-10, "weight must be <= 1.0, got {w}");
631        }
632    }
633
634    #[test]
635    fn prb_high_priority_sampled_more_often() {
636        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
637        // Index 0 gets priority 100, rest get priority 1
638        let mut rec = sample_record(4);
639        rec.reward = 99.0;
640        buf.push(rec, 100.0).unwrap();
641        for _ in 1..50 {
642            buf.push(sample_record(4), 1.0).unwrap();
643        }
644
645        // Sample many times and count how often index 0 appears
646        let mut count_high = 0;
647        for seed in 0..100 {
648            let batch = buf.sample(10, seed).unwrap();
649            for &idx in &batch.indices {
650                if idx == 0 {
651                    count_high += 1;
652                }
653            }
654        }
655        // With priority 100 vs 49*1 = 149 total, idx 0 should appear ~67% of time
656        assert!(
657            count_high > 200,
658            "high priority item should be sampled frequently, got {count_high}/1000"
659        );
660    }
661
662    #[test]
663    fn prb_update_priorities() {
664        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
665        for _ in 0..10 {
666            buf.push(sample_record(4), 1.0).unwrap();
667        }
668        // Update index 0 to very high priority
669        buf.update_priorities(&[0], &[100.0]).unwrap();
670
671        let mut count_idx0 = 0;
672        for seed in 0..50 {
673            let batch = buf.sample(5, seed).unwrap();
674            for &idx in &batch.indices {
675                if idx == 0 {
676                    count_idx0 += 1;
677                }
678            }
679        }
680        assert!(
681            count_idx0 > 100,
682            "updated high-priority item should be sampled frequently"
683        );
684    }
685
686    #[test]
687    fn prb_update_priorities_negative_errors() {
688        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
689        buf.push(sample_record(4), 1.0).unwrap();
690        let result = buf.update_priorities(&[0], &[-1.0]);
691        assert!(result.is_err());
692    }
693
694    #[test]
695    fn prb_update_priorities_oob_errors() {
696        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
697        buf.push(sample_record(4), 1.0).unwrap();
698        let result = buf.update_priorities(&[5], &[1.0]);
699        assert!(result.is_err());
700    }
701
702    #[test]
703    fn prb_update_priorities_length_mismatch_errors() {
704        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
705        buf.push(sample_record(4), 1.0).unwrap();
706        let result = buf.update_priorities(&[0], &[1.0, 2.0]);
707        assert!(result.is_err());
708    }
709
710    #[test]
711    fn prb_set_beta() {
712        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
713        buf.set_beta(1.0);
714        // Just ensure it doesn't panic; beta affects weight computation
715    }
716
717    #[test]
718    fn prb_wraps_around() {
719        let mut buf = PrioritizedReplayBuffer::new(5, 4, 1, 0.6, 0.4);
720        for i in 0..10 {
721            let mut rec = sample_record(4);
722            rec.reward = i as f32;
723            buf.push(rec, 1.0).unwrap();
724        }
725        assert_eq!(buf.len(), 5);
726        // Should contain rewards 5..10
727        let batch = buf.sample(5, 42).unwrap();
728        for &r in &batch.rewards {
729            assert!(r >= 5.0, "old data should be overwritten, got reward {r}");
730        }
731    }
732
733    #[test]
734    fn prb_multidim_actions() {
735        let mut buf = PrioritizedReplayBuffer::new(100, 4, 3, 0.6, 0.4);
736        buf.push(sample_record_multidim(4, 3), 1.0).unwrap();
737        let batch = buf.sample(1, 42).unwrap();
738        assert_eq!(batch.act_dim, 3);
739        assert_eq!(batch.actions.len(), 3);
740    }
741
742    #[test]
743    fn prb_deterministic_with_same_seed() {
744        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
745        for _ in 0..50 {
746            buf.push(sample_record(4), 1.0).unwrap();
747        }
748        let b1 = buf.sample(16, 42).unwrap();
749        let b2 = buf.sample(16, 42).unwrap();
750        assert_eq!(b1.indices, b2.indices);
751        assert_eq!(b1.weights, b2.weights);
752    }
753
754    #[test]
755    fn prb_next_obs_roundtrip() {
756        let obs_dim = 4;
757        let mut buf = PrioritizedReplayBuffer::new(100, obs_dim, 1, 0.6, 0.4);
758        let record = ExperienceRecord {
759            obs: vec![1.0; obs_dim],
760            next_obs: vec![2.0, 3.0, 4.0, 5.0],
761            action: vec![0.0],
762            reward: 1.0,
763            terminated: false,
764            truncated: false,
765        };
766        buf.push(record, 1.0).unwrap();
767        let batch = buf.sample(1, 42).unwrap();
768        assert_eq!(&batch.next_observations, &[2.0, 3.0, 4.0, 5.0]);
769    }
770
771    #[test]
772    fn prb_next_obs_shape() {
773        let obs_dim = 4;
774        let mut buf = PrioritizedReplayBuffer::new(200, obs_dim, 1, 0.6, 0.4);
775        for _ in 0..100 {
776            buf.push(sample_record(obs_dim), 1.0).unwrap();
777        }
778        let batch = buf.sample(32, 42).unwrap();
779        assert_eq!(batch.next_observations.len(), 32 * obs_dim);
780    }
781
782    #[test]
783    fn prb_obs_dim_mismatch_errors() {
784        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
785        let result = buf.push(sample_record(8), 1.0);
786        assert!(result.is_err());
787    }
788
789    // ---- Proptests ----
790
791    mod proptests {
792        use super::*;
793        use proptest::prelude::*;
794
795        proptest! {
796            #[test]
797            fn sum_tree_total_equals_sum_of_leaves(
798                priorities in proptest::collection::vec(0.0f64..100.0, 1..64)
799            ) {
800                let n = priorities.len();
801                let mut tree = SumTree::new(n);
802                let mut expected = 0.0;
803                for (i, &p) in priorities.iter().enumerate() {
804                    tree.set(i, p);
805                    expected += p;
806                }
807                let diff = (tree.total() - expected).abs();
808                prop_assert!(diff < 1e-6, "total {} != expected {}", tree.total(), expected);
809            }
810
811            #[test]
812            fn sum_tree_sample_in_range(
813                priorities in proptest::collection::vec(1.0f64..100.0, 1..64)
814            ) {
815                let n = priorities.len();
816                let mut tree = SumTree::new(n);
817                for (i, &p) in priorities.iter().enumerate() {
818                    tree.set(i, p);
819                }
820                // Sample several values
821                let total = tree.total();
822                for v in [0.0, total * 0.25, total * 0.5, total * 0.75, total * 0.999] {
823                    let idx = tree.sample(v);
824                    prop_assert!(idx < n, "sampled index {} >= capacity {}", idx, n);
825                }
826            }
827
828            #[test]
829            fn prb_never_exceeds_capacity(
830                capacity in 1..100usize,
831                num_pushes in 0..300usize,
832            ) {
833                let mut buf = PrioritizedReplayBuffer::new(capacity, 4, 1, 0.6, 0.4);
834                for _ in 0..num_pushes {
835                    buf.push(sample_record(4), 1.0).unwrap();
836                }
837                prop_assert!(buf.len() <= capacity);
838                prop_assert_eq!(buf.len(), num_pushes.min(capacity));
839            }
840
841            #[test]
842            fn prb_weights_are_valid(
843                num_pushes in 10..100usize,
844                batch_size in 1..10usize,
845            ) {
846                let mut buf = PrioritizedReplayBuffer::new(200, 4, 1, 0.6, 0.4);
847                for i in 0..num_pushes {
848                    buf.push(sample_record(4), (i + 1) as f64).unwrap();
849                }
850                let effective_batch = batch_size.min(buf.len());
851                let batch = buf.sample(effective_batch, 42).unwrap();
852                for &w in &batch.weights {
853                    prop_assert!(w > 0.0, "weight must be positive");
854                    prop_assert!(w <= 1.0 + 1e-10, "weight must be <= 1.0");
855                }
856                for &idx in &batch.indices {
857                    prop_assert!(idx < buf.len(), "index must be < len");
858                }
859            }
860        }
861    }
862
863    // ---- LAP tests ----
864
865    #[test]
866    fn test_lap_known_values() {
867        let td = &[1.0, -2.0];
868        let loss = &[0.5, 0.3];
869        let config = LAPConfig {
870            eta: 1.0,
871            min_priority: 1e-6,
872        };
873        let result = compute_lap_priorities(td, loss, &config).unwrap();
874        assert!(
875            (result[0] - 1.5).abs() < 1e-10,
876            "expected 1.5, got {}",
877            result[0]
878        );
879        assert!(
880            (result[1] - 2.3).abs() < 1e-10,
881            "expected 2.3, got {}",
882            result[1]
883        );
884    }
885
886    #[test]
887    fn test_lap_eta_zero_is_standard_per() {
888        let td = &[1.0, -2.0];
889        let loss = &[999.0, 999.0];
890        let config = LAPConfig {
891            eta: 0.0,
892            min_priority: 1e-6,
893        };
894        let result = compute_lap_priorities(td, loss, &config).unwrap();
895        assert!((result[0] - 1.0).abs() < 1e-10);
896        assert!((result[1] - 2.0).abs() < 1e-10);
897    }
898
899    #[test]
900    fn test_lap_min_priority_floor() {
901        let td = &[0.0];
902        let loss = &[0.0];
903        let config = LAPConfig {
904            eta: 1.0,
905            min_priority: 1e-6,
906        };
907        let result = compute_lap_priorities(td, loss, &config).unwrap();
908        assert!(
909            (result[0] - 1e-6).abs() < 1e-12,
910            "expected 1e-6, got {}",
911            result[0]
912        );
913    }
914
915    #[test]
916    fn test_lap_negative_td_uses_abs() {
917        let td = &[-3.0];
918        let loss = &[0.0];
919        let config = LAPConfig::default();
920        let result = compute_lap_priorities(td, loss, &config).unwrap();
921        assert!((result[0] - 3.0).abs() < 1e-10);
922    }
923
924    #[test]
925    fn test_lap_length_mismatch() {
926        let result = compute_lap_priorities(&[1.0, 2.0, 3.0], &[0.5, 0.5], &LAPConfig::default());
927        assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
928    }
929
930    #[test]
931    fn test_td_priorities_known_values() {
932        let td = &[0.5, -1.0, 0.0];
933        let result = compute_td_priorities(td, 0.01);
934        assert!((result[0] - 0.5).abs() < 1e-10);
935        assert!((result[1] - 1.0).abs() < 1e-10);
936        assert!((result[2] - 0.01).abs() < 1e-10);
937    }
938
939    #[test]
940    fn test_lap_integration_with_per_buffer() {
941        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
942        let config = LAPConfig::default();
943        for i in 0..10 {
944            let td = &[(i + 1) as f64];
945            let loss = &[i as f64 * 0.5];
946            let priorities = compute_lap_priorities(td, loss, &config).unwrap();
947            buf.push(sample_record(4), priorities[0]).unwrap();
948        }
949        let batch = buf.sample(5, 42).unwrap();
950        assert_eq!(batch.batch_size, 5);
951        for &w in &batch.weights {
952            assert!(w > 0.0 && w <= 1.0 + 1e-10);
953        }
954    }
955
956    #[test]
957    fn test_lap_large_eta_emphasizes_loss() {
958        let td = &[0.01];
959        let loss = &[10.0];
960        let config = LAPConfig {
961            eta: 100.0,
962            min_priority: 1e-6,
963        };
964        let result = compute_lap_priorities(td, loss, &config).unwrap();
965        let expected = 0.01 + 100.0 * 10.0;
966        assert!((result[0] - expected).abs() < 1e-10);
967    }
968
969    #[test]
970    fn test_update_priorities_from_loss_correct() {
971        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
972        for _ in 0..10 {
973            buf.push(sample_record(4), 1.0).unwrap();
974        }
975        buf.update_priorities_from_loss(&[0, 1, 2], &[5.0, 0.0, 3.0], 0.01)
976            .unwrap();
977        // Priorities should be |loss| + epsilon
978        // Index 0: 5.01, Index 1: 0.01, Index 2: 3.01
979    }
980
981    #[test]
982    fn test_update_priorities_from_loss_mismatched_lengths() {
983        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 0.6, 0.4);
984        for _ in 0..10 {
985            buf.push(sample_record(4), 1.0).unwrap();
986        }
987        let result = buf.update_priorities_from_loss(&[0, 1], &[5.0], 0.01);
988        assert!(result.is_err());
989    }
990
991    #[test]
992    fn test_update_priorities_from_loss_zero_gets_epsilon() {
993        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
994        for _ in 0..5 {
995            buf.push(sample_record(4), 1.0).unwrap();
996        }
997        let epsilon = 0.001;
998        buf.update_priorities_from_loss(&[0], &[0.0], epsilon)
999            .unwrap();
1000        // The priority should be epsilon^alpha = 0.001^1.0 = 0.001
1001    }
1002
1003    #[test]
1004    fn test_update_priorities_from_loss_high_loss_favored() {
1005        let mut buf = PrioritizedReplayBuffer::new(100, 4, 1, 1.0, 0.4);
1006        for _ in 0..10 {
1007            buf.push(sample_record(4), 1.0).unwrap();
1008        }
1009        // Give index 0 a very high loss
1010        buf.update_priorities_from_loss(&[0], &[100.0], 0.01)
1011            .unwrap();
1012        // Give other indices low loss
1013        for i in 1..10 {
1014            buf.update_priorities_from_loss(&[i], &[0.001], 0.01)
1015                .unwrap();
1016        }
1017        // Sample many times, index 0 should appear frequently
1018        let mut count_idx0 = 0;
1019        for seed in 0..100 {
1020            let batch = buf.sample(5, seed).unwrap();
1021            for &idx in &batch.indices {
1022                if idx == 0 {
1023                    count_idx0 += 1;
1024                }
1025            }
1026        }
1027        assert!(
1028            count_idx0 > 100,
1029            "high-loss item should be sampled frequently, got {count_idx0}/500"
1030        );
1031    }
1032
1033    mod lap_proptests {
1034        use super::*;
1035        use proptest::prelude::*;
1036
1037        proptest! {
1038            #[test]
1039            fn prop_lap_priorities_non_negative(
1040                n in 1usize..100,
1041            ) {
1042                let td: Vec<f64> = (0..n).map(|i| (i as f64) - 50.0).collect();
1043                let loss: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
1044                let config = LAPConfig::default();
1045                let priorities = compute_lap_priorities(&td, &loss, &config).unwrap();
1046                for (i, &p) in priorities.iter().enumerate() {
1047                    prop_assert!(p >= config.min_priority,
1048                        "priority[{i}]={p} < min_priority={}", config.min_priority);
1049                }
1050            }
1051
1052            #[test]
1053            fn prop_lap_monotone_in_td(
1054                td_a in 0.0f64..100.0,
1055                td_b in 0.0f64..100.0,
1056                loss in 0.0f64..10.0,
1057            ) {
1058                let config = LAPConfig {
1059                    eta: 1.0,
1060                    min_priority: 0.0, // disable floor for this test
1061                };
1062                let p_a = compute_lap_priorities(&[td_a], &[loss], &config).unwrap()[0];
1063                let p_b = compute_lap_priorities(&[td_b], &[loss], &config).unwrap()[0];
1064                if td_a.abs() > td_b.abs() {
1065                    prop_assert!(p_a >= p_b,
1066                        "|td_a|={} > |td_b|={} but p_a={} < p_b={}",
1067                        td_a.abs(), td_b.abs(), p_a, p_b);
1068                }
1069            }
1070        }
1071    }
1072}