1use rayon::prelude::*;
2
3use crate::env::batch::BatchSteppable;
4use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
5use crate::env::RLEnv;
6use crate::error::RloxError;
7use crate::seed::derive_seed;
8
9#[derive(Debug, Clone)]
11pub struct BatchTransition {
12 pub obs: Vec<Vec<f32>>,
14 pub obs_flat: Vec<f32>,
17 pub obs_dim: usize,
19 pub rewards: Vec<f64>,
21 pub terminated: Vec<bool>,
23 pub truncated: Vec<bool>,
25 pub terminal_obs: Vec<Option<Vec<f32>>>,
28}
29
30pub struct VecEnv {
32 envs: Vec<Box<dyn RLEnv>>,
33 action_space: ActionSpace,
34 obs_space: ObsSpace,
35}
36
37impl std::fmt::Debug for VecEnv {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("VecEnv")
40 .field("num_envs", &self.envs.len())
41 .field("action_space", &self.action_space)
42 .field("obs_space", &self.obs_space)
43 .finish()
44 }
45}
46
47impl VecEnv {
48 pub fn new(envs: Vec<Box<dyn RLEnv>>) -> Result<Self, RloxError> {
54 if envs.is_empty() {
55 return Err(RloxError::EnvError(
56 "VecEnv requires at least one environment".into(),
57 ));
58 }
59 let action_space = envs[0].action_space().clone();
60 let obs_space = envs[0].obs_space().clone();
61 Ok(VecEnv {
62 envs,
63 action_space,
64 obs_space,
65 })
66 }
67
68 pub fn num_envs(&self) -> usize {
69 self.envs.len()
70 }
71
72 pub fn action_space(&self) -> &ActionSpace {
73 &self.action_space
74 }
75
76 fn step_raw(
79 &mut self,
80 actions: &[Action],
81 ) -> Result<Vec<(Vec<f32>, f64, bool, bool, Option<Vec<f32>>)>, RloxError> {
82 if actions.len() != self.envs.len() {
83 return Err(RloxError::ShapeMismatch {
84 expected: format!("{}", self.envs.len()),
85 got: format!("{}", actions.len()),
86 });
87 }
88
89 let results: Vec<Result<(Vec<f32>, f64, bool, bool, Option<Vec<f32>>), RloxError>> = self
90 .envs
91 .par_iter_mut()
92 .zip(actions.par_iter())
93 .map(|(env, action)| {
94 let mut transition = env.step(action)?;
95 let mut term_obs = None;
96 if transition.terminated || transition.truncated {
97 term_obs = Some(transition.obs.clone().into_inner());
98 let new_obs = env.reset(None)?;
99 transition.obs = new_obs;
100 }
101 let obs_data = transition.obs.into_inner();
102 Ok((
103 obs_data,
104 transition.reward,
105 transition.terminated,
106 transition.truncated,
107 term_obs,
108 ))
109 })
110 .collect();
111
112 results.into_iter().collect()
113 }
114
115 pub fn step_all(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
120 let raw = self.step_raw(actions)?;
121 let n = raw.len();
122 let mut obs = Vec::with_capacity(n);
123 let mut rewards = Vec::with_capacity(n);
124 let mut terminated = Vec::with_capacity(n);
125 let mut truncated = Vec::with_capacity(n);
126 let mut terminal_obs = Vec::with_capacity(n);
127
128 for (obs_data, reward, term, trunc, tobs) in raw {
129 obs.push(obs_data);
130 rewards.push(reward);
131 terminated.push(term);
132 truncated.push(trunc);
133 terminal_obs.push(tobs);
134 }
135
136 Ok(BatchTransition {
137 obs,
138 obs_flat: Vec::new(),
139 obs_dim: 0,
140 rewards,
141 terminated,
142 truncated,
143 terminal_obs,
144 })
145 }
146
147 pub fn step_all_flat(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
153 let obs_dim = match &self.obs_space {
154 ObsSpace::Discrete(_) => 1,
155 ObsSpace::Box { shape, .. } => shape.iter().product(),
156 ObsSpace::MultiDiscrete(v) => v.len(),
157 ObsSpace::Dict(entries) => entries.iter().map(|(_, d)| d).sum(),
158 };
159
160 let raw = self.step_raw(actions)?;
161 let n = raw.len();
162 let mut obs_flat = vec![0.0f32; n * obs_dim];
163 let mut rewards = Vec::with_capacity(n);
164 let mut terminated = Vec::with_capacity(n);
165 let mut truncated = Vec::with_capacity(n);
166 let mut terminal_obs = Vec::with_capacity(n);
167
168 for (i, (obs_data, reward, term, trunc, tobs)) in raw.into_iter().enumerate() {
169 obs_flat[i * obs_dim..(i + 1) * obs_dim].copy_from_slice(&obs_data);
170 rewards.push(reward);
171 terminated.push(term);
172 truncated.push(trunc);
173 terminal_obs.push(tobs);
174 }
175
176 Ok(BatchTransition {
177 obs: Vec::new(),
178 obs_flat,
179 obs_dim,
180 rewards,
181 terminated,
182 truncated,
183 terminal_obs,
184 })
185 }
186
187 pub fn reset_all(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError> {
191 self.envs
192 .iter_mut()
193 .enumerate()
194 .map(|(i, env)| {
195 let env_seed = seed.map(|s| derive_seed(s, i));
196 env.reset(env_seed)
197 })
198 .collect()
199 }
200}
201
202impl BatchSteppable for VecEnv {
203 fn step_batch(&mut self, actions: &[Action]) -> Result<BatchTransition, RloxError> {
204 self.step_all(actions)
205 }
206
207 fn reset_batch(&mut self, seed: Option<u64>) -> Result<Vec<Observation>, RloxError> {
208 self.reset_all(seed)
209 }
210
211 fn num_envs(&self) -> usize {
212 self.num_envs()
213 }
214
215 fn action_space(&self) -> &ActionSpace {
216 &self.action_space
217 }
218
219 fn obs_space(&self) -> &ObsSpace {
220 &self.obs_space
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::env::builtins::CartPole;
228
229 fn make_vec_env(n: usize, seed: u64) -> VecEnv {
230 let envs: Vec<Box<dyn RLEnv>> = (0..n)
231 .map(|i| {
232 let s = derive_seed(seed, i);
233 Box::new(CartPole::new(Some(s))) as Box<dyn RLEnv>
234 })
235 .collect();
236 VecEnv::new(envs).unwrap()
237 }
238
239 #[test]
240 fn vec_env_num_envs() {
241 let venv = make_vec_env(4, 42);
242 assert_eq!(venv.num_envs(), 4);
243 }
244
245 #[test]
246 fn vec_env_step_all_returns_correct_shapes() {
247 let mut venv = make_vec_env(4, 42);
248 let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
249 let batch = venv.step_all(&actions).unwrap();
250 assert_eq!(batch.obs.len(), 4);
251 assert_eq!(batch.rewards.len(), 4);
252 assert_eq!(batch.terminated.len(), 4);
253 assert_eq!(batch.truncated.len(), 4);
254 for obs in &batch.obs {
255 assert_eq!(obs.len(), 4);
256 }
257 }
258
259 #[test]
260 fn vec_env_step_all_flat_returns_contiguous_obs() {
261 let mut venv = make_vec_env(4, 42);
262 let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
263
264 let batch_flat = venv.step_all_flat(&actions).unwrap();
265 assert!(
266 batch_flat.obs.is_empty(),
267 "obs Vec should be empty in flat mode"
268 );
269 assert_eq!(batch_flat.obs_flat.len(), 4 * 4); assert_eq!(batch_flat.obs_dim, 4);
271 assert_eq!(batch_flat.rewards.len(), 4);
272 }
273
274 #[test]
275 fn vec_env_step_all_flat_matches_step_all() {
276 let mut venv1 = make_vec_env(4, 42);
277 let mut venv2 = make_vec_env(4, 42);
278 let actions: Vec<Action> = (0..4).map(|i| Action::Discrete((i % 2) as u32)).collect();
279
280 let batch_vec = venv1.step_all(&actions).unwrap();
281 let batch_flat = venv2.step_all_flat(&actions).unwrap();
282
283 for (i, obs_vec) in batch_vec.obs.iter().enumerate() {
285 let flat_slice = &batch_flat.obs_flat[i * 4..(i + 1) * 4];
286 assert_eq!(obs_vec, flat_slice, "env {i} obs mismatch");
287 }
288 assert_eq!(batch_vec.rewards, batch_flat.rewards);
289 assert_eq!(batch_vec.terminated, batch_flat.terminated);
290 }
291
292 #[test]
293 fn vec_env_step_all_wrong_action_count() {
294 let mut venv = make_vec_env(4, 42);
295 let actions = vec![Action::Discrete(0); 3];
296 let result = venv.step_all(&actions);
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn vec_env_reset_all_deterministic() {
302 let mut venv1 = make_vec_env(4, 0);
303 let mut venv2 = make_vec_env(4, 0);
304
305 let obs1 = venv1.reset_all(Some(99)).unwrap();
306 let obs2 = venv2.reset_all(Some(99)).unwrap();
307
308 for (o1, o2) in obs1.iter().zip(obs2.iter()) {
309 assert_eq!(o1.as_slice(), o2.as_slice());
310 }
311 }
312
313 #[test]
314 fn vec_env_large_parallel_stepping() {
315 let mut venv = make_vec_env(256, 42);
317 let actions: Vec<Action> = (0..256).map(|i| Action::Discrete((i % 2) as u32)).collect();
318 let batch = venv.step_all(&actions).unwrap();
319 assert_eq!(batch.obs.len(), 256);
320 assert_eq!(batch.rewards.len(), 256);
321 for &r in &batch.rewards {
323 assert!((r - 1.0).abs() < f64::EPSILON);
324 }
325 }
326
327 #[test]
328 fn vec_env_1024_envs_no_panic() {
329 let mut venv = make_vec_env(1024, 42);
331 let actions: Vec<Action> = (0..1024)
332 .map(|i| Action::Discrete((i % 2) as u32))
333 .collect();
334 for _ in 0..10 {
336 let batch = venv.step_all(&actions).unwrap();
337 assert_eq!(batch.obs.len(), 1024);
338 }
339 }
340
341 #[test]
342 fn vec_env_parallel_determinism() {
343 let run = || {
345 let mut venv = make_vec_env(64, 42);
346 venv.reset_all(Some(42)).unwrap();
347 let actions: Vec<Action> = (0..64).map(|i| Action::Discrete((i % 2) as u32)).collect();
348 let mut all_rewards = Vec::new();
349 for _ in 0..50 {
350 let batch = venv.step_all(&actions).unwrap();
351 all_rewards.extend(batch.rewards);
352 }
353 all_rewards
354 };
355 let run1 = run();
356 let run2 = run();
357 assert_eq!(run1, run2);
358 }
359
360 #[test]
361 fn vec_env_auto_reset_on_done() {
362 let mut venv = make_vec_env(2, 42);
363
364 for _ in 0..100 {
366 let actions: Vec<Action> = (0..2).map(|_| Action::Discrete(1)).collect();
367 match venv.step_all(&actions) {
368 Ok(_batch) => {} Err(e) => panic!("step_all should not error with auto-reset: {}", e),
370 }
371 }
372 }
373}
374
375#[cfg(test)]
376mod terminal_obs_tests {
377 use super::*;
378 use crate::env::builtins::CartPole;
379 use crate::seed::derive_seed;
380
381 fn make_vec_env(n: usize, seed: u64) -> VecEnv {
382 let envs: Vec<Box<dyn RLEnv>> = (0..n)
383 .map(|i| Box::new(CartPole::new(Some(derive_seed(seed, i)))) as Box<dyn RLEnv>)
384 .collect();
385 VecEnv::new(envs).unwrap()
386 }
387
388 #[test]
389 fn step_result_has_terminal_obs_on_truncation() {
390 let mut venv = make_vec_env(4, 42);
391 venv.reset_all(Some(42)).unwrap();
392
393 for _ in 0..600 {
394 let actions: Vec<Action> = (0..4).map(|_| Action::Discrete(0)).collect();
395 let batch = venv.step_all(&actions).unwrap();
396
397 for i in 0..4 {
398 if batch.truncated[i] {
399 assert!(
400 batch.terminal_obs[i].is_some(),
401 "terminal_obs must be Some when truncated"
402 );
403 }
404 if batch.terminated[i] {
405 assert!(
406 batch.terminal_obs[i].is_some(),
407 "terminal_obs must be Some when terminated"
408 );
409 }
410 if !batch.terminated[i] && !batch.truncated[i] {
411 assert!(
412 batch.terminal_obs[i].is_none(),
413 "terminal_obs must be None when not done"
414 );
415 }
416 }
417 }
418 }
419
420 #[test]
421 fn terminal_obs_has_correct_dimension() {
422 let mut venv = make_vec_env(2, 42);
423 venv.reset_all(Some(42)).unwrap();
424
425 for _ in 0..200 {
426 let actions: Vec<Action> = vec![Action::Discrete(1); 2];
427 let batch = venv.step_all(&actions).unwrap();
428 for i in 0..2 {
429 if let Some(tobs) = &batch.terminal_obs[i] {
430 assert_eq!(tobs.len(), 4, "CartPole terminal obs must have dim 4");
431 }
432 }
433 }
434 }
435
436 #[test]
437 fn returned_obs_after_reset_is_fresh_not_terminal() {
438 let mut venv = make_vec_env(1, 42);
439 venv.reset_all(Some(42)).unwrap();
440
441 for _ in 0..200 {
442 let actions = vec![Action::Discrete(1)];
443 let batch = venv.step_all(&actions).unwrap();
444 if batch.terminated[0] {
445 let fresh_obs = &batch.obs[0];
446 for &v in fresh_obs {
447 assert!(
448 v.abs() <= 0.06,
449 "post-reset obs should be near zero, got {v}"
450 );
451 }
452 let tobs = batch.terminal_obs[0]
453 .as_ref()
454 .expect("terminal_obs must exist on termination");
455 let max_abs = tobs.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
456 assert!(
457 max_abs > 0.05,
458 "terminal obs should be out-of-bounds, got max_abs={max_abs}"
459 );
460 break;
461 }
462 }
463 }
464}
465
466#[cfg(test)]
467mod pendulum_vec_env_tests {
468 use super::*;
469 use crate::env::builtins::Pendulum;
470 use crate::seed::derive_seed;
471
472 fn make_pendulum_vec_env(n: usize, seed: u64) -> VecEnv {
473 let envs: Vec<Box<dyn RLEnv>> = (0..n)
474 .map(|i| {
475 let s = derive_seed(seed, i);
476 Box::new(Pendulum::new(Some(s))) as Box<dyn RLEnv>
477 })
478 .collect();
479 VecEnv::new(envs).unwrap()
480 }
481
482 #[test]
483 fn pendulum_vec_env_step_continuous_actions() {
484 let mut venv = make_pendulum_vec_env(4, 42);
485 let actions: Vec<Action> = (0..4)
486 .map(|i| Action::Continuous(vec![(i as f32 - 1.5) * 0.5]))
487 .collect();
488 let batch = venv.step_all(&actions).unwrap();
489 assert_eq!(batch.obs.len(), 4);
490 assert_eq!(batch.rewards.len(), 4);
491 for obs in &batch.obs {
492 assert_eq!(obs.len(), 3, "Pendulum obs should have 3 dims");
493 }
494 }
495
496 #[test]
497 fn pendulum_vec_env_step_flat() {
498 let mut venv = make_pendulum_vec_env(4, 42);
499 let actions: Vec<Action> = (0..4).map(|_| Action::Continuous(vec![0.5])).collect();
500 let batch = venv.step_all_flat(&actions).unwrap();
501 assert!(batch.obs.is_empty());
502 assert_eq!(batch.obs_flat.len(), 4 * 3);
503 assert_eq!(batch.obs_dim, 3);
504 }
505
506 #[test]
507 fn pendulum_vec_env_auto_reset() {
508 let mut venv = make_pendulum_vec_env(2, 42);
509 for _ in 0..300 {
511 let actions: Vec<Action> = (0..2).map(|_| Action::Continuous(vec![1.0])).collect();
512 let batch = venv.step_all(&actions).unwrap();
513 assert_eq!(batch.obs.len(), 2);
514 }
515 }
516
517 #[test]
518 fn pendulum_vec_env_action_space() {
519 let venv = make_pendulum_vec_env(2, 42);
520 match venv.action_space() {
521 ActionSpace::Box { low, high, shape } => {
522 assert_eq!(shape, &[1]);
523 assert_eq!(low, &[-2.0]);
524 assert_eq!(high, &[2.0]);
525 }
526 other => panic!("Expected Box action space, got {:?}", other),
527 }
528 }
529
530 #[test]
531 fn pendulum_vec_env_determinism() {
532 let run = || {
533 let mut venv = make_pendulum_vec_env(8, 42);
534 venv.reset_all(Some(42)).unwrap();
535 let actions: Vec<Action> = (0..8)
536 .map(|i| Action::Continuous(vec![(i as f32) * 0.25 - 1.0]))
537 .collect();
538 let mut all_rewards = Vec::new();
539 for _ in 0..50 {
540 let batch = venv.step_all(&actions).unwrap();
541 all_rewards.extend(batch.rewards);
542 }
543 all_rewards
544 };
545 let r1 = run();
546 let r2 = run();
547 assert_eq!(r1, r2);
548 }
549}