rlox_core/buffer/
episode.rs

1//! Episode boundary tracking for ring buffers.
2//!
3//! Provides [`EpisodeTracker`] which maintains metadata about episode boundaries
4//! within a ring buffer, enabling sequence sampling and HER-style relabeling.
5
6use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12/// Trait for episode-aware buffer components.
13///
14/// Anything that needs to track episode boundaries in a ring buffer
15/// implements this trait. Used by SequenceReplayBuffer and HERBuffer.
16pub trait EpisodeAware {
17    /// Notify that a transition was pushed at `write_pos` with `done` flag.
18    fn notify_push(&mut self, write_pos: usize, done: bool);
19
20    /// Invalidate any episodes that overlap with the overwritten region.
21    fn invalidate_overwritten(&mut self, write_pos: usize, count: usize);
22
23    /// Number of complete (terminated/truncated) episodes currently tracked.
24    fn num_complete_episodes(&self) -> usize;
25}
26
27/// Metadata for a single episode within the ring buffer.
28#[derive(Debug, Clone, Copy)]
29pub struct EpisodeMeta {
30    /// Starting position in the ring buffer.
31    pub start: usize,
32    /// Number of transitions in this episode.
33    pub length: usize,
34    /// Whether this episode is complete (reached done=true).
35    /// Incomplete episodes are still being built (the current in-progress
36    /// episode). They are not returned by `eligible_episodes`.
37    pub complete: bool,
38}
39
40/// A contiguous window of transitions within an episode,
41/// suitable for sequence sampling.
42#[derive(Debug, Clone, Copy)]
43pub struct EpisodeWindow {
44    /// Index of the episode this window belongs to.
45    pub episode_idx: usize,
46    /// Starting position in the ring buffer.
47    pub ring_start: usize,
48    /// Number of transitions in this window.
49    pub length: usize,
50}
51
52/// Tracks episode boundaries within a ring buffer.
53///
54/// Maintains a list of [`EpisodeMeta`] entries, invalidating episodes
55/// whose data has been overwritten by the ring buffer's write pointer.
56/// Provides efficient sampling of contiguous windows for sequence models.
57#[derive(Debug)]
58pub struct EpisodeTracker {
59    ring_capacity: usize,
60    episodes: Vec<EpisodeMeta>,
61    /// Start position of the episode currently being built.
62    current_episode_start: Option<usize>,
63    current_episode_length: usize,
64}
65
66impl EpisodeTracker {
67    /// Create a new episode tracker for a ring buffer with the given capacity.
68    pub fn new(ring_capacity: usize) -> Self {
69        Self {
70            ring_capacity,
71            episodes: Vec::new(),
72            current_episode_start: None,
73            current_episode_length: 0,
74        }
75    }
76
77    /// Record a push at `write_pos`. If `done`, the current episode is finalized.
78    #[inline]
79    pub fn notify_push(&mut self, write_pos: usize, done: bool) {
80        if self.current_episode_start.is_none() {
81            self.current_episode_start = Some(write_pos);
82            self.current_episode_length = 0;
83        }
84
85        self.current_episode_length += 1;
86
87        if done {
88            self.episodes.push(EpisodeMeta {
89                start: self.current_episode_start.take().unwrap_or(write_pos),
90                length: self.current_episode_length,
91                complete: true,
92            });
93            self.current_episode_start = None;
94            self.current_episode_length = 0;
95        }
96    }
97
98    /// Remove any episodes whose transitions have been overwritten.
99    ///
100    /// Called when the ring buffer wraps. An episode is invalidated if any
101    /// of its positions overlap with the region `[write_pos, write_pos + count)`
102    /// modulo the ring capacity.
103    #[inline]
104    pub fn invalidate_overwritten(&mut self, write_pos: usize, count: usize) {
105        self.episodes.retain(|ep| {
106            !ring_range_overlaps(ep.start, ep.length, write_pos, count, self.ring_capacity)
107        });
108
109        // Also invalidate current in-progress episode if it overlaps
110        if let Some(start) = self.current_episode_start {
111            if ring_range_overlaps(
112                start,
113                self.current_episode_length,
114                write_pos,
115                count,
116                self.ring_capacity,
117            ) {
118                self.current_episode_start = None;
119                self.current_episode_length = 0;
120            }
121        }
122    }
123
124    /// Number of complete episodes currently tracked.
125    #[inline]
126    pub fn num_complete_episodes(&self) -> usize {
127        self.episodes.iter().filter(|ep| ep.complete).count()
128    }
129
130    /// All currently tracked episodes (complete and in-progress).
131    pub fn episodes(&self) -> &[EpisodeMeta] {
132        &self.episodes
133    }
134
135    /// Indices of episodes long enough for a given sequence length.
136    pub fn eligible_episodes(&self, min_length: usize) -> Vec<usize> {
137        self.episodes
138            .iter()
139            .enumerate()
140            .filter(|(_, ep)| ep.complete && ep.length >= min_length)
141            .map(|(i, _)| i)
142            .collect()
143    }
144
145    /// Sample `batch_size` windows of `seq_len` consecutive transitions,
146    /// each entirely within a single complete episode.
147    ///
148    /// Uses ChaCha8Rng seeded with `seed`.
149    pub fn sample_windows(
150        &self,
151        batch_size: usize,
152        seq_len: usize,
153        seed: u64,
154    ) -> Result<Vec<EpisodeWindow>, RloxError> {
155        let eligible = self.eligible_episodes(seq_len);
156        if eligible.is_empty() {
157            return Err(RloxError::BufferError(format!(
158                "no episodes with length >= {seq_len}"
159            )));
160        }
161
162        let mut rng = ChaCha8Rng::seed_from_u64(seed);
163        let mut windows = Vec::with_capacity(batch_size);
164
165        for _ in 0..batch_size {
166            let ep_idx = eligible[rng.random_range(0..eligible.len())];
167            let ep = &self.episodes[ep_idx];
168
169            // Random start offset within the episode
170            let max_offset = ep.length - seq_len;
171            let offset = if max_offset == 0 {
172                0
173            } else {
174                rng.random_range(0..=max_offset)
175            };
176
177            let ring_start = (ep.start + offset) % self.ring_capacity;
178
179            windows.push(EpisodeWindow {
180                episode_idx: ep_idx,
181                ring_start,
182                length: seq_len,
183            });
184        }
185
186        Ok(windows)
187    }
188}
189
190impl EpisodeAware for EpisodeTracker {
191    fn notify_push(&mut self, write_pos: usize, done: bool) {
192        EpisodeTracker::notify_push(self, write_pos, done);
193    }
194
195    fn invalidate_overwritten(&mut self, write_pos: usize, count: usize) {
196        EpisodeTracker::invalidate_overwritten(self, write_pos, count);
197    }
198
199    fn num_complete_episodes(&self) -> usize {
200        EpisodeTracker::num_complete_episodes(self)
201    }
202}
203
204/// Check if two ring-buffer ranges overlap in O(1) using modular arithmetic.
205///
206/// Range A: `[a_start, a_start + a_len)` mod `cap`
207/// Range B: `[b_start, b_start + b_len)` mod `cap`
208///
209/// Two circular ranges overlap iff the start of either range falls
210/// within the other range.
211#[inline]
212fn ring_range_overlaps(
213    a_start: usize,
214    a_len: usize,
215    b_start: usize,
216    b_len: usize,
217    cap: usize,
218) -> bool {
219    if a_len == 0 || b_len == 0 {
220        return false;
221    }
222    let a_in_b = (a_start + cap - b_start) % cap < b_len;
223    let b_in_a = (b_start + cap - a_start) % cap < a_len;
224    a_in_b || b_in_a
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_new_tracker_is_empty() {
233        let tracker = EpisodeTracker::new(100);
234        assert_eq!(tracker.num_complete_episodes(), 0);
235        assert!(tracker.episodes().is_empty());
236    }
237
238    #[test]
239    fn test_single_episode_tracked() {
240        let mut tracker = EpisodeTracker::new(100);
241        for i in 0..5 {
242            tracker.notify_push(i, i == 4); // done at position 4
243        }
244        assert_eq!(tracker.num_complete_episodes(), 1);
245        assert_eq!(tracker.episodes()[0].length, 5);
246        assert_eq!(tracker.episodes()[0].start, 0);
247        assert!(tracker.episodes()[0].complete);
248    }
249
250    #[test]
251    fn test_multiple_episodes() {
252        let mut tracker = EpisodeTracker::new(100);
253        let mut pos = 0;
254        // Episode 1: length 3
255        for _ in 0..3 {
256            tracker.notify_push(pos, pos == 2);
257            pos += 1;
258        }
259        // Episode 2: length 5
260        for _ in 0..5 {
261            tracker.notify_push(pos, pos == 7);
262            pos += 1;
263        }
264        // Episode 3: length 2
265        for _ in 0..2 {
266            tracker.notify_push(pos, pos == 9);
267            pos += 1;
268        }
269        assert_eq!(tracker.num_complete_episodes(), 3);
270        assert_eq!(tracker.episodes()[0].length, 3);
271        assert_eq!(tracker.episodes()[1].length, 5);
272        assert_eq!(tracker.episodes()[2].length, 2);
273    }
274
275    #[test]
276    fn test_incomplete_episode_not_counted() {
277        let mut tracker = EpisodeTracker::new(100);
278        for i in 0..5 {
279            tracker.notify_push(i, false);
280        }
281        assert_eq!(tracker.num_complete_episodes(), 0);
282    }
283
284    #[test]
285    fn test_invalidate_removes_overwritten() {
286        let mut tracker = EpisodeTracker::new(10);
287        // Episode at positions 0..5
288        for i in 0..5 {
289            tracker.notify_push(i, i == 4);
290        }
291        assert_eq!(tracker.num_complete_episodes(), 1);
292
293        // Overwrite positions 0..3
294        tracker.invalidate_overwritten(0, 3);
295        assert_eq!(tracker.num_complete_episodes(), 0);
296    }
297
298    #[test]
299    fn test_sample_windows_within_episode() {
300        let mut tracker = EpisodeTracker::new(100);
301        // Two episodes of length 10
302        for i in 0..10 {
303            tracker.notify_push(i, i == 9);
304        }
305        for i in 10..20 {
306            tracker.notify_push(i, i == 19);
307        }
308        assert_eq!(tracker.num_complete_episodes(), 2);
309
310        let windows = tracker.sample_windows(5, 5, 42).unwrap();
311        assert_eq!(windows.len(), 5);
312        for w in &windows {
313            assert_eq!(w.length, 5);
314            // Verify window is within its episode
315            let ep = &tracker.episodes()[w.episode_idx];
316            let ep_end = ep.start + ep.length;
317            assert!(
318                w.ring_start >= ep.start && w.ring_start + w.length <= ep_end,
319                "window [{}, {}) not within episode [{}, {})",
320                w.ring_start,
321                w.ring_start + w.length,
322                ep.start,
323                ep_end
324            );
325        }
326    }
327
328    #[test]
329    fn test_sample_windows_deterministic() {
330        let mut tracker = EpisodeTracker::new(100);
331        for i in 0..10 {
332            tracker.notify_push(i, i == 9);
333        }
334        let w1 = tracker.sample_windows(5, 3, 42).unwrap();
335        let w2 = tracker.sample_windows(5, 3, 42).unwrap();
336        for (a, b) in w1.iter().zip(w2.iter()) {
337            assert_eq!(a.ring_start, b.ring_start);
338            assert_eq!(a.episode_idx, b.episode_idx);
339            assert_eq!(a.length, b.length);
340        }
341    }
342
343    #[test]
344    fn test_sample_windows_rejects_too_long_seq() {
345        let mut tracker = EpisodeTracker::new(100);
346        for i in 0..3 {
347            tracker.notify_push(i, i == 2);
348        }
349        let result = tracker.sample_windows(1, 5, 42);
350        assert!(result.is_err());
351    }
352
353    #[test]
354    fn test_eligible_episodes_filters_short() {
355        let mut tracker = EpisodeTracker::new(100);
356        let mut pos = 0;
357        // Episode 0: length 2
358        for _ in 0..2 {
359            tracker.notify_push(pos, pos == 1);
360            pos += 1;
361        }
362        // Episode 1: length 5
363        for _ in 0..5 {
364            tracker.notify_push(pos, pos == 6);
365            pos += 1;
366        }
367        // Episode 2: length 3
368        for _ in 0..3 {
369            tracker.notify_push(pos, pos == 9);
370            pos += 1;
371        }
372        // Episode 3: length 8
373        for _ in 0..8 {
374            tracker.notify_push(pos, pos == 17);
375            pos += 1;
376        }
377        let eligible = tracker.eligible_episodes(4);
378        assert_eq!(eligible, vec![1, 3]);
379    }
380
381    #[test]
382    fn test_invalidate_partial_episode() {
383        let mut tracker = EpisodeTracker::new(10);
384        // Episode at positions 0..5
385        for i in 0..5 {
386            tracker.notify_push(i, i == 4);
387        }
388        // Overwrite position 2 (middle of episode)
389        tracker.invalidate_overwritten(2, 1);
390        assert_eq!(
391            tracker.num_complete_episodes(),
392            0,
393            "partially overwritten episode should be removed"
394        );
395    }
396
397    #[test]
398    fn test_consecutive_dones() {
399        let mut tracker = EpisodeTracker::new(100);
400        tracker.notify_push(0, true); // Episode of length 1
401        tracker.notify_push(1, true); // Another episode of length 1
402        assert_eq!(tracker.num_complete_episodes(), 2);
403        assert_eq!(tracker.episodes()[0].length, 1);
404        assert_eq!(tracker.episodes()[1].length, 1);
405    }
406
407    #[test]
408    fn test_empty_tracker_sample_windows_errors() {
409        let tracker = EpisodeTracker::new(100);
410        let result = tracker.sample_windows(1, 1, 42);
411        assert!(result.is_err());
412    }
413
414    #[test]
415    fn test_single_transition_episode() {
416        let mut tracker = EpisodeTracker::new(100);
417        tracker.notify_push(0, true);
418        assert_eq!(tracker.num_complete_episodes(), 1);
419        let windows = tracker.sample_windows(1, 1, 42).unwrap();
420        assert_eq!(windows[0].ring_start, 0);
421        assert_eq!(windows[0].length, 1);
422    }
423
424    #[test]
425    fn test_trait_object_safety() {
426        let tracker: Box<dyn EpisodeAware> = Box::new(EpisodeTracker::new(100));
427        assert_eq!(tracker.num_complete_episodes(), 0);
428    }
429
430    mod proptests {
431        use super::*;
432        use proptest::prelude::*;
433
434        proptest! {
435            #[test]
436            fn prop_episode_count_matches_dones(
437                n in 1usize..200,
438                done_rate in 0.05f64..0.5,
439            ) {
440                let mut tracker = EpisodeTracker::new(n * 2); // large capacity, no wrapping
441                let mut expected_complete = 0;
442                for i in 0..n {
443                    let done = ((i as f64 + 1.0) * done_rate) as usize
444                        > (i as f64 * done_rate) as usize;
445                    tracker.notify_push(i, done);
446                    if done {
447                        expected_complete += 1;
448                    }
449                }
450                prop_assert_eq!(
451                    tracker.num_complete_episodes(),
452                    expected_complete,
453                    "expected {} complete episodes", expected_complete
454                );
455            }
456
457            #[test]
458            fn prop_window_within_bounds(
459                ep_len in 5usize..50,
460                seq_len in 1usize..5,
461                batch_size in 1usize..10,
462            ) {
463                let cap = ep_len * 3;
464                let mut tracker = EpisodeTracker::new(cap);
465                for i in 0..ep_len {
466                    tracker.notify_push(i, i == ep_len - 1);
467                }
468                let windows = tracker.sample_windows(batch_size, seq_len, 42).unwrap();
469                for w in &windows {
470                    prop_assert!(
471                        w.ring_start + w.length <= cap,
472                        "window [{}, {}) exceeds capacity {cap}",
473                        w.ring_start,
474                        w.ring_start + w.length
475                    );
476                }
477            }
478
479            #[test]
480            fn prop_no_cross_episode_windows(
481                n_episodes in 2usize..10,
482                ep_len in 5usize..20,
483                seq_len in 1usize..5,
484            ) {
485                let cap = n_episodes * ep_len * 2;
486                let mut tracker = EpisodeTracker::new(cap);
487                let mut pos = 0;
488                for _ in 0..n_episodes {
489                    for j in 0..ep_len {
490                        tracker.notify_push(pos, j == ep_len - 1);
491                        pos += 1;
492                    }
493                }
494                let windows = tracker.sample_windows(n_episodes * 2, seq_len, 42).unwrap();
495                for w in &windows {
496                    let ep = &tracker.episodes()[w.episode_idx];
497                    let ep_end = ep.start + ep.length;
498                    prop_assert!(
499                        w.ring_start >= ep.start && w.ring_start + w.length <= ep_end,
500                        "window crosses episode boundary"
501                    );
502                }
503            }
504
505            #[test]
506            fn prop_invalidation_never_returns_overwritten(
507                cap in 10usize..100,
508                n_pushes in 1usize..300,
509            ) {
510                let mut tracker = EpisodeTracker::new(cap);
511                for (write_pos, i) in (0..n_pushes).enumerate() {
512                    let done = i % 7 == 6; // episodes of ~7 steps
513                    if write_pos >= cap {
514                        // Wrapping: invalidate the position about to be overwritten
515                        tracker.invalidate_overwritten(write_pos % cap, 1);
516                    }
517                    tracker.notify_push(write_pos % cap, done);
518                }
519                // All remaining episodes should have valid start positions
520                for ep in tracker.episodes() {
521                    prop_assert!(
522                        ep.start < cap,
523                        "episode start {} >= capacity {cap}",
524                        ep.start
525                    );
526                }
527            }
528        }
529    }
530}