1use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12pub trait EpisodeAware {
17 fn notify_push(&mut self, write_pos: usize, done: bool);
19
20 fn invalidate_overwritten(&mut self, write_pos: usize, count: usize);
22
23 fn num_complete_episodes(&self) -> usize;
25}
26
27#[derive(Debug, Clone, Copy)]
29pub struct EpisodeMeta {
30 pub start: usize,
32 pub length: usize,
34 pub complete: bool,
38}
39
40#[derive(Debug, Clone, Copy)]
43pub struct EpisodeWindow {
44 pub episode_idx: usize,
46 pub ring_start: usize,
48 pub length: usize,
50}
51
52#[derive(Debug)]
58pub struct EpisodeTracker {
59 ring_capacity: usize,
60 episodes: Vec<EpisodeMeta>,
61 current_episode_start: Option<usize>,
63 current_episode_length: usize,
64}
65
66impl EpisodeTracker {
67 pub fn new(ring_capacity: usize) -> Self {
69 Self {
70 ring_capacity,
71 episodes: Vec::new(),
72 current_episode_start: None,
73 current_episode_length: 0,
74 }
75 }
76
77 #[inline]
79 pub fn notify_push(&mut self, write_pos: usize, done: bool) {
80 if self.current_episode_start.is_none() {
81 self.current_episode_start = Some(write_pos);
82 self.current_episode_length = 0;
83 }
84
85 self.current_episode_length += 1;
86
87 if done {
88 self.episodes.push(EpisodeMeta {
89 start: self.current_episode_start.take().unwrap_or(write_pos),
90 length: self.current_episode_length,
91 complete: true,
92 });
93 self.current_episode_start = None;
94 self.current_episode_length = 0;
95 }
96 }
97
98 #[inline]
104 pub fn invalidate_overwritten(&mut self, write_pos: usize, count: usize) {
105 self.episodes.retain(|ep| {
106 !ring_range_overlaps(ep.start, ep.length, write_pos, count, self.ring_capacity)
107 });
108
109 if let Some(start) = self.current_episode_start {
111 if ring_range_overlaps(
112 start,
113 self.current_episode_length,
114 write_pos,
115 count,
116 self.ring_capacity,
117 ) {
118 self.current_episode_start = None;
119 self.current_episode_length = 0;
120 }
121 }
122 }
123
124 #[inline]
126 pub fn num_complete_episodes(&self) -> usize {
127 self.episodes.iter().filter(|ep| ep.complete).count()
128 }
129
130 pub fn episodes(&self) -> &[EpisodeMeta] {
132 &self.episodes
133 }
134
135 pub fn eligible_episodes(&self, min_length: usize) -> Vec<usize> {
137 self.episodes
138 .iter()
139 .enumerate()
140 .filter(|(_, ep)| ep.complete && ep.length >= min_length)
141 .map(|(i, _)| i)
142 .collect()
143 }
144
145 pub fn sample_windows(
150 &self,
151 batch_size: usize,
152 seq_len: usize,
153 seed: u64,
154 ) -> Result<Vec<EpisodeWindow>, RloxError> {
155 let eligible = self.eligible_episodes(seq_len);
156 if eligible.is_empty() {
157 return Err(RloxError::BufferError(format!(
158 "no episodes with length >= {seq_len}"
159 )));
160 }
161
162 let mut rng = ChaCha8Rng::seed_from_u64(seed);
163 let mut windows = Vec::with_capacity(batch_size);
164
165 for _ in 0..batch_size {
166 let ep_idx = eligible[rng.random_range(0..eligible.len())];
167 let ep = &self.episodes[ep_idx];
168
169 let max_offset = ep.length - seq_len;
171 let offset = if max_offset == 0 {
172 0
173 } else {
174 rng.random_range(0..=max_offset)
175 };
176
177 let ring_start = (ep.start + offset) % self.ring_capacity;
178
179 windows.push(EpisodeWindow {
180 episode_idx: ep_idx,
181 ring_start,
182 length: seq_len,
183 });
184 }
185
186 Ok(windows)
187 }
188}
189
190impl EpisodeAware for EpisodeTracker {
191 fn notify_push(&mut self, write_pos: usize, done: bool) {
192 EpisodeTracker::notify_push(self, write_pos, done);
193 }
194
195 fn invalidate_overwritten(&mut self, write_pos: usize, count: usize) {
196 EpisodeTracker::invalidate_overwritten(self, write_pos, count);
197 }
198
199 fn num_complete_episodes(&self) -> usize {
200 EpisodeTracker::num_complete_episodes(self)
201 }
202}
203
204#[inline]
212fn ring_range_overlaps(
213 a_start: usize,
214 a_len: usize,
215 b_start: usize,
216 b_len: usize,
217 cap: usize,
218) -> bool {
219 if a_len == 0 || b_len == 0 {
220 return false;
221 }
222 let a_in_b = (a_start + cap - b_start) % cap < b_len;
223 let b_in_a = (b_start + cap - a_start) % cap < a_len;
224 a_in_b || b_in_a
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_new_tracker_is_empty() {
233 let tracker = EpisodeTracker::new(100);
234 assert_eq!(tracker.num_complete_episodes(), 0);
235 assert!(tracker.episodes().is_empty());
236 }
237
238 #[test]
239 fn test_single_episode_tracked() {
240 let mut tracker = EpisodeTracker::new(100);
241 for i in 0..5 {
242 tracker.notify_push(i, i == 4); }
244 assert_eq!(tracker.num_complete_episodes(), 1);
245 assert_eq!(tracker.episodes()[0].length, 5);
246 assert_eq!(tracker.episodes()[0].start, 0);
247 assert!(tracker.episodes()[0].complete);
248 }
249
250 #[test]
251 fn test_multiple_episodes() {
252 let mut tracker = EpisodeTracker::new(100);
253 let mut pos = 0;
254 for _ in 0..3 {
256 tracker.notify_push(pos, pos == 2);
257 pos += 1;
258 }
259 for _ in 0..5 {
261 tracker.notify_push(pos, pos == 7);
262 pos += 1;
263 }
264 for _ in 0..2 {
266 tracker.notify_push(pos, pos == 9);
267 pos += 1;
268 }
269 assert_eq!(tracker.num_complete_episodes(), 3);
270 assert_eq!(tracker.episodes()[0].length, 3);
271 assert_eq!(tracker.episodes()[1].length, 5);
272 assert_eq!(tracker.episodes()[2].length, 2);
273 }
274
275 #[test]
276 fn test_incomplete_episode_not_counted() {
277 let mut tracker = EpisodeTracker::new(100);
278 for i in 0..5 {
279 tracker.notify_push(i, false);
280 }
281 assert_eq!(tracker.num_complete_episodes(), 0);
282 }
283
284 #[test]
285 fn test_invalidate_removes_overwritten() {
286 let mut tracker = EpisodeTracker::new(10);
287 for i in 0..5 {
289 tracker.notify_push(i, i == 4);
290 }
291 assert_eq!(tracker.num_complete_episodes(), 1);
292
293 tracker.invalidate_overwritten(0, 3);
295 assert_eq!(tracker.num_complete_episodes(), 0);
296 }
297
298 #[test]
299 fn test_sample_windows_within_episode() {
300 let mut tracker = EpisodeTracker::new(100);
301 for i in 0..10 {
303 tracker.notify_push(i, i == 9);
304 }
305 for i in 10..20 {
306 tracker.notify_push(i, i == 19);
307 }
308 assert_eq!(tracker.num_complete_episodes(), 2);
309
310 let windows = tracker.sample_windows(5, 5, 42).unwrap();
311 assert_eq!(windows.len(), 5);
312 for w in &windows {
313 assert_eq!(w.length, 5);
314 let ep = &tracker.episodes()[w.episode_idx];
316 let ep_end = ep.start + ep.length;
317 assert!(
318 w.ring_start >= ep.start && w.ring_start + w.length <= ep_end,
319 "window [{}, {}) not within episode [{}, {})",
320 w.ring_start,
321 w.ring_start + w.length,
322 ep.start,
323 ep_end
324 );
325 }
326 }
327
328 #[test]
329 fn test_sample_windows_deterministic() {
330 let mut tracker = EpisodeTracker::new(100);
331 for i in 0..10 {
332 tracker.notify_push(i, i == 9);
333 }
334 let w1 = tracker.sample_windows(5, 3, 42).unwrap();
335 let w2 = tracker.sample_windows(5, 3, 42).unwrap();
336 for (a, b) in w1.iter().zip(w2.iter()) {
337 assert_eq!(a.ring_start, b.ring_start);
338 assert_eq!(a.episode_idx, b.episode_idx);
339 assert_eq!(a.length, b.length);
340 }
341 }
342
343 #[test]
344 fn test_sample_windows_rejects_too_long_seq() {
345 let mut tracker = EpisodeTracker::new(100);
346 for i in 0..3 {
347 tracker.notify_push(i, i == 2);
348 }
349 let result = tracker.sample_windows(1, 5, 42);
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn test_eligible_episodes_filters_short() {
355 let mut tracker = EpisodeTracker::new(100);
356 let mut pos = 0;
357 for _ in 0..2 {
359 tracker.notify_push(pos, pos == 1);
360 pos += 1;
361 }
362 for _ in 0..5 {
364 tracker.notify_push(pos, pos == 6);
365 pos += 1;
366 }
367 for _ in 0..3 {
369 tracker.notify_push(pos, pos == 9);
370 pos += 1;
371 }
372 for _ in 0..8 {
374 tracker.notify_push(pos, pos == 17);
375 pos += 1;
376 }
377 let eligible = tracker.eligible_episodes(4);
378 assert_eq!(eligible, vec![1, 3]);
379 }
380
381 #[test]
382 fn test_invalidate_partial_episode() {
383 let mut tracker = EpisodeTracker::new(10);
384 for i in 0..5 {
386 tracker.notify_push(i, i == 4);
387 }
388 tracker.invalidate_overwritten(2, 1);
390 assert_eq!(
391 tracker.num_complete_episodes(),
392 0,
393 "partially overwritten episode should be removed"
394 );
395 }
396
397 #[test]
398 fn test_consecutive_dones() {
399 let mut tracker = EpisodeTracker::new(100);
400 tracker.notify_push(0, true); tracker.notify_push(1, true); assert_eq!(tracker.num_complete_episodes(), 2);
403 assert_eq!(tracker.episodes()[0].length, 1);
404 assert_eq!(tracker.episodes()[1].length, 1);
405 }
406
407 #[test]
408 fn test_empty_tracker_sample_windows_errors() {
409 let tracker = EpisodeTracker::new(100);
410 let result = tracker.sample_windows(1, 1, 42);
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_single_transition_episode() {
416 let mut tracker = EpisodeTracker::new(100);
417 tracker.notify_push(0, true);
418 assert_eq!(tracker.num_complete_episodes(), 1);
419 let windows = tracker.sample_windows(1, 1, 42).unwrap();
420 assert_eq!(windows[0].ring_start, 0);
421 assert_eq!(windows[0].length, 1);
422 }
423
424 #[test]
425 fn test_trait_object_safety() {
426 let tracker: Box<dyn EpisodeAware> = Box::new(EpisodeTracker::new(100));
427 assert_eq!(tracker.num_complete_episodes(), 0);
428 }
429
430 mod proptests {
431 use super::*;
432 use proptest::prelude::*;
433
434 proptest! {
435 #[test]
436 fn prop_episode_count_matches_dones(
437 n in 1usize..200,
438 done_rate in 0.05f64..0.5,
439 ) {
440 let mut tracker = EpisodeTracker::new(n * 2); let mut expected_complete = 0;
442 for i in 0..n {
443 let done = ((i as f64 + 1.0) * done_rate) as usize
444 > (i as f64 * done_rate) as usize;
445 tracker.notify_push(i, done);
446 if done {
447 expected_complete += 1;
448 }
449 }
450 prop_assert_eq!(
451 tracker.num_complete_episodes(),
452 expected_complete,
453 "expected {} complete episodes", expected_complete
454 );
455 }
456
457 #[test]
458 fn prop_window_within_bounds(
459 ep_len in 5usize..50,
460 seq_len in 1usize..5,
461 batch_size in 1usize..10,
462 ) {
463 let cap = ep_len * 3;
464 let mut tracker = EpisodeTracker::new(cap);
465 for i in 0..ep_len {
466 tracker.notify_push(i, i == ep_len - 1);
467 }
468 let windows = tracker.sample_windows(batch_size, seq_len, 42).unwrap();
469 for w in &windows {
470 prop_assert!(
471 w.ring_start + w.length <= cap,
472 "window [{}, {}) exceeds capacity {cap}",
473 w.ring_start,
474 w.ring_start + w.length
475 );
476 }
477 }
478
479 #[test]
480 fn prop_no_cross_episode_windows(
481 n_episodes in 2usize..10,
482 ep_len in 5usize..20,
483 seq_len in 1usize..5,
484 ) {
485 let cap = n_episodes * ep_len * 2;
486 let mut tracker = EpisodeTracker::new(cap);
487 let mut pos = 0;
488 for _ in 0..n_episodes {
489 for j in 0..ep_len {
490 tracker.notify_push(pos, j == ep_len - 1);
491 pos += 1;
492 }
493 }
494 let windows = tracker.sample_windows(n_episodes * 2, seq_len, 42).unwrap();
495 for w in &windows {
496 let ep = &tracker.episodes()[w.episode_idx];
497 let ep_end = ep.start + ep.length;
498 prop_assert!(
499 w.ring_start >= ep.start && w.ring_start + w.length <= ep_end,
500 "window crosses episode boundary"
501 );
502 }
503 }
504
505 #[test]
506 fn prop_invalidation_never_returns_overwritten(
507 cap in 10usize..100,
508 n_pushes in 1usize..300,
509 ) {
510 let mut tracker = EpisodeTracker::new(cap);
511 for (write_pos, i) in (0..n_pushes).enumerate() {
512 let done = i % 7 == 6; if write_pos >= cap {
514 tracker.invalidate_overwritten(write_pos % cap, 1);
516 }
517 tracker.notify_push(write_pos % cap, done);
518 }
519 for ep in tracker.episodes() {
521 prop_assert!(
522 ep.start < cap,
523 "episode start {} >= capacity {cap}",
524 ep.start
525 );
526 }
527 }
528 }
529 }
530}