1use crate::env::parallel::BatchTransition;
2use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
3use crate::error::RloxError;
4
5pub trait BatchSteppable: Send {
13 fn step_batch(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError>;
15
16 fn reset_batch(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError>;
18
19 fn num_envs(&self) -> usize;
21
22 fn action_space(&self) -> &ActionSpace;
24
25 fn obs_space(&self) -> &ObsSpace;
27}
28
29#[cfg(test)]
30mod tests {
31 use super::*;
32 use crate::env::builtins::CartPole;
33 use crate::env::parallel::VecEnv;
34 use crate::env::RLEnv;
35 use crate::seed::derive_seed;
36
37 fn make_batch(n: usize, seed: u64) -> VecEnv {
38 let envs: Vec<Box<dyn RLEnv>> = (0..n)
39 .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
40 .collect();
41 VecEnv::new(envs).unwrap()
42 }
43
44 fn _assert_object_safe(_: &dyn BatchSteppable) {}
46
47 #[test]
48 fn test_vecenv_implements_batch_steppable() {
49 let mut batch: Box<dyn BatchSteppable> = Box::new(make_batch(4, 42));
50 let _obs = batch.reset_batch(Some(42)).unwrap();
51
52 let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
53 let result = batch.step_batch(&actions).unwrap();
54
55 assert_eq!(result.obs.len(), 4);
56 assert_eq!(result.rewards.len(), 4);
57 assert_eq!(result.terminated.len(), 4);
58 assert_eq!(result.truncated.len(), 4);
59 }
60
61 #[test]
62 fn test_batch_steppable_action_space_propagates() {
63 let batch = make_batch(4, 42);
64 let steppable: &dyn BatchSteppable = &batch;
65 assert_eq!(steppable.action_space(), &ActionSpace::Discrete(2));
66 }
67
68 #[test]
69 fn test_batch_steppable_wrong_action_count() {
70 let mut batch: Box<dyn BatchSteppable> = Box::new(make_batch(4, 42));
71 let actions = vec![Action::Discrete(0); 3]; let result = batch.step_batch(&actions);
73 assert!(result.is_err());
74 }
75
76 #[test]
77 fn test_batch_steppable_is_object_safe() {
78 fn _check(_: &dyn BatchSteppable) {}
80 }
81}