rlox_core/env/
parallel.rs

1use rayon::prelude::*;
2
3use crate::env::batch::BatchSteppable;
4use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
5use crate::env::RLEnv;
6use crate::error::RloxError;
7use crate::seed::derive_seed;
8
9/// Columnar batch of transitions from parallel stepping.
10#[derive(Debug, Clone)]
11pub struct BatchTransition {
12    /// Observations: `[num_envs][obs_dim]` — post-reset obs when done
13    pub obs: Vec<Vec<f32>>,
14    /// Flat observations: `[num_envs * obs_dim]` contiguous layout.
15    /// Populated by `step_all_flat`. Empty when using `step_all`.
16    pub obs_flat: Vec<f32>,
17    /// Observation dimensionality (set by step_all_flat).
18    pub obs_dim: usize,
19    /// Rewards: `[num_envs]`
20    pub rewards: Vec<f64>,
21    /// Terminated flags: `[num_envs]`
22    pub terminated: Vec<bool>,
23    /// Truncated flags: `[num_envs]`
24    pub truncated: Vec<bool>,
25    /// Terminal observations: `Some` when terminated or truncated, `None` otherwise.
26    /// Contains the observation *before* auto-reset, needed for value bootstrapping.
27    pub terminal_obs: Vec<Option<Vec<f32>>>,
28}
29
30/// A vectorized environment that steps multiple sub-environments in parallel.
31pub struct VecEnv {
32    envs: Vec<Box<dyn RLEnv>>,
33    action_space: ActionSpace,
34    obs_space: ObsSpace,
35}
36
37impl std::fmt::Debug for VecEnv {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("VecEnv")
40            .field("num_envs", &self.envs.len())
41            .field("action_space", &self.action_space)
42            .field("obs_space", &self.obs_space)
43            .finish()
44    }
45}
46
47impl VecEnv {
48    /// Create a new vectorized environment from a list of sub-environments.
49    ///
50    /// # Errors
51    ///
52    /// Returns [`RloxError::EnvError`] if `envs` is empty.
53    pub fn new(envs: Vec<Box<dyn RLEnv>>) -> Result<Self, RloxError> {
54        if envs.is_empty() {
55            return Err(RloxError::EnvError(
56                "VecEnv requires at least one environment".into(),
57            ));
58        }
59        let action_space = envs[0].action_space().clone();
60        let obs_space = envs[0].obs_space().clone();
61        Ok(VecEnv {
62            envs,
63            action_space,
64            obs_space,
65        })
66    }
67
68    pub fn num_envs(&self) -> usize {
69        self.envs.len()
70    }
71
72    pub fn action_space(&self) -> &ActionSpace {
73        &self.action_space
74    }
75
76    /// Step + auto-reset all environments in parallel. Returns the raw
77    /// per-environment results: `(obs_data, reward, terminated, truncated, terminal_obs)`.
78    fn step_raw(
79        &mut self,
80        actions: &[Action],
81    ) -> Result<Vec<(Vec<f32>, f64, bool, bool, Option<Vec<f32>>)>, RloxError> {
82        if actions.len() != self.envs.len() {
83            return Err(RloxError::ShapeMismatch {
84                expected: format!("{}", self.envs.len()),
85                got: format!("{}", actions.len()),
86            });
87        }
88
89        let results: Vec<Result<(Vec<f32>, f64, bool, bool, Option<Vec<f32>>), RloxError>> = self
90            .envs
91            .par_iter_mut()
92            .zip(actions.par_iter())
93            .map(|(env, action)| {
94                let mut transition = env.step(action)?;
95                let mut term_obs = None;
96                if transition.terminated || transition.truncated {
97                    term_obs = Some(transition.obs.clone().into_inner());
98                    let new_obs = env.reset(None)?;
99                    transition.obs = new_obs;
100                }
101                let obs_data = transition.obs.into_inner();
102                Ok((
103                    obs_data,
104                    transition.reward,
105                    transition.terminated,
106                    transition.truncated,
107                    term_obs,
108                ))
109            })
110            .collect();
111
112        results.into_iter().collect()
113    }
114
115    /// Step all environments in parallel using Rayon.
116    ///
117    /// If an environment is done after stepping, it is automatically reset
118    /// and the returned observation is from the fresh episode.
119    pub fn step_all(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
120        let raw = self.step_raw(actions)?;
121        let n = raw.len();
122        let mut obs = Vec::with_capacity(n);
123        let mut rewards = Vec::with_capacity(n);
124        let mut terminated = Vec::with_capacity(n);
125        let mut truncated = Vec::with_capacity(n);
126        let mut terminal_obs = Vec::with_capacity(n);
127
128        for (obs_data, reward, term, trunc, tobs) in raw {
129            obs.push(obs_data);
130            rewards.push(reward);
131            terminated.push(term);
132            truncated.push(trunc);
133            terminal_obs.push(tobs);
134        }
135
136        Ok(BatchTransition {
137            obs,
138            obs_flat: Vec::new(),
139            obs_dim: 0,
140            rewards,
141            terminated,
142            truncated,
143            terminal_obs,
144        })
145    }
146
147    /// Step all environments in parallel, returning observations as a flat contiguous buffer.
148    ///
149    /// Unlike `step_all`, this avoids per-env Vec allocations by collecting
150    /// observations directly into `obs_flat: Vec<f32>` of shape `[n_envs * obs_dim]`.
151    /// The `obs` field is left empty.
152    pub fn step_all_flat(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
153        let obs_dim = match &self.obs_space {
154            ObsSpace::Discrete(_) => 1,
155            ObsSpace::Box { shape, .. } => shape.iter().product(),
156            ObsSpace::MultiDiscrete(v) => v.len(),
157            ObsSpace::Dict(entries) => entries.iter().map(|(_, d)| d).sum(),
158        };
159
160        let raw = self.step_raw(actions)?;
161        let n = raw.len();
162        let mut obs_flat = vec![0.0f32; n * obs_dim];
163        let mut rewards = Vec::with_capacity(n);
164        let mut terminated = Vec::with_capacity(n);
165        let mut truncated = Vec::with_capacity(n);
166        let mut terminal_obs = Vec::with_capacity(n);
167
168        for (i, (obs_data, reward, term, trunc, tobs)) in raw.into_iter().enumerate() {
169            obs_flat[i * obs_dim..(i + 1) * obs_dim].copy_from_slice(&obs_data);
170            rewards.push(reward);
171            terminated.push(term);
172            truncated.push(trunc);
173            terminal_obs.push(tobs);
174        }
175
176        Ok(BatchTransition {
177            obs: Vec::new(),
178            obs_flat,
179            obs_dim,
180            rewards,
181            terminated,
182            truncated,
183            terminal_obs,
184        })
185    }
186
187    /// Reset all environments, optionally seeding them deterministically.
188    ///
189    /// When a master seed is provided, each env `i` gets `derive_seed(master, i)`.
190    pub fn reset_all(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError> {
191        self.envs
192            .iter_mut()
193            .enumerate()
194            .map(|(i, env)| {
195                let env_seed = seed.map(|s| derive_seed(s, i));
196                env.reset(env_seed)
197            })
198            .collect()
199    }
200}
201
202impl BatchSteppable for VecEnv {
203    fn step_batch(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
204        self.step_all(actions)
205    }
206
207    fn reset_batch(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError> {
208        self.reset_all(seed)
209    }
210
211    fn num_envs(&self) -> usize {
212        self.num_envs()
213    }
214
215    fn action_space(&self) -> &ActionSpace {
216        &self.action_space
217    }
218
219    fn obs_space(&self) -> &ObsSpace {
220        &self.obs_space
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::env::builtins::CartPole;
228
229    fn make_vec_env(n: usize, seed: u64) -> VecEnv {
230        let envs: Vec<Box<dyn RLEnv>> = (0..n)
231            .map(|i| {
232                let s = derive_seed(seed, i);
233                Box::new(CartPole::new(Some(s))) as Box<dyn RLEnv>
234            })
235            .collect();
236        VecEnv::new(envs).unwrap()
237    }
238
239    #[test]
240    fn vec_env_num_envs() {
241        let venv = make_vec_env(4, 42);
242        assert_eq!(venv.num_envs(), 4);
243    }
244
245    #[test]
246    fn vec_env_step_all_returns_correct_shapes() {
247        let mut venv = make_vec_env(4, 42);
248        let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
249        let batch = venv.step_all(&actions).unwrap();
250        assert_eq!(batch.obs.len(), 4);
251        assert_eq!(batch.rewards.len(), 4);
252        assert_eq!(batch.terminated.len(), 4);
253        assert_eq!(batch.truncated.len(), 4);
254        for obs in &batch.obs {
255            assert_eq!(obs.len(), 4);
256        }
257    }
258
259    #[test]
260    fn vec_env_step_all_flat_returns_contiguous_obs() {
261        let mut venv = make_vec_env(4, 42);
262        let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
263
264        let batch_flat = venv.step_all_flat(&actions).unwrap();
265        assert!(
266            batch_flat.obs.is_empty(),
267            "obs Vec should be empty in flat mode"
268        );
269        assert_eq!(batch_flat.obs_flat.len(), 4 * 4); // 4 envs * 4 obs_dim (CartPole)
270        assert_eq!(batch_flat.obs_dim, 4);
271        assert_eq!(batch_flat.rewards.len(), 4);
272    }
273
274    #[test]
275    fn vec_env_step_all_flat_matches_step_all() {
276        let mut venv1 = make_vec_env(4, 42);
277        let mut venv2 = make_vec_env(4, 42);
278        let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
279
280        let batch_vec = venv1.step_all(&actions).unwrap();
281        let batch_flat = venv2.step_all_flat(&actions).unwrap();
282
283        // Compare flat obs against per-env obs
284        for (i, obs_vec) in batch_vec.obs.iter().enumerate() {
285            let flat_slice = &batch_flat.obs_flat[i * 4..(i + 1) * 4];
286            assert_eq!(obs_vec, flat_slice, "env {i} obs mismatch");
287        }
288        assert_eq!(batch_vec.rewards, batch_flat.rewards);
289        assert_eq!(batch_vec.terminated, batch_flat.terminated);
290    }
291
292    #[test]
293    fn vec_env_step_all_wrong_action_count() {
294        let mut venv = make_vec_env(4, 42);
295        let actions = vec![Action::Discrete(0); 3];
296        let result = venv.step_all(&actions);
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn vec_env_reset_all_deterministic() {
302        let mut venv1 = make_vec_env(4, 0);
303        let mut venv2 = make_vec_env(4, 0);
304
305        let obs1 = venv1.reset_all(Some(99)).unwrap();
306        let obs2 = venv2.reset_all(Some(99)).unwrap();
307
308        for (o1, o2) in obs1.iter().zip(obs2.iter()) {
309            assert_eq!(o1.as_slice(), o2.as_slice());
310        }
311    }
312
313    #[test]
314    fn vec_env_large_parallel_stepping() {
315        // Validate 256+ envs step correctly in parallel
316        let mut venv = make_vec_env(256, 42);
317        let actions: Vec<Action> = (0..256).map(|i| Action::Discrete((i % 2) as u32)).collect();
318        let batch = venv.step_all(&actions).unwrap();
319        assert_eq!(batch.obs.len(), 256);
320        assert_eq!(batch.rewards.len(), 256);
321        // All rewards should be 1.0 (first step)
322        for &r in &batch.rewards {
323            assert!((r - 1.0).abs() < f64::EPSILON);
324        }
325    }
326
327    #[test]
328    fn vec_env_1024_envs_no_panic() {
329        // Ensure 1024 parallel envs don't cause thread pool issues
330        let mut venv = make_vec_env(1024, 42);
331        let actions: Vec<Action> = (0..1024)
332            .map(|i| Action::Discrete((i % 2) as u32))
333            .collect();
334        // Step 10 times
335        for _ in 0..10 {
336            let batch = venv.step_all(&actions).unwrap();
337            assert_eq!(batch.obs.len(), 1024);
338        }
339    }
340
341    #[test]
342    fn vec_env_parallel_determinism() {
343        // Parallel stepping must be deterministic across runs
344        let run = || {
345            let mut venv = make_vec_env(64, 42);
346            venv.reset_all(Some(42)).unwrap();
347            let actions: Vec<Action> = (0..64).map(|i| Action::Discrete((i % 2) as u32)).collect();
348            let mut all_rewards = Vec::new();
349            for _ in 0..50 {
350                let batch = venv.step_all(&actions).unwrap();
351                all_rewards.extend(batch.rewards);
352            }
353            all_rewards
354        };
355        let run1 = run();
356        let run2 = run();
357        assert_eq!(run1, run2);
358    }
359
360    #[test]
361    fn vec_env_auto_reset_on_done() {
362        let mut venv = make_vec_env(2, 42);
363
364        // Step many times - eventually envs will terminate and auto-reset
365        for _ in 0..100 {
366            let actions: Vec<Action> = (0..2).map(|_| Action::Discrete(1)).collect();
367            match venv.step_all(&actions) {
368                Ok(_batch) => {} // should always succeed due to auto-reset
369                Err(e) => panic!("step_all should not error with auto-reset: {}", e),
370            }
371        }
372    }
373}
374
375#[cfg(test)]
376mod terminal_obs_tests {
377    use super::*;
378    use crate::env::builtins::CartPole;
379    use crate::seed::derive_seed;
380
381    fn make_vec_env(n: usize, seed: u64) -> VecEnv {
382        let envs: Vec<Box<dyn RLEnv>> = (0..n)
383            .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
384            .collect();
385        VecEnv::new(envs).unwrap()
386    }
387
388    #[test]
389    fn step_result_has_terminal_obs_on_truncation() {
390        let mut venv = make_vec_env(4, 42);
391        venv.reset_all(Some(42)).unwrap();
392
393        for _ in 0..600 {
394            let actions: Vec<Action> = (0..4).map(|_| Action::Discrete(0)).collect();
395            let batch = venv.step_all(&actions).unwrap();
396
397            for i in 0..4 {
398                if batch.truncated[i] {
399                    assert!(
400                        batch.terminal_obs[i].is_some(),
401                        "terminal_obs must be Some when truncated"
402                    );
403                }
404                if batch.terminated[i] {
405                    assert!(
406                        batch.terminal_obs[i].is_some(),
407                        "terminal_obs must be Some when terminated"
408                    );
409                }
410                if !batch.terminated[i] && !batch.truncated[i] {
411                    assert!(
412                        batch.terminal_obs[i].is_none(),
413                        "terminal_obs must be None when not done"
414                    );
415                }
416            }
417        }
418    }
419
420    #[test]
421    fn terminal_obs_has_correct_dimension() {
422        let mut venv = make_vec_env(2, 42);
423        venv.reset_all(Some(42)).unwrap();
424
425        for _ in 0..200 {
426            let actions: Vec<Action> = vec![Action::Discrete(1); 2];
427            let batch = venv.step_all(&actions).unwrap();
428            for i in 0..2 {
429                if let Some(tobs) = &batch.terminal_obs[i] {
430                    assert_eq!(tobs.len(), 4, "CartPole terminal obs must have dim 4");
431                }
432            }
433        }
434    }
435
436    #[test]
437    fn returned_obs_after_reset_is_fresh_not_terminal() {
438        let mut venv = make_vec_env(1, 42);
439        venv.reset_all(Some(42)).unwrap();
440
441        for _ in 0..200 {
442            let actions = vec![Action::Discrete(1)];
443            let batch = venv.step_all(&actions).unwrap();
444            if batch.terminated[0] {
445                let fresh_obs = &batch.obs[0];
446                for &v in fresh_obs {
447                    assert!(
448                        v.abs() <= 0.06,
449                        "post-reset obs should be near zero, got {v}"
450                    );
451                }
452                let tobs = batch.terminal_obs[0]
453                    .as_ref()
454                    .expect("terminal_obs must exist on termination");
455                let max_abs = tobs.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
456                assert!(
457                    max_abs > 0.05,
458                    "terminal obs should be out-of-bounds, got max_abs={max_abs}"
459                );
460                break;
461            }
462        }
463    }
464}
465
466#[cfg(test)]
467mod pendulum_vec_env_tests {
468    use super::*;
469    use crate::env::builtins::Pendulum;
470    use crate::seed::derive_seed;
471
472    fn make_pendulum_vec_env(n: usize, seed: u64) -> VecEnv {
473        let envs: Vec<Box<dyn RLEnv>> = (0..n)
474            .map(|i| {
475                let s = derive_seed(seed, i);
476                Box::new(Pendulum::new(Some(s))) as Box<dyn RLEnv>
477            })
478            .collect();
479        VecEnv::new(envs).unwrap()
480    }
481
482    #[test]
483    fn pendulum_vec_env_step_continuous_actions() {
484        let mut venv = make_pendulum_vec_env(4, 42);
485        let actions: Vec<Action> = (0..4)
486            .map(|i| Action::Continuous(vec![(i as f32 - 1.5) * 0.5]))
487            .collect();
488        let batch = venv.step_all(&actions).unwrap();
489        assert_eq!(batch.obs.len(), 4);
490        assert_eq!(batch.rewards.len(), 4);
491        for obs in &batch.obs {
492            assert_eq!(obs.len(), 3, "Pendulum obs should have 3 dims");
493        }
494    }
495
496    #[test]
497    fn pendulum_vec_env_step_flat() {
498        let mut venv = make_pendulum_vec_env(4, 42);
499        let actions: Vec<Action> = (0..4).map(|_| Action::Continuous(vec![0.5])).collect();
500        let batch = venv.step_all_flat(&actions).unwrap();
501        assert!(batch.obs.is_empty());
502        assert_eq!(batch.obs_flat.len(), 4 * 3);
503        assert_eq!(batch.obs_dim, 3);
504    }
505
506    #[test]
507    fn pendulum_vec_env_auto_reset() {
508        let mut venv = make_pendulum_vec_env(2, 42);
509        // Step 300 times — past the 200 truncation limit
510        for _ in 0..300 {
511            let actions: Vec<Action> = (0..2).map(|_| Action::Continuous(vec![1.0])).collect();
512            let batch = venv.step_all(&actions).unwrap();
513            assert_eq!(batch.obs.len(), 2);
514        }
515    }
516
517    #[test]
518    fn pendulum_vec_env_action_space() {
519        let venv = make_pendulum_vec_env(2, 42);
520        match venv.action_space() {
521            ActionSpace::Box { low, high, shape } => {
522                assert_eq!(shape, &[1]);
523                assert_eq!(low, &[-2.0]);
524                assert_eq!(high, &[2.0]);
525            }
526            other => panic!("Expected Box action space, got {:?}", other),
527        }
528    }
529
530    #[test]
531    fn pendulum_vec_env_determinism() {
532        let run = || {
533            let mut venv = make_pendulum_vec_env(8, 42);
534            venv.reset_all(Some(42)).unwrap();
535            let actions: Vec<Action> = (0..8)
536                .map(|i| Action::Continuous(vec![(i as f32) * 0.25 - 1.0]))
537                .collect();
538            let mut all_rewards = Vec::new();
539            for _ in 0..50 {
540                let batch = venv.step_all(&actions).unwrap();
541                all_rewards.extend(batch.rewards);
542            }
543            all_rewards
544        };
545        let r1 = run();
546        let r2 = run();
547        assert_eq!(r1, r2);
548    }
549}