rlox_core/buffer/
concurrent.rs

1use parking_lot::Mutex;
2
3use super::ringbuf::{ReplayBuffer, SampledBatch};
4use crate::error::RloxError;
5
6/// Thread-safe concurrent replay buffer backed by `parking_lot::Mutex`.
7///
8/// Multiple actor threads can push transitions concurrently.
9/// A single learner thread samples batches.
10///
11/// `parking_lot::Mutex` is ~2x faster than `std::Mutex` on uncontended locks
12/// (~10ns), which is negligible compared to the data copy cost of each push.
13///
14/// # Thread Safety
15///
16/// Automatically `Send + Sync` because `Mutex<T: Send>` is `Send + Sync`.
17pub struct ConcurrentReplayBuffer {
18    inner: Mutex<ReplayBuffer>,
19    capacity: usize,
20}
21
22impl ConcurrentReplayBuffer {
23    /// Create a concurrent replay buffer with fixed capacity.
24    ///
25    /// All arrays are pre-allocated inside the inner `ReplayBuffer`.
26    /// `obs_dim` and `act_dim` define the per-transition dimensionality.
27    pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
28        Self {
29            inner: Mutex::new(ReplayBuffer::new(capacity, obs_dim, act_dim)),
30            capacity,
31        }
32    }
33
34    /// Push a transition into the buffer from borrowed slices.
35    ///
36    /// Thread-safe: multiple threads may call `push` concurrently.
37    pub fn push(
38        &self,
39        obs: &[f32],
40        next_obs: &[f32],
41        action: &[f32],
42        reward: f32,
43        terminated: bool,
44        truncated: bool,
45    ) -> Result<(), RloxError> {
46        let mut buf = self.inner.lock();
47        buf.push_slices(obs, next_obs, action, reward, terminated, truncated)
48    }
49
50    /// Sample a batch of transitions uniformly at random.
51    ///
52    /// Uses `ChaCha8Rng` seeded with `seed` for deterministic cross-platform
53    /// reproducibility.
54    pub fn sample(&self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
55        let buf = self.inner.lock();
56        buf.sample(batch_size, seed)
57    }
58
59    /// Number of transitions currently stored.
60    ///
61    /// This is always `<= capacity`.
62    pub fn len(&self) -> usize {
63        self.inner.lock().len()
64    }
65
66    /// Whether the buffer has no transitions.
67    pub fn is_empty(&self) -> bool {
68        self.len() == 0
69    }
70
71    /// Maximum number of transitions the buffer can hold.
72    pub fn capacity(&self) -> usize {
73        self.capacity
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use std::sync::Arc;
81
82    #[test]
83    fn test_concurrent_push_single_thread() {
84        let buf = ConcurrentReplayBuffer::new(100, 4, 2);
85
86        for i in 0..50 {
87            let obs = vec![i as f32; 4];
88            let next_obs = vec![(i + 1) as f32; 4];
89            let action = vec![0.1, 0.2];
90            buf.push(&obs, &next_obs, &action, i as f32, false, false)
91                .expect("push should succeed");
92        }
93
94        assert_eq!(buf.len(), 50);
95        let batch = buf.sample(10, 42).expect("sample should succeed");
96        assert_eq!(batch.batch_size, 10);
97        assert_eq!(batch.observations.len(), 10 * 4);
98        assert_eq!(batch.actions.len(), 10 * 2);
99        assert_eq!(batch.rewards.len(), 10);
100    }
101
102    #[test]
103    fn test_concurrent_push_multi_thread() {
104        let buf = Arc::new(ConcurrentReplayBuffer::new(1000, 4, 1));
105        let n_threads = 4;
106        let pushes_per_thread = 200;
107
108        let handles: Vec<_> = (0..n_threads)
109            .map(|t| {
110                let buf = Arc::clone(&buf);
111                std::thread::spawn(move || {
112                    for i in 0..pushes_per_thread {
113                        let val = (t * pushes_per_thread + i) as f32;
114                        let obs = vec![val; 4];
115                        let next_obs = vec![val + 1.0; 4];
116                        let action = vec![val * 0.01];
117                        buf.push(&obs, &next_obs, &action, val, false, false)
118                            .expect("push should succeed");
119                    }
120                })
121            })
122            .collect();
123
124        for h in handles {
125            h.join().expect("thread should not panic");
126        }
127
128        assert_eq!(buf.len(), n_threads * pushes_per_thread);
129
130        // Verify no data corruption: sample and check shapes
131        let batch = buf.sample(100, 42).expect("sample should succeed");
132        assert_eq!(batch.batch_size, 100);
133        assert_eq!(batch.observations.len(), 100 * 4);
134
135        // Each obs should be [v, v, v, v] where v is consistent
136        for i in 0..batch.batch_size {
137            let obs_slice = &batch.observations[i * 4..(i + 1) * 4];
138            let first = obs_slice[0];
139            for &val in &obs_slice[1..] {
140                assert!(
141                    (val - first).abs() < f32::EPSILON,
142                    "data corruption detected: obs={obs_slice:?}"
143                );
144            }
145        }
146    }
147
148    #[test]
149    fn test_concurrent_sample_during_push() {
150        let buf = Arc::new(ConcurrentReplayBuffer::new(1000, 4, 1));
151
152        // Pre-fill some data so sampling can start immediately
153        for i in 0..100 {
154            let obs = vec![i as f32; 4];
155            let next_obs = vec![(i + 1) as f32; 4];
156            buf.push(&obs, &next_obs, &[0.0], 1.0, false, false)
157                .expect("push should succeed");
158        }
159
160        let buf_push = Arc::clone(&buf);
161        let push_handle = std::thread::spawn(move || {
162            for i in 100..500 {
163                let obs = vec![i as f32; 4];
164                let next_obs = vec![(i + 1) as f32; 4];
165                buf_push
166                    .push(&obs, &next_obs, &[0.0], 1.0, false, false)
167                    .expect("push should succeed");
168            }
169        });
170
171        // Sample concurrently while pushes are happening
172        let mut sample_count = 0;
173        for seed in 0..50u64 {
174            let len = buf.len();
175            if len >= 10 {
176                let batch = buf.sample(10, seed).expect("sample should succeed");
177                assert_eq!(batch.batch_size, 10);
178                sample_count += 1;
179            }
180        }
181
182        push_handle.join().expect("push thread should not panic");
183        assert!(sample_count > 0, "should have sampled at least once");
184    }
185
186    #[test]
187    fn test_concurrent_wrap_around() {
188        let capacity = 50;
189        let buf = Arc::new(ConcurrentReplayBuffer::new(capacity, 2, 1));
190        let n_threads = 4;
191        let pushes_per_thread = 100; // 400 total > capacity 50
192
193        let handles: Vec<_> = (0..n_threads)
194            .map(|t| {
195                let buf = Arc::clone(&buf);
196                std::thread::spawn(move || {
197                    for i in 0..pushes_per_thread {
198                        let val = (t * pushes_per_thread + i) as f32;
199                        buf.push(
200                            &[val, val],
201                            &[val + 1.0, val + 1.0],
202                            &[val],
203                            val,
204                            false,
205                            false,
206                        )
207                        .expect("push should succeed");
208                    }
209                })
210            })
211            .collect();
212
213        for h in handles {
214            h.join().expect("thread should not panic");
215        }
216
217        // len must not exceed capacity
218        assert_eq!(buf.len(), capacity);
219        assert_eq!(buf.capacity(), capacity);
220
221        // Should still be able to sample successfully
222        let batch = buf.sample(20, 42).expect("sample should succeed");
223        assert_eq!(batch.batch_size, 20);
224    }
225
226    #[test]
227    fn test_concurrent_send_sync() {
228        fn assert_send_sync<T: Send + Sync>() {}
229        assert_send_sync::<ConcurrentReplayBuffer>();
230    }
231
232    #[test]
233    fn test_concurrent_deterministic_sample() {
234        let buf = ConcurrentReplayBuffer::new(500, 4, 1);
235        for i in 0..200 {
236            let obs = vec![i as f32; 4];
237            let next_obs = vec![(i + 1) as f32; 4];
238            buf.push(&obs, &next_obs, &[0.0], i as f32, false, false)
239                .expect("push should succeed");
240        }
241
242        let b1 = buf.sample(32, 42).expect("sample should succeed");
243        let b2 = buf.sample(32, 42).expect("sample should succeed");
244        assert_eq!(b1.observations, b2.observations);
245        assert_eq!(b1.rewards, b2.rewards);
246        assert_eq!(b1.actions, b2.actions);
247    }
248
249    #[test]
250    fn test_concurrent_empty_sample_errors() {
251        let buf = ConcurrentReplayBuffer::new(100, 4, 1);
252        assert!(buf.sample(1, 42).is_err());
253        assert!(buf.is_empty());
254    }
255
256    #[test]
257    fn test_concurrent_shape_mismatch_errors() {
258        let buf = ConcurrentReplayBuffer::new(100, 4, 2);
259        // Wrong obs dim
260        assert!(buf
261            .push(&[1.0, 2.0], &[1.0; 4], &[0.0, 0.0], 1.0, false, false)
262            .is_err());
263        // Wrong next_obs dim
264        assert!(buf
265            .push(&[1.0; 4], &[1.0, 2.0], &[0.0, 0.0], 1.0, false, false)
266            .is_err());
267        // Wrong action dim
268        assert!(buf
269            .push(&[1.0; 4], &[1.0; 4], &[0.0], 1.0, false, false)
270            .is_err());
271    }
272
273    mod proptests {
274        use super::*;
275        use proptest::prelude::*;
276
277        proptest! {
278            #[test]
279            fn prop_concurrent_len_never_exceeds_capacity(
280                capacity in 1..500usize,
281                num_pushes in 0..2000usize
282            ) {
283                let buf = ConcurrentReplayBuffer::new(capacity, 2, 1);
284                for i in 0..num_pushes {
285                    let v = i as f32;
286                    buf.push(&[v, v], &[v, v], &[v], v, false, false).expect("push should succeed");
287                }
288                prop_assert!(buf.len() <= capacity);
289                prop_assert_eq!(buf.len(), num_pushes.min(capacity));
290            }
291        }
292    }
293}