1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3use std::thread::{self, JoinHandle};
4
5use crossbeam_channel::Sender;
6
7use crate::env::batch::BatchSteppable;
8use crate::env::spaces::Action;
9use crate::pipeline::channel::RolloutBatch;
10use crate::training::gae;
11
12pub struct AsyncCollector {
25 handle: Option<JoinHandle<()>>,
26 stop_flag: Arc<AtomicBool>,
27}
28
29impl AsyncCollector {
30 pub fn start(
39 mut envs: Box<dyn BatchSteppable>,
40 n_steps: usize,
41 gamma: f64,
42 gae_lambda: f64,
43 tx: Sender<RolloutBatch>,
44 value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync>,
45 action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync>,
46 ) -> Self {
47 let stop_flag = Arc::new(AtomicBool::new(false));
48 let stop = stop_flag.clone();
49
50 let handle = thread::spawn(move || {
51 let n_envs = envs.num_envs();
52
53 let obs_dim = match envs.obs_space() {
55 crate::env::spaces::ObsSpace::Discrete(_) => 1,
56 crate::env::spaces::ObsSpace::Box { shape, .. } => shape.iter().product(),
57 crate::env::spaces::ObsSpace::MultiDiscrete(v) => v.len(),
58 crate::env::spaces::ObsSpace::Dict(entries) => entries.iter().map(|(_, d)| d).sum(),
59 };
60
61 let act_dim = match envs.action_space() {
63 crate::env::spaces::ActionSpace::Discrete(_) => 1,
64 crate::env::spaces::ActionSpace::Box { shape, .. } => shape.iter().product(),
65 crate::env::spaces::ActionSpace::MultiDiscrete(v) => v.len(),
66 };
67
68 let init_obs = match envs.reset_batch(None) {
70 Ok(obs) => obs,
71 Err(_) => return,
72 };
73 let mut current_obs: Vec<f32> =
74 init_obs.into_iter().flat_map(|o| o.into_inner()).collect();
75
76 while !stop.load(Ordering::Relaxed) {
77 let total = n_steps * n_envs;
78
79 let mut all_obs = Vec::with_capacity(total * obs_dim);
80 let mut all_actions = Vec::with_capacity(total * act_dim);
81 let mut all_rewards = Vec::with_capacity(total);
82 let mut all_dones = Vec::with_capacity(total);
83 let mut all_values = Vec::with_capacity(total);
84 let mut all_log_probs = Vec::with_capacity(total);
85
86 let mut ok = true;
88 for _ in 0..n_steps {
89 if stop.load(Ordering::Relaxed) {
90 return;
91 }
92
93 let values = value_fn(¤t_obs);
95 let (actions_flat, log_probs) = action_fn(¤t_obs);
96
97 let actions: Vec<Action> = match envs.action_space() {
99 crate::env::spaces::ActionSpace::Discrete(_) => actions_flat
100 .iter()
101 .map(|&a| Action::Discrete(a as u32))
102 .collect(),
103 crate::env::spaces::ActionSpace::Box { shape, .. } => {
104 let dim: usize = shape.iter().product();
105 actions_flat
106 .chunks(dim)
107 .map(|chunk| Action::Continuous(chunk.to_vec()))
108 .collect()
109 }
110 crate::env::spaces::ActionSpace::MultiDiscrete(_) => actions_flat
111 .iter()
112 .map(|&a| Action::Discrete(a as u32))
113 .collect(),
114 };
115
116 all_obs.extend_from_slice(¤t_obs);
118 all_actions.extend_from_slice(&actions_flat);
119 all_values.extend(&values);
120 all_log_probs.extend(&log_probs);
121
122 match envs.step_batch(&actions) {
124 Ok(transition) => {
125 for i in 0..n_envs {
126 all_rewards.push(transition.rewards[i]);
127 let done = if transition.terminated[i] || transition.truncated[i] {
128 1.0
129 } else {
130 0.0
131 };
132 all_dones.push(done);
133 }
134 let mut offset = 0;
136 for obs_vec in transition.obs {
137 current_obs[offset..offset + obs_vec.len()]
138 .copy_from_slice(&obs_vec);
139 offset += obs_vec.len();
140 }
141 }
142 Err(_) => {
143 ok = false;
144 break;
145 }
146 }
147 }
148
149 if !ok {
150 break;
151 }
152
153 let last_values = value_fn(¤t_obs);
155
156 let mut env_major_rewards = vec![0.0; total];
158 let mut env_major_values = vec![0.0; total];
159 let mut env_major_dones = vec![0.0; total];
160 for t in 0..n_steps {
161 for e in 0..n_envs {
162 env_major_rewards[e * n_steps + t] = all_rewards[t * n_envs + e];
163 env_major_values[e * n_steps + t] = all_values[t * n_envs + e];
164 env_major_dones[e * n_steps + t] = all_dones[t * n_envs + e];
165 }
166 }
167
168 let (env_major_adv, env_major_ret) = gae::compute_gae_batched(
169 &env_major_rewards,
170 &env_major_values,
171 &env_major_dones,
172 &last_values,
173 n_steps,
174 gamma,
175 gae_lambda,
176 );
177
178 let mut advantages = vec![0.0; total];
180 let mut returns = vec![0.0; total];
181 for e in 0..n_envs {
182 for t in 0..n_steps {
183 advantages[t * n_envs + e] = env_major_adv[e * n_steps + t];
184 returns[t * n_envs + e] = env_major_ret[e * n_steps + t];
185 }
186 }
187
188 let batch = RolloutBatch {
189 observations: all_obs,
190 actions: all_actions,
191 rewards: all_rewards,
192 dones: all_dones,
193 log_probs: all_log_probs,
194 values: all_values,
195 advantages,
196 returns,
197 obs_dim,
198 act_dim,
199 n_steps,
200 n_envs,
201 };
202
203 if tx.send(batch).is_err() {
205 break; }
207 }
208 });
209
210 Self {
211 handle: Some(handle),
212 stop_flag,
213 }
214 }
215
216 pub fn stop(&mut self) {
218 self.stop_flag.store(true, Ordering::Relaxed);
219 if let Some(handle) = self.handle.take() {
220 let _ = handle.join();
221 }
222 }
223
224 pub fn is_stopped(&self) -> bool {
226 self.stop_flag.load(Ordering::Relaxed)
227 }
228}
229
230impl Drop for AsyncCollector {
231 fn drop(&mut self) {
232 self.stop();
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::env::builtins::CartPole;
240 use crate::env::parallel::VecEnv;
241 use crate::env::RLEnv;
242 use crate::pipeline::channel::Pipeline;
243 use crate::seed::derive_seed;
244
245 fn make_vec_env(n: usize, seed: u64) -> Box<dyn BatchSteppable> {
246 let envs: Vec<Box<dyn RLEnv>> = (0..n)
247 .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
248 .collect();
249 Box::new(VecEnv::new(envs).unwrap())
250 }
251
252 #[test]
253 fn test_async_collector_produces_batches() {
254 let pipe = Pipeline::new(4);
255 let tx = pipe.sender();
256
257 let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
258 Arc::new(|obs: &[f32]| vec![0.0; obs.len() / 4]); let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
260 Arc::new(|obs: &[f32]| {
261 let n = obs.len() / 4;
262 (vec![0.0; n], vec![0.0; n]) });
264
265 let mut collector = AsyncCollector::start(
266 make_vec_env(2, 42),
267 8, 0.99,
269 0.95,
270 tx,
271 value_fn,
272 action_fn,
273 );
274
275 let batch = pipe.recv().unwrap();
277 assert_eq!(batch.n_steps, 8);
278 assert_eq!(batch.n_envs, 2);
279 assert_eq!(batch.obs_dim, 4);
280 assert_eq!(batch.act_dim, 1);
281 assert_eq!(batch.observations.len(), 8 * 2 * 4);
282 assert_eq!(batch.rewards.len(), 8 * 2);
283 assert_eq!(batch.advantages.len(), 8 * 2);
284
285 collector.stop();
286 assert!(collector.is_stopped());
287 }
288
289 #[test]
290 fn test_async_collector_stop_is_idempotent() {
291 let pipe = Pipeline::new(2);
292 let tx = pipe.sender();
293
294 let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
295 Arc::new(|obs: &[f32]| vec![0.0; obs.len() / 4]);
296 let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
297 Arc::new(|obs: &[f32]| {
298 let n = obs.len() / 4;
299 (vec![0.0; n], vec![0.0; n])
300 });
301
302 let mut collector =
303 AsyncCollector::start(make_vec_env(1, 0), 4, 0.99, 0.95, tx, value_fn, action_fn);
304
305 collector.stop();
306 collector.stop(); }
308
309 #[test]
310 fn test_async_collector_gae_values_are_finite() {
311 let pipe = Pipeline::new(4);
312 let tx = pipe.sender();
313
314 let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
315 Arc::new(|obs: &[f32]| vec![0.5; obs.len() / 4]);
316 let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
317 Arc::new(|obs: &[f32]| {
318 let n = obs.len() / 4;
319 (vec![1.0; n], vec![-0.5; n])
320 });
321
322 let mut collector =
323 AsyncCollector::start(make_vec_env(4, 42), 16, 0.99, 0.95, tx, value_fn, action_fn);
324
325 let batch = pipe.recv().unwrap();
326 for &a in &batch.advantages {
327 assert!(a.is_finite(), "advantage must be finite, got {a}");
328 }
329 for &r in &batch.returns {
330 assert!(r.is_finite(), "return must be finite, got {r}");
331 }
332
333 collector.stop();
334 }
335}