rlox_core/pipeline/
collector.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3use std::thread::{self, JoinHandle};
4
5use crossbeam_channel::Sender;
6
7use crate::env::batch::BatchSteppable;
8use crate::env::spaces::Action;
9use crate::pipeline::channel::RolloutBatch;
10use crate::training::gae;
11
12/// Asynchronous rollout collector that runs env stepping and GAE computation
13/// in a background thread, sending completed batches through a channel.
14///
15/// The collector requires two external function pointers:
16/// - `action_fn`: given flat observations `[n_envs * obs_dim]`, returns
17///   `(actions_flat, log_probs)` where actions are `[n_envs * act_dim]`
18///   and log_probs are `[n_envs]`.
19/// - `value_fn`: given flat observations `[n_envs * obs_dim]`, returns
20///   value estimates `[n_envs]`.
21///
22/// This design lets the Python side own the neural network while Rust handles
23/// the tight env-stepping loop.
24pub struct AsyncCollector {
25    handle: Option<JoinHandle<()>>,
26    stop_flag: Arc<AtomicBool>,
27}
28
29impl AsyncCollector {
30    /// Start the collector in a background thread.
31    ///
32    /// The collector will repeatedly:
33    /// 1. Collect `n_steps` of experience from all envs
34    /// 2. Compute GAE advantages and returns
35    /// 3. Send the resulting `RolloutBatch` through `tx`
36    ///
37    /// It stops when `stop()` is called or the channel is disconnected.
38    pub fn start(
39        mut envs: Box<dyn BatchSteppable>,
40        n_steps: usize,
41        gamma: f64,
42        gae_lambda: f64,
43        tx: Sender<RolloutBatch>,
44        value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync>,
45        action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync>,
46    ) -> Self {
47        let stop_flag = Arc::new(AtomicBool::new(false));
48        let stop = stop_flag.clone();
49
50        let handle = thread::spawn(move || {
51            let n_envs = envs.num_envs();
52
53            // Derive obs_dim from the observation space
54            let obs_dim = match envs.obs_space() {
55                crate::env::spaces::ObsSpace::Discrete(_) => 1,
56                crate::env::spaces::ObsSpace::Box { shape, .. } => shape.iter().product(),
57                crate::env::spaces::ObsSpace::MultiDiscrete(v) => v.len(),
58                crate::env::spaces::ObsSpace::Dict(entries) => entries.iter().map(|(_, d)| d).sum(),
59            };
60
61            // Derive act_dim from the action space
62            let act_dim = match envs.action_space() {
63                crate::env::spaces::ActionSpace::Discrete(_) => 1,
64                crate::env::spaces::ActionSpace::Box { shape, .. } => shape.iter().product(),
65                crate::env::spaces::ActionSpace::MultiDiscrete(v) => v.len(),
66            };
67
68            // Get initial observations
69            let init_obs = match envs.reset_batch(None) {
70                Ok(obs) => obs,
71                Err(_) => return,
72            };
73            let mut current_obs: Vec<f32> =
74                init_obs.into_iter().flat_map(|o| o.into_inner()).collect();
75
76            while !stop.load(Ordering::Relaxed) {
77                let total = n_steps * n_envs;
78
79                let mut all_obs = Vec::with_capacity(total * obs_dim);
80                let mut all_actions = Vec::with_capacity(total * act_dim);
81                let mut all_rewards = Vec::with_capacity(total);
82                let mut all_dones = Vec::with_capacity(total);
83                let mut all_values = Vec::with_capacity(total);
84                let mut all_log_probs = Vec::with_capacity(total);
85
86                // Collect n_steps of experience
87                let mut ok = true;
88                for _ in 0..n_steps {
89                    if stop.load(Ordering::Relaxed) {
90                        return;
91                    }
92
93                    // Get values and actions from the policy
94                    let values = value_fn(&current_obs);
95                    let (actions_flat, log_probs) = action_fn(&current_obs);
96
97                    // Convert flat actions to Action enum for stepping
98                    let actions: Vec<Action> = match envs.action_space() {
99                        crate::env::spaces::ActionSpace::Discrete(_) => actions_flat
100                            .iter()
101                            .map(|&a| Action::Discrete(a as u32))
102                            .collect(),
103                        crate::env::spaces::ActionSpace::Box { shape, .. } => {
104                            let dim: usize = shape.iter().product();
105                            actions_flat
106                                .chunks(dim)
107                                .map(|chunk| Action::Continuous(chunk.to_vec()))
108                                .collect()
109                        }
110                        crate::env::spaces::ActionSpace::MultiDiscrete(_) => actions_flat
111                            .iter()
112                            .map(|&a| Action::Discrete(a as u32))
113                            .collect(),
114                    };
115
116                    // Store current obs, actions, values, log_probs
117                    all_obs.extend_from_slice(&current_obs);
118                    all_actions.extend_from_slice(&actions_flat);
119                    all_values.extend(&values);
120                    all_log_probs.extend(&log_probs);
121
122                    // Step environments
123                    match envs.step_batch(&actions) {
124                        Ok(transition) => {
125                            for i in 0..n_envs {
126                                all_rewards.push(transition.rewards[i]);
127                                let done = if transition.terminated[i] || transition.truncated[i] {
128                                    1.0
129                                } else {
130                                    0.0
131                                };
132                                all_dones.push(done);
133                            }
134                            // Update current observations (reuse allocation)
135                            let mut offset = 0;
136                            for obs_vec in transition.obs {
137                                current_obs[offset..offset + obs_vec.len()]
138                                    .copy_from_slice(&obs_vec);
139                                offset += obs_vec.len();
140                            }
141                        }
142                        Err(_) => {
143                            ok = false;
144                            break;
145                        }
146                    }
147                }
148
149                if !ok {
150                    break;
151                }
152
153                // Bootstrap value for GAE
154                let last_values = value_fn(&current_obs);
155
156                // Transpose step-major -> env-major for batched GAE
157                let mut env_major_rewards = vec![0.0; total];
158                let mut env_major_values = vec![0.0; total];
159                let mut env_major_dones = vec![0.0; total];
160                for t in 0..n_steps {
161                    for e in 0..n_envs {
162                        env_major_rewards[e * n_steps + t] = all_rewards[t * n_envs + e];
163                        env_major_values[e * n_steps + t] = all_values[t * n_envs + e];
164                        env_major_dones[e * n_steps + t] = all_dones[t * n_envs + e];
165                    }
166                }
167
168                let (env_major_adv, env_major_ret) = gae::compute_gae_batched(
169                    &env_major_rewards,
170                    &env_major_values,
171                    &env_major_dones,
172                    &last_values,
173                    n_steps,
174                    gamma,
175                    gae_lambda,
176                );
177
178                // Transpose back env-major -> step-major
179                let mut advantages = vec![0.0; total];
180                let mut returns = vec![0.0; total];
181                for e in 0..n_envs {
182                    for t in 0..n_steps {
183                        advantages[t * n_envs + e] = env_major_adv[e * n_steps + t];
184                        returns[t * n_envs + e] = env_major_ret[e * n_steps + t];
185                    }
186                }
187
188                let batch = RolloutBatch {
189                    observations: all_obs,
190                    actions: all_actions,
191                    rewards: all_rewards,
192                    dones: all_dones,
193                    log_probs: all_log_probs,
194                    values: all_values,
195                    advantages,
196                    returns,
197                    obs_dim,
198                    act_dim,
199                    n_steps,
200                    n_envs,
201                };
202
203                // Send — blocks if channel is full (backpressure)
204                if tx.send(batch).is_err() {
205                    break; // receiver dropped
206                }
207            }
208        });
209
210        Self {
211            handle: Some(handle),
212            stop_flag,
213        }
214    }
215
216    /// Signal the collector to stop and wait for the thread to finish.
217    pub fn stop(&mut self) {
218        self.stop_flag.store(true, Ordering::Relaxed);
219        if let Some(handle) = self.handle.take() {
220            let _ = handle.join();
221        }
222    }
223
224    /// Check whether the collector has been asked to stop.
225    pub fn is_stopped(&self) -> bool {
226        self.stop_flag.load(Ordering::Relaxed)
227    }
228}
229
230impl Drop for AsyncCollector {
231    fn drop(&mut self) {
232        self.stop();
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::env::builtins::CartPole;
240    use crate::env::parallel::VecEnv;
241    use crate::env::RLEnv;
242    use crate::pipeline::channel::Pipeline;
243    use crate::seed::derive_seed;
244
245    fn make_vec_env(n: usize, seed: u64) -> Box<dyn BatchSteppable> {
246        let envs: Vec<Box<dyn RLEnv>> = (0..n)
247            .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
248            .collect();
249        Box::new(VecEnv::new(envs).unwrap())
250    }
251
252    #[test]
253    fn test_async_collector_produces_batches() {
254        let pipe = Pipeline::new(4);
255        let tx = pipe.sender();
256
257        let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
258            Arc::new(|obs: &[f32]| vec![0.0; obs.len() / 4]); // CartPole obs_dim=4
259        let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
260            Arc::new(|obs: &[f32]| {
261                let n = obs.len() / 4;
262                (vec![0.0; n], vec![0.0; n]) // always action 0
263            });
264
265        let mut collector = AsyncCollector::start(
266            make_vec_env(2, 42),
267            8, // n_steps
268            0.99,
269            0.95,
270            tx,
271            value_fn,
272            action_fn,
273        );
274
275        // Should receive at least one batch
276        let batch = pipe.recv().unwrap();
277        assert_eq!(batch.n_steps, 8);
278        assert_eq!(batch.n_envs, 2);
279        assert_eq!(batch.obs_dim, 4);
280        assert_eq!(batch.act_dim, 1);
281        assert_eq!(batch.observations.len(), 8 * 2 * 4);
282        assert_eq!(batch.rewards.len(), 8 * 2);
283        assert_eq!(batch.advantages.len(), 8 * 2);
284
285        collector.stop();
286        assert!(collector.is_stopped());
287    }
288
289    #[test]
290    fn test_async_collector_stop_is_idempotent() {
291        let pipe = Pipeline::new(2);
292        let tx = pipe.sender();
293
294        let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
295            Arc::new(|obs: &[f32]| vec![0.0; obs.len() / 4]);
296        let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
297            Arc::new(|obs: &[f32]| {
298                let n = obs.len() / 4;
299                (vec![0.0; n], vec![0.0; n])
300            });
301
302        let mut collector =
303            AsyncCollector::start(make_vec_env(1, 0), 4, 0.99, 0.95, tx, value_fn, action_fn);
304
305        collector.stop();
306        collector.stop(); // should not panic
307    }
308
309    #[test]
310    fn test_async_collector_gae_values_are_finite() {
311        let pipe = Pipeline::new(4);
312        let tx = pipe.sender();
313
314        let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
315            Arc::new(|obs: &[f32]| vec![0.5; obs.len() / 4]);
316        let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
317            Arc::new(|obs: &[f32]| {
318                let n = obs.len() / 4;
319                (vec![1.0; n], vec![-0.5; n])
320            });
321
322        let mut collector =
323            AsyncCollector::start(make_vec_env(4, 42), 16, 0.99, 0.95, tx, value_fn, action_fn);
324
325        let batch = pipe.recv().unwrap();
326        for &a in &batch.advantages {
327            assert!(a.is_finite(), "advantage must be finite, got {a}");
328        }
329        for &r in &batch.returns {
330            assert!(r.is_finite(), "return must be finite, got {r}");
331        }
332
333        collector.stop();
334    }
335}