rlox_core/pipeline/
channel.rs1use crossbeam_channel::{bounded, Receiver, Sender, TrySendError};
2
3use crate::error::RloxError;
4
5#[derive(Debug, Clone)]
11pub struct RolloutBatch {
12 pub observations: Vec<f32>,
14 pub actions: Vec<f32>,
16 pub rewards: Vec<f64>,
18 pub dones: Vec<f64>,
20 pub log_probs: Vec<f64>,
22 pub values: Vec<f64>,
24 pub advantages: Vec<f64>,
26 pub returns: Vec<f64>,
28 pub obs_dim: usize,
30 pub act_dim: usize,
32 pub n_steps: usize,
34 pub n_envs: usize,
36}
37
38pub struct Pipeline {
44 tx: Sender<RolloutBatch>,
45 rx: Receiver<RolloutBatch>,
46}
47
48impl Pipeline {
49 pub fn new(capacity: usize) -> Self {
55 assert!(capacity > 0, "Pipeline capacity must be at least 1");
56 let (tx, rx) = bounded(capacity);
57 Self { tx, rx }
58 }
59
60 pub fn send(&self, batch: RolloutBatch) -> Result<(), RloxError> {
62 self.tx
63 .send(batch)
64 .map_err(|_| RloxError::BufferError("pipeline channel disconnected".into()))
65 }
66
67 pub fn try_send(&self, batch: RolloutBatch) -> Result<(), RloxError> {
70 self.tx.try_send(batch).map_err(|e| match e {
71 TrySendError::Full(_) => RloxError::BufferError("pipeline channel full".into()),
72 TrySendError::Disconnected(_) => {
73 RloxError::BufferError("pipeline channel disconnected".into())
74 }
75 })
76 }
77
78 pub fn recv(&self) -> Result<RolloutBatch, RloxError> {
80 self.rx
81 .recv()
82 .map_err(|_| RloxError::BufferError("pipeline channel disconnected".into()))
83 }
84
85 pub fn try_recv(&self) -> Option<RolloutBatch> {
87 self.rx.try_recv().ok()
88 }
89
90 pub fn len(&self) -> usize {
92 self.rx.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.rx.is_empty()
98 }
99
100 pub fn sender(&self) -> Sender<RolloutBatch> {
102 self.tx.clone()
103 }
104
105 pub fn receiver(&self) -> Receiver<RolloutBatch> {
107 self.rx.clone()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 fn sample_batch(tag: usize) -> RolloutBatch {
116 let n_steps = 4;
117 let n_envs = 2;
118 let obs_dim = 3;
119 let act_dim = 1;
120 let total = n_steps * n_envs;
121 RolloutBatch {
122 observations: vec![tag as f32; total * obs_dim],
123 actions: vec![tag as f32; total * act_dim],
124 rewards: vec![tag as f64; total],
125 dones: vec![0.0; total],
126 log_probs: vec![-0.5; total],
127 values: vec![0.0; total],
128 advantages: vec![tag as f64 * 0.1; total],
129 returns: vec![tag as f64 * 0.5; total],
130 obs_dim,
131 act_dim,
132 n_steps,
133 n_envs,
134 }
135 }
136
137 #[test]
138 fn test_pipeline_new_is_empty() {
139 let pipe = Pipeline::new(4);
140 assert!(pipe.is_empty());
141 assert_eq!(pipe.len(), 0);
142 }
143
144 #[test]
145 fn test_pipeline_send_recv_roundtrip() {
146 let pipe = Pipeline::new(4);
147 let batch = sample_batch(42);
148 pipe.send(batch).unwrap();
149 assert_eq!(pipe.len(), 1);
150
151 let received = pipe.recv().unwrap();
152 assert_eq!(received.observations[0], 42.0);
153 assert_eq!(received.rewards[0], 42.0);
154 assert_eq!(received.obs_dim, 3);
155 assert_eq!(received.act_dim, 1);
156 assert_eq!(received.n_steps, 4);
157 assert_eq!(received.n_envs, 2);
158 }
159
160 #[test]
161 fn test_pipeline_try_recv_empty_returns_none() {
162 let pipe = Pipeline::new(4);
163 assert!(pipe.try_recv().is_none());
164 }
165
166 #[test]
167 fn test_pipeline_backpressure_blocks() {
168 let pipe = Pipeline::new(1);
169 pipe.send(sample_batch(1)).unwrap();
171 assert_eq!(pipe.len(), 1);
172
173 let result = pipe.try_send(sample_batch(2));
175 assert!(result.is_err());
176
177 let b = pipe.recv().unwrap();
179 assert_eq!(b.observations[0], 1.0);
180 }
181
182 #[test]
183 fn test_pipeline_rollout_batch_data_integrity() {
184 let pipe = Pipeline::new(4);
185 let n_steps = 4;
186 let n_envs = 2;
187 let obs_dim = 3;
188 let act_dim = 1;
189 let total = n_steps * n_envs;
190
191 let batch = RolloutBatch {
192 observations: (0..total * obs_dim).map(|i| i as f32).collect(),
193 actions: (0..total * act_dim).map(|i| i as f32 * 0.1).collect(),
194 rewards: (0..total).map(|i| i as f64).collect(),
195 dones: vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
196 log_probs: vec![-0.5; total],
197 values: vec![0.0; total],
198 advantages: (0..total).map(|i| i as f64 * 0.01).collect(),
199 returns: (0..total).map(|i| i as f64 * 0.5).collect(),
200 obs_dim,
201 act_dim,
202 n_steps,
203 n_envs,
204 };
205
206 pipe.send(batch).unwrap();
207 let out = pipe.recv().unwrap();
208
209 assert_eq!(out.observations.len(), total * obs_dim);
210 assert_eq!(out.actions.len(), total * act_dim);
211 assert_eq!(out.rewards.len(), total);
212 assert_eq!(out.dones.len(), total);
213 assert_eq!(out.advantages.len(), total);
214 assert_eq!(out.returns.len(), total);
215 assert_eq!(out.observations[5], 5.0);
217 assert_eq!(out.dones[2], 1.0);
218 assert_eq!(out.dones[7], 1.0);
219 assert!((out.returns[3] - 1.5).abs() < 1e-10);
220 }
221
222 #[test]
223 fn test_pipeline_multiple_batches_fifo_order() {
224 let pipe = Pipeline::new(8);
225
226 for i in 0..5 {
227 pipe.send(sample_batch(i)).unwrap();
228 }
229 assert_eq!(pipe.len(), 5);
230
231 for i in 0..5 {
232 let b = pipe.recv().unwrap();
233 assert_eq!(b.observations[0], i as f32, "batch {i} out of order");
234 }
235 assert!(pipe.is_empty());
236 }
237
238 #[test]
239 #[should_panic(expected = "Pipeline capacity must be at least 1")]
240 fn test_pipeline_zero_capacity_panics() {
241 Pipeline::new(0);
242 }
243
244 #[test]
245 fn test_pipeline_cross_thread_send_recv() {
246 let pipe = Pipeline::new(4);
247 let tx = pipe.sender();
248
249 let handle = std::thread::spawn(move || {
250 tx.send(sample_batch(99)).unwrap();
251 });
252
253 handle.join().unwrap();
254 let b = pipe.recv().unwrap();
255 assert_eq!(b.observations[0], 99.0);
256 }
257}