rlox_core/buffer/
sequence.rs

1//! Sequence replay buffer for recurrent/transformer-based RL algorithms.
2//!
3//! Wraps a [`ReplayBuffer`] and [`EpisodeTracker`] to provide sampling of
4//! contiguous transition sequences that never cross episode boundaries.
5//! Used by DreamerV3, R2D2, and other sequence-based algorithms.
6
7use crate::error::RloxError;
8
9use super::episode::EpisodeTracker;
10use super::ringbuf::ReplayBuffer;
11
12/// Replay buffer that samples contiguous sequences of transitions.
13///
14/// Wraps a `ReplayBuffer` for storage and an `EpisodeTracker` for
15/// episode-aware sequence sampling. Sequences never cross episode
16/// boundaries.
17#[derive(Debug)]
18pub struct SequenceReplayBuffer {
19    buffer: ReplayBuffer,
20    tracker: EpisodeTracker,
21    obs_dim: usize,
22    act_dim: usize,
23    capacity: usize,
24}
25
26/// A batch of sampled sequences.
27#[derive(Debug, Clone)]
28pub struct SequenceBatch {
29    /// Observations: flat `(batch_size * seq_len * obs_dim)`.
30    pub observations: Vec<f32>,
31    /// Next observations: flat `(batch_size * seq_len * obs_dim)`.
32    pub next_observations: Vec<f32>,
33    /// Actions: flat `(batch_size * seq_len * act_dim)`.
34    pub actions: Vec<f32>,
35    /// Rewards: flat `(batch_size * seq_len)`.
36    pub rewards: Vec<f32>,
37    /// Terminated flags: flat `(batch_size * seq_len)`.
38    pub terminated: Vec<bool>,
39    /// Truncated flags: flat `(batch_size * seq_len)`.
40    pub truncated: Vec<bool>,
41    pub obs_dim: usize,
42    pub act_dim: usize,
43    pub batch_size: usize,
44    pub seq_len: usize,
45}
46
47impl SequenceReplayBuffer {
48    /// Create a new sequence replay buffer with given capacity and dimensions.
49    pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
50        Self {
51            buffer: ReplayBuffer::new(capacity, obs_dim, act_dim),
52            tracker: EpisodeTracker::new(capacity),
53            obs_dim,
54            act_dim,
55            capacity,
56        }
57    }
58
59    /// Push a single transition, notifying the episode tracker.
60    pub fn push_slices(
61        &mut self,
62        obs: &[f32],
63        next_obs: &[f32],
64        action: &[f32],
65        reward: f32,
66        terminated: bool,
67        truncated: bool,
68    ) -> Result<(), RloxError> {
69        let write_pos = self.buffer.write_pos();
70        let was_full = self.buffer.len() == self.capacity;
71
72        // If wrapping, invalidate episodes at the position about to be overwritten
73        if was_full {
74            self.tracker.invalidate_overwritten(write_pos, 1);
75        }
76
77        self.buffer
78            .push_slices(obs, next_obs, action, reward, terminated, truncated)?;
79
80        let done = terminated || truncated;
81        self.tracker.notify_push(write_pos, done);
82
83        Ok(())
84    }
85
86    /// Sample `batch_size` sequences of `seq_len` consecutive transitions.
87    ///
88    /// Each sequence is guaranteed to be within a single episode.
89    pub fn sample_sequences(
90        &self,
91        batch_size: usize,
92        seq_len: usize,
93        seed: u64,
94    ) -> Result<SequenceBatch, RloxError> {
95        let windows = self.tracker.sample_windows(batch_size, seq_len, seed)?;
96
97        let total_obs = batch_size * seq_len * self.obs_dim;
98        let total_act = batch_size * seq_len * self.act_dim;
99        let total_flat = batch_size * seq_len;
100
101        let mut batch = SequenceBatch {
102            observations: Vec::with_capacity(total_obs),
103            next_observations: Vec::with_capacity(total_obs),
104            actions: Vec::with_capacity(total_act),
105            rewards: Vec::with_capacity(total_flat),
106            terminated: Vec::with_capacity(total_flat),
107            truncated: Vec::with_capacity(total_flat),
108            obs_dim: self.obs_dim,
109            act_dim: self.act_dim,
110            batch_size,
111            seq_len,
112        };
113
114        for window in &windows {
115            for offset in 0..seq_len {
116                let idx = (window.ring_start + offset) % self.capacity;
117                let (obs, next_obs, action, reward, terminated, truncated) = self.buffer.get(idx);
118                batch.observations.extend_from_slice(obs);
119                batch.next_observations.extend_from_slice(next_obs);
120                batch.actions.extend_from_slice(action);
121                batch.rewards.push(reward);
122                batch.terminated.push(terminated);
123                batch.truncated.push(truncated);
124            }
125        }
126
127        Ok(batch)
128    }
129
130    /// Delegate to inner buffer for standard i.i.d. sampling.
131    pub fn sample(
132        &self,
133        batch_size: usize,
134        seed: u64,
135    ) -> Result<super::ringbuf::SampledBatch, RloxError> {
136        self.buffer.sample(batch_size, seed)
137    }
138
139    /// Number of valid transitions currently stored.
140    pub fn len(&self) -> usize {
141        self.buffer.len()
142    }
143
144    /// Whether the buffer is empty.
145    pub fn is_empty(&self) -> bool {
146        self.buffer.is_empty()
147    }
148
149    /// Number of complete episodes currently tracked.
150    pub fn num_complete_episodes(&self) -> usize {
151        self.tracker.num_complete_episodes()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    /// Helper: push an episode of `length` steps with identifiable data.
160    fn push_episode(buf: &mut SequenceReplayBuffer, length: usize, obs_base: f32) {
161        let obs_dim = buf.obs_dim;
162        let act_dim = buf.act_dim;
163        for i in 0..length {
164            let val = obs_base + i as f32;
165            let obs = vec![val; obs_dim];
166            let next_obs = vec![val + 1.0; obs_dim];
167            let action = vec![0.0; act_dim];
168            let reward = val;
169            let done = i == length - 1;
170            buf.push_slices(&obs, &next_obs, &action, reward, done, false)
171                .unwrap();
172        }
173    }
174
175    #[test]
176    fn test_new_is_empty() {
177        let buf = SequenceReplayBuffer::new(100, 4, 1);
178        assert_eq!(buf.len(), 0);
179        assert!(buf.is_empty());
180        assert_eq!(buf.num_complete_episodes(), 0);
181    }
182
183    #[test]
184    fn test_push_increments_len() {
185        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
186        push_episode(&mut buf, 5, 0.0);
187        assert_eq!(buf.len(), 5);
188        assert_eq!(buf.num_complete_episodes(), 1);
189    }
190
191    #[test]
192    fn test_single_episode_sequence_sample() {
193        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
194        push_episode(&mut buf, 10, 0.0);
195        let batch = buf.sample_sequences(1, 3, 42).unwrap();
196        assert_eq!(batch.batch_size, 1);
197        assert_eq!(batch.seq_len, 3);
198        assert_eq!(batch.observations.len(), 3 * 4);
199        assert_eq!(batch.rewards.len(), 3);
200    }
201
202    #[test]
203    fn test_sequences_dont_cross_episodes() {
204        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
205        // Two episodes of length 5 each
206        push_episode(&mut buf, 5, 0.0);
207        push_episode(&mut buf, 5, 100.0);
208        assert_eq!(buf.num_complete_episodes(), 2);
209
210        // Sample many sequences of length 4
211        let batch = buf.sample_sequences(20, 4, 42).unwrap();
212
213        // Each sequence should have rewards either all in [0,5) or all in [100,105)
214        for seq_idx in 0..20 {
215            let rewards: Vec<f32> = (0..4).map(|t| batch.rewards[seq_idx * 4 + t]).collect();
216            let all_low = rewards.iter().all(|&r| r < 50.0);
217            let all_high = rewards.iter().all(|&r| r >= 50.0);
218            assert!(
219                all_low || all_high,
220                "sequence {seq_idx} crosses episode boundary: {rewards:?}"
221            );
222        }
223    }
224
225    #[test]
226    fn test_sequence_contiguity() {
227        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
228        push_episode(&mut buf, 10, 0.0);
229
230        let batch = buf.sample_sequences(5, 5, 42).unwrap();
231        let obs_dim = 4;
232        let seq_len = 5;
233
234        for seq_idx in 0..5 {
235            for t in 0..(seq_len - 1) {
236                let next_obs_start = (seq_idx * seq_len + t) * obs_dim;
237                let obs_next_start = (seq_idx * seq_len + t + 1) * obs_dim;
238
239                let next_obs = &batch.next_observations[next_obs_start..next_obs_start + obs_dim];
240                let obs_t1 = &batch.observations[obs_next_start..obs_next_start + obs_dim];
241
242                assert_eq!(
243                    next_obs,
244                    obs_t1,
245                    "next_obs[{t}] != obs[{t_plus_1}] in seq {seq_idx}",
246                    t_plus_1 = t + 1
247                );
248            }
249        }
250    }
251
252    #[test]
253    fn test_sequence_deterministic() {
254        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
255        push_episode(&mut buf, 10, 0.0);
256        let b1 = buf.sample_sequences(5, 3, 42).unwrap();
257        let b2 = buf.sample_sequences(5, 3, 42).unwrap();
258        assert_eq!(b1.observations, b2.observations);
259        assert_eq!(b1.rewards, b2.rewards);
260    }
261
262    #[test]
263    fn test_reject_too_long_sequence() {
264        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
265        push_episode(&mut buf, 3, 0.0);
266        let result = buf.sample_sequences(1, 5, 42);
267        assert!(result.is_err());
268    }
269
270    #[test]
271    fn test_capacity_respected() {
272        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
273        // Push 200 transitions (20 episodes of length 10)
274        for i in 0..20 {
275            push_episode(&mut buf, 10, i as f32 * 100.0);
276        }
277        assert_eq!(buf.len(), 100);
278    }
279
280    #[test]
281    fn test_batch_shape_correct() {
282        let mut buf = SequenceReplayBuffer::new(200, 8, 2);
283        push_episode(&mut buf, 20, 0.0);
284
285        let batch = buf.sample_sequences(4, 3, 42).unwrap();
286        assert_eq!(batch.observations.len(), 4 * 3 * 8);
287        assert_eq!(batch.next_observations.len(), 4 * 3 * 8);
288        assert_eq!(batch.actions.len(), 4 * 3 * 2);
289        assert_eq!(batch.rewards.len(), 4 * 3);
290        assert_eq!(batch.terminated.len(), 4 * 3);
291        assert_eq!(batch.truncated.len(), 4 * 3);
292    }
293
294    #[test]
295    fn test_empty_buffer_sample_errors() {
296        let buf = SequenceReplayBuffer::new(100, 4, 1);
297        let result = buf.sample_sequences(1, 1, 42);
298        assert!(result.is_err());
299    }
300
301    #[test]
302    fn test_push_slices_validates_dims() {
303        let mut buf = SequenceReplayBuffer::new(100, 4, 1);
304        let result = buf.push_slices(
305            &[1.0, 2.0, 3.0], // wrong obs_dim: 3 instead of 4
306            &[1.0, 2.0, 3.0, 4.0],
307            &[0.0],
308            1.0,
309            false,
310            false,
311        );
312        assert!(result.is_err());
313    }
314
315    #[test]
316    fn test_multiple_episodes_mixed_lengths() {
317        let mut buf = SequenceReplayBuffer::new(200, 4, 1);
318        push_episode(&mut buf, 3, 0.0);
319        push_episode(&mut buf, 7, 100.0);
320        push_episode(&mut buf, 2, 200.0);
321        push_episode(&mut buf, 10, 300.0);
322
323        // seq_len=4: only episodes of length >= 4 are eligible (7 and 10)
324        let batch = buf.sample_sequences(10, 4, 42).unwrap();
325        for seq_idx in 0..10 {
326            let first_reward = batch.rewards[seq_idx * 4];
327            // Should be from episode starting at 100.0 or 300.0
328            assert!(
329                first_reward >= 90.0,
330                "seq {seq_idx} sampled from too-short episode: reward={first_reward}"
331            );
332        }
333    }
334
335    #[test]
336    fn test_sequence_rewards_match_buffer() {
337        let mut buf = SequenceReplayBuffer::new(100, 2, 1);
338        // Push a known episode
339        for i in 0..5 {
340            let val = (i + 1) as f32 * 10.0;
341            buf.push_slices(
342                &[val, val],
343                &[val + 1.0, val + 1.0],
344                &[0.0],
345                val,
346                i == 4,
347                false,
348            )
349            .unwrap();
350        }
351
352        // Sample the entire episode
353        let batch = buf.sample_sequences(1, 5, 42).unwrap();
354        // The rewards should be a contiguous sub-sequence of [10, 20, 30, 40, 50]
355        let expected = vec![10.0, 20.0, 30.0, 40.0, 50.0];
356        assert_eq!(batch.rewards, expected);
357    }
358
359    #[test]
360    fn test_is_send_sync() {
361        fn assert_send_sync<T: Send + Sync>() {}
362        assert_send_sync::<SequenceReplayBuffer>();
363    }
364
365    mod proptests {
366        use super::*;
367        use proptest::prelude::*;
368
369        proptest! {
370            #[test]
371            fn prop_batch_size_matches_request(
372                batch_size in 1usize..10,
373                seq_len in 1usize..5,
374                ep_len in 5usize..20,
375            ) {
376                let cap = ep_len * 5;
377                let mut buf = SequenceReplayBuffer::new(cap, 4, 1);
378                push_episode(&mut buf, ep_len, 0.0);
379                push_episode(&mut buf, ep_len, 100.0);
380                let batch = buf.sample_sequences(batch_size, seq_len, 42).unwrap();
381                prop_assert_eq!(batch.batch_size, batch_size);
382                prop_assert_eq!(batch.seq_len, seq_len);
383            }
384
385            #[test]
386            fn prop_len_never_exceeds_capacity(
387                cap in 10usize..100,
388                n_pushes in 1usize..300,
389            ) {
390                let mut buf = SequenceReplayBuffer::new(cap, 4, 1);
391                for i in 0..n_pushes {
392                    let done = i % 7 == 6;
393                    buf.push_slices(
394                        &[i as f32; 4],
395                        &[(i + 1) as f32; 4],
396                        &[0.0],
397                        i as f32,
398                        done,
399                        false,
400                    ).unwrap();
401                }
402                prop_assert!(buf.len() <= cap);
403            }
404
405            #[test]
406            fn prop_sequence_obs_contiguous(
407                ep_len in 5usize..20,
408                seq_len in 2usize..5,
409                batch_size in 1usize..5,
410            ) {
411                let mut buf = SequenceReplayBuffer::new(ep_len * 3, 4, 1);
412                push_episode(&mut buf, ep_len, 0.0);
413                let batch = buf.sample_sequences(batch_size, seq_len, 42).unwrap();
414                let obs_dim = 4;
415                for seq_idx in 0..batch_size {
416                    for t in 0..(seq_len - 1) {
417                        let next_start = (seq_idx * seq_len + t) * obs_dim;
418                        let obs_next_start = (seq_idx * seq_len + t + 1) * obs_dim;
419                        let next_obs = &batch.next_observations[next_start..next_start + obs_dim];
420                        let obs_t1 = &batch.observations[obs_next_start..obs_next_start + obs_dim];
421                        prop_assert_eq!(next_obs, obs_t1);
422                    }
423                }
424            }
425        }
426    }
427}