1use parking_lot::Mutex;
2
3use super::ringbuf::{ReplayBuffer, SampledBatch};
4use crate::error::RloxError;
5
6pub struct ConcurrentReplayBuffer {
18 inner: Mutex<ReplayBuffer>,
19 capacity: usize,
20}
21
22impl ConcurrentReplayBuffer {
23 pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
28 Self {
29 inner: Mutex::new(ReplayBuffer::new(capacity, obs_dim, act_dim)),
30 capacity,
31 }
32 }
33
34 pub fn push(
38 &self,
39 obs: &[f32],
40 next_obs: &[f32],
41 action: &[f32],
42 reward: f32,
43 terminated: bool,
44 truncated: bool,
45 ) -> Result<(), RloxError> {
46 let mut buf = self.inner.lock();
47 buf.push_slices(obs, next_obs, action, reward, terminated, truncated)
48 }
49
50 pub fn sample(&self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
55 let buf = self.inner.lock();
56 buf.sample(batch_size, seed)
57 }
58
59 pub fn len(&self) -> usize {
63 self.inner.lock().len()
64 }
65
66 pub fn is_empty(&self) -> bool {
68 self.len() == 0
69 }
70
71 pub fn capacity(&self) -> usize {
73 self.capacity
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use std::sync::Arc;
81
82 #[test]
83 fn test_concurrent_push_single_thread() {
84 let buf = ConcurrentReplayBuffer::new(100, 4, 2);
85
86 for i in 0..50 {
87 let obs = vec![i as f32; 4];
88 let next_obs = vec![(i + 1) as f32; 4];
89 let action = vec![0.1, 0.2];
90 buf.push(&obs, &next_obs, &action, i as f32, false, false)
91 .expect("push should succeed");
92 }
93
94 assert_eq!(buf.len(), 50);
95 let batch = buf.sample(10, 42).expect("sample should succeed");
96 assert_eq!(batch.batch_size, 10);
97 assert_eq!(batch.observations.len(), 10 * 4);
98 assert_eq!(batch.actions.len(), 10 * 2);
99 assert_eq!(batch.rewards.len(), 10);
100 }
101
102 #[test]
103 fn test_concurrent_push_multi_thread() {
104 let buf = Arc::new(ConcurrentReplayBuffer::new(1000, 4, 1));
105 let n_threads = 4;
106 let pushes_per_thread = 200;
107
108 let handles: Vec<_> = (0..n_threads)
109 .map(|t| {
110 let buf = Arc::clone(&buf);
111 std::thread::spawn(move || {
112 for i in 0..pushes_per_thread {
113 let val = (t * pushes_per_thread + i) as f32;
114 let obs = vec![val; 4];
115 let next_obs = vec![val + 1.0; 4];
116 let action = vec![val * 0.01];
117 buf.push(&obs, &next_obs, &action, val, false, false)
118 .expect("push should succeed");
119 }
120 })
121 })
122 .collect();
123
124 for h in handles {
125 h.join().expect("thread should not panic");
126 }
127
128 assert_eq!(buf.len(), n_threads * pushes_per_thread);
129
130 let batch = buf.sample(100, 42).expect("sample should succeed");
132 assert_eq!(batch.batch_size, 100);
133 assert_eq!(batch.observations.len(), 100 * 4);
134
135 for i in 0..batch.batch_size {
137 let obs_slice = &batch.observations[i * 4..(i + 1) * 4];
138 let first = obs_slice[0];
139 for &val in &obs_slice[1..] {
140 assert!(
141 (val - first).abs() < f32::EPSILON,
142 "data corruption detected: obs={obs_slice:?}"
143 );
144 }
145 }
146 }
147
148 #[test]
149 fn test_concurrent_sample_during_push() {
150 let buf = Arc::new(ConcurrentReplayBuffer::new(1000, 4, 1));
151
152 for i in 0..100 {
154 let obs = vec![i as f32; 4];
155 let next_obs = vec![(i + 1) as f32; 4];
156 buf.push(&obs, &next_obs, &[0.0], 1.0, false, false)
157 .expect("push should succeed");
158 }
159
160 let buf_push = Arc::clone(&buf);
161 let push_handle = std::thread::spawn(move || {
162 for i in 100..500 {
163 let obs = vec![i as f32; 4];
164 let next_obs = vec![(i + 1) as f32; 4];
165 buf_push
166 .push(&obs, &next_obs, &[0.0], 1.0, false, false)
167 .expect("push should succeed");
168 }
169 });
170
171 let mut sample_count = 0;
173 for seed in 0..50u64 {
174 let len = buf.len();
175 if len >= 10 {
176 let batch = buf.sample(10, seed).expect("sample should succeed");
177 assert_eq!(batch.batch_size, 10);
178 sample_count += 1;
179 }
180 }
181
182 push_handle.join().expect("push thread should not panic");
183 assert!(sample_count > 0, "should have sampled at least once");
184 }
185
186 #[test]
187 fn test_concurrent_wrap_around() {
188 let capacity = 50;
189 let buf = Arc::new(ConcurrentReplayBuffer::new(capacity, 2, 1));
190 let n_threads = 4;
191 let pushes_per_thread = 100; let handles: Vec<_> = (0..n_threads)
194 .map(|t| {
195 let buf = Arc::clone(&buf);
196 std::thread::spawn(move || {
197 for i in 0..pushes_per_thread {
198 let val = (t * pushes_per_thread + i) as f32;
199 buf.push(
200 &[val, val],
201 &[val + 1.0, val + 1.0],
202 &[val],
203 val,
204 false,
205 false,
206 )
207 .expect("push should succeed");
208 }
209 })
210 })
211 .collect();
212
213 for h in handles {
214 h.join().expect("thread should not panic");
215 }
216
217 assert_eq!(buf.len(), capacity);
219 assert_eq!(buf.capacity(), capacity);
220
221 let batch = buf.sample(20, 42).expect("sample should succeed");
223 assert_eq!(batch.batch_size, 20);
224 }
225
226 #[test]
227 fn test_concurrent_send_sync() {
228 fn assert_send_sync<T: Send + Sync>() {}
229 assert_send_sync::<ConcurrentReplayBuffer>();
230 }
231
232 #[test]
233 fn test_concurrent_deterministic_sample() {
234 let buf = ConcurrentReplayBuffer::new(500, 4, 1);
235 for i in 0..200 {
236 let obs = vec![i as f32; 4];
237 let next_obs = vec![(i + 1) as f32; 4];
238 buf.push(&obs, &next_obs, &[0.0], i as f32, false, false)
239 .expect("push should succeed");
240 }
241
242 let b1 = buf.sample(32, 42).expect("sample should succeed");
243 let b2 = buf.sample(32, 42).expect("sample should succeed");
244 assert_eq!(b1.observations, b2.observations);
245 assert_eq!(b1.rewards, b2.rewards);
246 assert_eq!(b1.actions, b2.actions);
247 }
248
249 #[test]
250 fn test_concurrent_empty_sample_errors() {
251 let buf = ConcurrentReplayBuffer::new(100, 4, 1);
252 assert!(buf.sample(1, 42).is_err());
253 assert!(buf.is_empty());
254 }
255
256 #[test]
257 fn test_concurrent_shape_mismatch_errors() {
258 let buf = ConcurrentReplayBuffer::new(100, 4, 2);
259 assert!(buf
261 .push(&[1.0, 2.0], &[1.0; 4], &[0.0, 0.0], 1.0, false, false)
262 .is_err());
263 assert!(buf
265 .push(&[1.0; 4], &[1.0, 2.0], &[0.0, 0.0], 1.0, false, false)
266 .is_err());
267 assert!(buf
269 .push(&[1.0; 4], &[1.0; 4], &[0.0], 1.0, false, false)
270 .is_err());
271 }
272
273 mod proptests {
274 use super::*;
275 use proptest::prelude::*;
276
277 proptest! {
278 #[test]
279 fn prop_concurrent_len_never_exceeds_capacity(
280 capacity in 1..500usize,
281 num_pushes in 0..2000usize
282 ) {
283 let buf = ConcurrentReplayBuffer::new(capacity, 2, 1);
284 for i in 0..num_pushes {
285 let v = i as f32;
286 buf.push(&[v, v], &[v, v], &[v], v, false, false).expect("push should succeed");
287 }
288 prop_assert!(buf.len() <= capacity);
289 prop_assert_eq!(buf.len(), num_pushes.min(capacity));
290 }
291 }
292 }
293}