rlox_core/env/
batch.rs

1use crate::env::parallel::BatchTransition;
2use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
3use crate::error::RloxError;
4
5/// Trait for anything that can step a batch of environments.
6///
7/// Separates the parallelism strategy from step logic, so that both
8/// Rust-native `VecEnv` (Rayon) and future Python-backed `GymVecEnv`
9/// can share a common interface for rollout collectors and training loops.
10///
11/// The `Send` bound enables use from async / threaded contexts.
12pub trait BatchSteppable: Send {
13    /// Step all environments with the given actions (one per env).
14    fn step_batch(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError>;
15
16    /// Reset all environments, optionally seeding them deterministically.
17    fn reset_batch(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError>;
18
19    /// Number of sub-environments in this batch.
20    fn num_envs(&self) -> usize;
21
22    /// The shared action space (all sub-environments must have the same space).
23    fn action_space(&self) -> &ActionSpace;
24
25    /// The shared observation space.
26    fn obs_space(&self) -> &ObsSpace;
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use crate::env::builtins::CartPole;
33    use crate::env::parallel::VecEnv;
34    use crate::env::RLEnv;
35    use crate::seed::derive_seed;
36
37    fn make_batch(n: usize, seed: u64) -> VecEnv {
38        let envs: Vec<Box<dyn RLEnv>> = (0..n)
39            .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
40            .collect();
41        VecEnv::new(envs).unwrap()
42    }
43
44    /// Compile-time proof that `BatchSteppable` is object-safe.
45    fn _assert_object_safe(_: &dyn BatchSteppable) {}
46
47    #[test]
48    fn test_vecenv_implements_batch_steppable() {
49        let mut batch: Box<dyn BatchSteppable> = Box::new(make_batch(4, 42));
50        let _obs = batch.reset_batch(Some(42)).unwrap();
51
52        let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
53        let result = batch.step_batch(&actions).unwrap();
54
55        assert_eq!(result.obs.len(), 4);
56        assert_eq!(result.rewards.len(), 4);
57        assert_eq!(result.terminated.len(), 4);
58        assert_eq!(result.truncated.len(), 4);
59    }
60
61    #[test]
62    fn test_batch_steppable_action_space_propagates() {
63        let batch = make_batch(4, 42);
64        let steppable: &dyn BatchSteppable = &batch;
65        assert_eq!(steppable.action_space(), &ActionSpace::Discrete(2));
66    }
67
68    #[test]
69    fn test_batch_steppable_wrong_action_count() {
70        let mut batch: Box<dyn BatchSteppable> = Box::new(make_batch(4, 42));
71        let actions = vec![Action::Discrete(0); 3]; // 3 actions for 4 envs
72        let result = batch.step_batch(&actions);
73        assert!(result.is_err());
74    }
75
76    #[test]
77    fn test_batch_steppable_is_object_safe() {
78        // If this compiles, the trait is object-safe.
79        fn _check(_: &dyn BatchSteppable) {}
80    }
81}