rlox_core/pipeline/
channel.rs

1use crossbeam_channel::{bounded, Receiver, Sender, TrySendError};
2
3use crate::error::RloxError;
4
5/// A batch of rollout data ready for the learner.
6///
7/// All vectors are flat (row-major). For example, `observations` has length
8/// `n_steps * n_envs * obs_dim`, laid out as `[step0_env0, step0_env1, ...,
9/// step1_env0, ...]`.
10#[derive(Debug, Clone)]
11pub struct RolloutBatch {
12    /// Flat observations: `[n_steps * n_envs * obs_dim]`.
13    pub observations: Vec<f32>,
14    /// Flat actions: `[n_steps * n_envs * act_dim]`.
15    pub actions: Vec<f32>,
16    /// Rewards: `[n_steps * n_envs]`.
17    pub rewards: Vec<f64>,
18    /// Done flags (0.0 or 1.0): `[n_steps * n_envs]`.
19    pub dones: Vec<f64>,
20    /// Log-probabilities of the collected actions: `[n_steps * n_envs]`.
21    pub log_probs: Vec<f64>,
22    /// Value estimates at collection time: `[n_steps * n_envs]`.
23    pub values: Vec<f64>,
24    /// GAE advantages: `[n_steps * n_envs]`.
25    pub advantages: Vec<f64>,
26    /// Discounted returns: `[n_steps * n_envs]`.
27    pub returns: Vec<f64>,
28    /// Observation dimensionality.
29    pub obs_dim: usize,
30    /// Action dimensionality.
31    pub act_dim: usize,
32    /// Number of time steps in this batch.
33    pub n_steps: usize,
34    /// Number of environments that contributed to this batch.
35    pub n_envs: usize,
36}
37
38/// Bounded experience pipeline for decoupled collection and training.
39///
40/// Uses a crossbeam bounded channel internally, providing backpressure when
41/// the learner falls behind the collector. The `capacity` controls how many
42/// `RolloutBatch`es can be buffered before the sender blocks.
43pub struct Pipeline {
44    tx: Sender<RolloutBatch>,
45    rx: Receiver<RolloutBatch>,
46}
47
48impl Pipeline {
49    /// Create a new pipeline with the given buffer capacity.
50    ///
51    /// # Panics
52    ///
53    /// Panics if `capacity` is zero.
54    pub fn new(capacity: usize) -> Self {
55        assert!(capacity > 0, "Pipeline capacity must be at least 1");
56        let (tx, rx) = bounded(capacity);
57        Self { tx, rx }
58    }
59
60    /// Send a batch into the pipeline (blocks if full).
61    pub fn send(&self, batch: RolloutBatch) -> Result<(), RloxError> {
62        self.tx
63            .send(batch)
64            .map_err(|_| RloxError::BufferError("pipeline channel disconnected".into()))
65    }
66
67    /// Try to send a batch without blocking. Returns `Ok(())` on success,
68    /// `Err` with the batch if the channel is full or disconnected.
69    pub fn try_send(&self, batch: RolloutBatch) -> Result<(), RloxError> {
70        self.tx.try_send(batch).map_err(|e| match e {
71            TrySendError::Full(_) => RloxError::BufferError("pipeline channel full".into()),
72            TrySendError::Disconnected(_) => {
73                RloxError::BufferError("pipeline channel disconnected".into())
74            }
75        })
76    }
77
78    /// Receive a batch (blocks until one is available).
79    pub fn recv(&self) -> Result<RolloutBatch, RloxError> {
80        self.rx
81            .recv()
82            .map_err(|_| RloxError::BufferError("pipeline channel disconnected".into()))
83    }
84
85    /// Try to receive a batch without blocking.
86    pub fn try_recv(&self) -> Option<RolloutBatch> {
87        self.rx.try_recv().ok()
88    }
89
90    /// Number of batches currently buffered in the channel.
91    pub fn len(&self) -> usize {
92        self.rx.len()
93    }
94
95    /// Whether the channel is currently empty.
96    pub fn is_empty(&self) -> bool {
97        self.rx.is_empty()
98    }
99
100    /// Get a clone of the sender (for use in collector threads).
101    pub fn sender(&self) -> Sender<RolloutBatch> {
102        self.tx.clone()
103    }
104
105    /// Get a clone of the receiver (for use in learner threads).
106    pub fn receiver(&self) -> Receiver<RolloutBatch> {
107        self.rx.clone()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    fn sample_batch(tag: usize) -> RolloutBatch {
116        let n_steps = 4;
117        let n_envs = 2;
118        let obs_dim = 3;
119        let act_dim = 1;
120        let total = n_steps * n_envs;
121        RolloutBatch {
122            observations: vec![tag as f32; total * obs_dim],
123            actions: vec![tag as f32; total * act_dim],
124            rewards: vec![tag as f64; total],
125            dones: vec![0.0; total],
126            log_probs: vec![-0.5; total],
127            values: vec![0.0; total],
128            advantages: vec![tag as f64 * 0.1; total],
129            returns: vec![tag as f64 * 0.5; total],
130            obs_dim,
131            act_dim,
132            n_steps,
133            n_envs,
134        }
135    }
136
137    #[test]
138    fn test_pipeline_new_is_empty() {
139        let pipe = Pipeline::new(4);
140        assert!(pipe.is_empty());
141        assert_eq!(pipe.len(), 0);
142    }
143
144    #[test]
145    fn test_pipeline_send_recv_roundtrip() {
146        let pipe = Pipeline::new(4);
147        let batch = sample_batch(42);
148        pipe.send(batch).unwrap();
149        assert_eq!(pipe.len(), 1);
150
151        let received = pipe.recv().unwrap();
152        assert_eq!(received.observations[0], 42.0);
153        assert_eq!(received.rewards[0], 42.0);
154        assert_eq!(received.obs_dim, 3);
155        assert_eq!(received.act_dim, 1);
156        assert_eq!(received.n_steps, 4);
157        assert_eq!(received.n_envs, 2);
158    }
159
160    #[test]
161    fn test_pipeline_try_recv_empty_returns_none() {
162        let pipe = Pipeline::new(4);
163        assert!(pipe.try_recv().is_none());
164    }
165
166    #[test]
167    fn test_pipeline_backpressure_blocks() {
168        let pipe = Pipeline::new(1);
169        // First send succeeds
170        pipe.send(sample_batch(1)).unwrap();
171        assert_eq!(pipe.len(), 1);
172
173        // Second send via try_send should fail (channel full)
174        let result = pipe.try_send(sample_batch(2));
175        assert!(result.is_err());
176
177        // Drain and verify the first batch is still intact
178        let b = pipe.recv().unwrap();
179        assert_eq!(b.observations[0], 1.0);
180    }
181
182    #[test]
183    fn test_pipeline_rollout_batch_data_integrity() {
184        let pipe = Pipeline::new(4);
185        let n_steps = 4;
186        let n_envs = 2;
187        let obs_dim = 3;
188        let act_dim = 1;
189        let total = n_steps * n_envs;
190
191        let batch = RolloutBatch {
192            observations: (0..total * obs_dim).map(|i| i as f32).collect(),
193            actions: (0..total * act_dim).map(|i| i as f32 * 0.1).collect(),
194            rewards: (0..total).map(|i| i as f64).collect(),
195            dones: vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
196            log_probs: vec![-0.5; total],
197            values: vec![0.0; total],
198            advantages: (0..total).map(|i| i as f64 * 0.01).collect(),
199            returns: (0..total).map(|i| i as f64 * 0.5).collect(),
200            obs_dim,
201            act_dim,
202            n_steps,
203            n_envs,
204        };
205
206        pipe.send(batch).unwrap();
207        let out = pipe.recv().unwrap();
208
209        assert_eq!(out.observations.len(), total * obs_dim);
210        assert_eq!(out.actions.len(), total * act_dim);
211        assert_eq!(out.rewards.len(), total);
212        assert_eq!(out.dones.len(), total);
213        assert_eq!(out.advantages.len(), total);
214        assert_eq!(out.returns.len(), total);
215        // Spot-check values
216        assert_eq!(out.observations[5], 5.0);
217        assert_eq!(out.dones[2], 1.0);
218        assert_eq!(out.dones[7], 1.0);
219        assert!((out.returns[3] - 1.5).abs() < 1e-10);
220    }
221
222    #[test]
223    fn test_pipeline_multiple_batches_fifo_order() {
224        let pipe = Pipeline::new(8);
225
226        for i in 0..5 {
227            pipe.send(sample_batch(i)).unwrap();
228        }
229        assert_eq!(pipe.len(), 5);
230
231        for i in 0..5 {
232            let b = pipe.recv().unwrap();
233            assert_eq!(b.observations[0], i as f32, "batch {i} out of order");
234        }
235        assert!(pipe.is_empty());
236    }
237
238    #[test]
239    #[should_panic(expected = "Pipeline capacity must be at least 1")]
240    fn test_pipeline_zero_capacity_panics() {
241        Pipeline::new(0);
242    }
243
244    #[test]
245    fn test_pipeline_cross_thread_send_recv() {
246        let pipe = Pipeline::new(4);
247        let tx = pipe.sender();
248
249        let handle = std::thread::spawn(move || {
250            tx.send(sample_batch(99)).unwrap();
251        });
252
253        handle.join().unwrap();
254        let b = pipe.recv().unwrap();
255        assert_eq!(b.observations[0], 99.0);
256    }
257}