1#[cfg(feature = "mujoco")]
10mod inner {
11 pub use super::simplified::SimplifiedMuJoCoEnv;
21}
22
23#[cfg(not(feature = "mujoco"))]
24mod inner {
25 pub use super::simplified::SimplifiedMuJoCoEnv;
26}
27
28pub use inner::*;
29
30mod simplified {
31 use std::collections::HashMap;
32
33 use rand::Rng;
34 use rand_chacha::ChaCha8Rng;
35
36 use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
37 use crate::env::{RLEnv, Transition};
38 use crate::error::RloxError;
39 use crate::seed::rng_from_seed;
40
41 const OBS_DIM: usize = 17;
49
50 const ACT_DIM: usize = 6;
54
55 const DT: f64 = 0.05;
57
58 const MAX_STEPS: u32 = 1000;
60
61 const CTRL_COST_WEIGHT: f64 = 0.1;
63
64 pub struct SimplifiedMuJoCoEnv {
73 state: Vec<f64>,
75 rng: ChaCha8Rng,
76 steps: u32,
77 action_space: ActionSpace,
78 obs_space: ObsSpace,
79 done: bool,
80 prev_x_pos: f64,
82 }
83
84 impl SimplifiedMuJoCoEnv {
85 pub fn new(seed: Option<u64>) -> Self {
87 let seed_val = seed.unwrap_or(0);
88 let rng = rng_from_seed(seed_val);
89
90 let action_low = vec![-1.0_f32; ACT_DIM];
91 let action_high = vec![1.0_f32; ACT_DIM];
92
93 let obs_low = vec![-f32::INFINITY; OBS_DIM];
95 let obs_high = vec![f32::INFINITY; OBS_DIM];
96
97 let mut env = SimplifiedMuJoCoEnv {
98 state: vec![0.0; OBS_DIM],
99 rng,
100 steps: 0,
101 action_space: ActionSpace::Box {
102 low: action_low,
103 high: action_high,
104 shape: vec![ACT_DIM],
105 },
106 obs_space: ObsSpace::Box {
107 low: obs_low,
108 high: obs_high,
109 shape: vec![OBS_DIM],
110 },
111 done: true,
112 prev_x_pos: 0.0,
113 };
114 let _ = env.reset(Some(seed_val));
115 env
116 }
117
118 fn obs(&self) -> Observation {
120 Observation::Flat(self.state.iter().map(|&v| v as f32).collect())
121 }
122
123 fn forward_velocity(&self) -> f64 {
130 self.state[8]
133 }
134 }
135
136 impl RLEnv for SimplifiedMuJoCoEnv {
137 fn step(&mut self, action: &Action) -> Result<Transition, RloxError> {
138 if self.done {
139 return Err(RloxError::EnvError(
140 "Environment is done. Call reset() before stepping.".into(),
141 ));
142 }
143
144 let torques = match action {
145 Action::Continuous(vals) if vals.len() == ACT_DIM => vals,
146 _ => {
147 return Err(RloxError::InvalidAction(format!(
148 "HalfCheetah expects a Continuous action with {} elements",
149 ACT_DIM
150 )));
151 }
152 };
153
154 for (i, &t) in torques.iter().enumerate().take(ACT_DIM) {
164 let torque = (t as f64).clamp(-1.0, 1.0);
165 self.state[8 + i] += DT * torque;
166 }
167
168 for i in 0..8 {
170 let vel_idx = 8 + i.min(OBS_DIM - 9);
171 self.state[i] += DT * self.state[vel_idx];
172 }
173
174 self.steps += 1;
175
176 let forward_vel = self.forward_velocity();
179 let ctrl_cost: f64 = CTRL_COST_WEIGHT
180 * torques
181 .iter()
182 .map(|&t| (t as f64) * (t as f64))
183 .sum::<f64>();
184 let reward = forward_vel - ctrl_cost;
185
186 let truncated = self.steps >= MAX_STEPS;
188 self.done = truncated;
189
190 Ok(Transition {
191 obs: self.obs(),
192 reward,
193 terminated: false,
194 truncated,
195 info: Some({
196 let mut info = HashMap::new();
197 info.insert("x_velocity".to_string(), forward_vel);
198 info.insert("reward_forward".to_string(), forward_vel);
199 info.insert("reward_ctrl".to_string(), -ctrl_cost);
200 info
201 }),
202 })
203 }
204
205 fn reset(&mut self, seed: Option<u64>) -> Result<Observation, RloxError> {
206 if let Some(s) = seed {
207 self.rng = rng_from_seed(s);
208 }
209
210 for s in self.state.iter_mut() {
213 *s = self.rng.random_range(-0.1..0.1);
214 }
215
216 self.steps = 0;
217 self.done = false;
218 self.prev_x_pos = 0.0;
219
220 Ok(self.obs())
221 }
222
223 fn action_space(&self) -> &ActionSpace {
224 &self.action_space
225 }
226
227 fn obs_space(&self) -> &ObsSpace {
228 &self.obs_space
229 }
230
231 fn render(&self) -> Option<String> {
232 Some(format!(
233 "SimplifiedHalfCheetah | step={} | x_vel={:.4}",
234 self.steps,
235 self.forward_velocity()
236 ))
237 }
238 }
239}
240
241#[cfg(test)]
246mod tests {
247 use super::SimplifiedMuJoCoEnv;
248 use crate::env::parallel::VecEnv;
249 use crate::env::spaces::{Action, ActionSpace, ObsSpace};
250 use crate::env::RLEnv;
251 use crate::seed::derive_seed;
252
253 fn zero_action() -> Action {
254 Action::Continuous(vec![0.0; 6])
255 }
256
257 fn random_action(seed: u32) -> Action {
258 let vals: Vec<f32> = (0..6)
260 .map(|i| ((seed as f32 + i as f32) * 0.31415).sin() * 0.8)
261 .collect();
262 Action::Continuous(vals)
263 }
264
265 #[test]
268 fn obs_dim_is_17() {
269 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
270 let obs = env.reset(Some(42)).unwrap();
271 assert_eq!(obs.as_slice().len(), 17, "HalfCheetah obs must be 17-dim");
272 }
273
274 #[test]
275 fn obs_space_shape_is_17() {
276 let env = SimplifiedMuJoCoEnv::new(Some(42));
277 match env.obs_space() {
278 ObsSpace::Box { shape, .. } => {
279 assert_eq!(shape, &[17]);
280 }
281 other => panic!("Expected Box obs space, got {:?}", other),
282 }
283 }
284
285 #[test]
288 fn action_dim_is_6() {
289 let env = SimplifiedMuJoCoEnv::new(Some(42));
290 match env.action_space() {
291 ActionSpace::Box { low, high, shape } => {
292 assert_eq!(shape, &[6]);
293 assert_eq!(low.len(), 6);
294 assert_eq!(high.len(), 6);
295 for (&lo, &hi) in low.iter().zip(high.iter()) {
296 assert!((lo - (-1.0)).abs() < f32::EPSILON);
297 assert!((hi - 1.0).abs() < f32::EPSILON);
298 }
299 }
300 other => panic!("Expected Box action space, got {:?}", other),
301 }
302 }
303
304 #[test]
307 fn reset_returns_valid_obs() {
308 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
309 let obs = env.reset(Some(99)).unwrap();
310 assert_eq!(obs.as_slice().len(), 17);
311 for &v in obs.as_slice() {
313 assert!(
314 v.abs() <= 0.1 + f32::EPSILON,
315 "initial obs element out of range: {}",
316 v
317 );
318 }
319 }
320
321 #[test]
322 fn reset_clears_step_counter() {
323 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
324 for _ in 0..10 {
326 env.step(&zero_action()).unwrap();
327 }
328 env.reset(Some(42)).unwrap();
330 let t = env.step(&zero_action()).unwrap();
331 assert!(!t.truncated);
332 }
333
334 #[test]
337 fn step_returns_17_dim_obs() {
338 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
339 let t = env.step(&zero_action()).unwrap();
340 assert_eq!(t.obs.as_slice().len(), 17);
341 }
342
343 #[test]
344 fn step_never_terminates() {
345 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
346 for _ in 0..1000 {
347 let t = env.step(&random_action(0)).unwrap();
348 assert!(!t.terminated, "HalfCheetah should never terminate early");
349 if t.truncated {
350 break;
351 }
352 }
353 }
354
355 #[test]
358 fn truncates_at_1000_steps() {
359 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
360 for i in 0..1000 {
361 let t = env.step(&zero_action()).unwrap();
362 if i < 999 {
363 assert!(!t.truncated, "should not truncate at step {}", i + 1);
364 } else {
365 assert!(t.truncated, "should truncate at step 1000");
366 }
367 }
368 let result = env.step(&zero_action());
370 assert!(result.is_err());
371 }
372
373 #[test]
376 fn zero_action_gives_zero_ctrl_cost() {
377 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
378 let t = env.step(&zero_action()).unwrap();
379 let x_vel = t
381 .info
382 .as_ref()
383 .and_then(|m| m.get("x_velocity"))
384 .copied()
385 .unwrap_or(0.0);
386 let ctrl = t
387 .info
388 .as_ref()
389 .and_then(|m| m.get("reward_ctrl"))
390 .copied()
391 .unwrap_or(0.0);
392 assert!(
393 ctrl.abs() < 1e-10,
394 "ctrl cost should be ~0 for zero action, got {}",
395 ctrl
396 );
397 assert!(
398 (t.reward - x_vel).abs() < 1e-10,
399 "reward should equal x_velocity when ctrl_cost=0"
400 );
401 }
402
403 #[test]
404 fn nonzero_action_incurs_ctrl_cost() {
405 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
406 let action = Action::Continuous(vec![0.5; 6]);
407 let t = env.step(&action).unwrap();
408 let ctrl = t
409 .info
410 .as_ref()
411 .and_then(|m| m.get("reward_ctrl"))
412 .copied()
413 .unwrap_or(0.0);
414 assert!(ctrl < 0.0, "ctrl reward should be negative, got {}", ctrl);
416 }
417
418 #[test]
421 fn discrete_action_rejected() {
422 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
423 let result = env.step(&Action::Discrete(0));
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn wrong_dim_action_rejected() {
429 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
430 let result = env.step(&Action::Continuous(vec![0.0; 3]));
431 assert!(result.is_err());
432 }
433
434 #[test]
437 fn seeded_determinism() {
438 let run = |seed: u64| -> Vec<f64> {
439 let mut env = SimplifiedMuJoCoEnv::new(Some(seed));
440 let mut rewards = Vec::with_capacity(100);
441 for i in 0..100 {
442 let t = env.step(&random_action(i)).unwrap();
443 rewards.push(t.reward);
444 }
445 rewards
446 };
447
448 let r1 = run(123);
449 let r2 = run(123);
450 assert_eq!(r1, r2, "same seed must produce identical trajectories");
451
452 let r3 = run(456);
453 assert_ne!(
454 r1, r3,
455 "different seeds should produce different trajectories"
456 );
457 }
458
459 #[test]
462 fn step_after_done_errors() {
463 let mut env = SimplifiedMuJoCoEnv::new(Some(42));
464 for _ in 0..1000 {
466 let _ = env.step(&zero_action()).unwrap();
467 }
468 let result = env.step(&zero_action());
469 assert!(result.is_err());
470 }
471
472 #[test]
475 fn vec_env_with_multiple_half_cheetahs() {
476 let n = 4;
477 let envs: Vec<Box<dyn RLEnv>> = (0..n)
478 .map(|i| {
479 let s = derive_seed(42, i);
480 Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
481 })
482 .collect();
483
484 let mut venv = VecEnv::new(envs).unwrap();
485 assert_eq!(venv.num_envs(), 4);
486
487 let actions: Vec<Action> = (0..n).map(|i| random_action(i as u32)).collect();
489 let batch = venv.step_all(&actions).unwrap();
490
491 assert_eq!(batch.obs.len(), 4);
492 assert_eq!(batch.rewards.len(), 4);
493 assert_eq!(batch.terminated.len(), 4);
494 assert_eq!(batch.truncated.len(), 4);
495
496 for obs in &batch.obs {
497 assert_eq!(obs.len(), 17, "each env obs must be 17-dim");
498 }
499 }
500
501 #[test]
502 fn vec_env_flat_stepping() {
503 let n = 4;
504 let envs: Vec<Box<dyn RLEnv>> = (0..n)
505 .map(|i| {
506 let s = derive_seed(42, i);
507 Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
508 })
509 .collect();
510
511 let mut venv = VecEnv::new(envs).unwrap();
512 let actions: Vec<Action> = (0..n).map(|_| zero_action()).collect();
513 let batch = venv.step_all_flat(&actions).unwrap();
514
515 assert!(batch.obs.is_empty());
516 assert_eq!(batch.obs_flat.len(), 4 * 17);
517 assert_eq!(batch.obs_dim, 17);
518 }
519
520 #[test]
521 fn vec_env_auto_reset_across_truncation() {
522 let n = 2;
523 let envs: Vec<Box<dyn RLEnv>> = (0..n)
524 .map(|i| {
525 let s = derive_seed(42, i);
526 Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
527 })
528 .collect();
529
530 let mut venv = VecEnv::new(envs).unwrap();
531 let actions: Vec<Action> = (0..n).map(|_| zero_action()).collect();
532
533 for _ in 0..1100 {
535 let batch = venv.step_all(&actions).unwrap();
536 assert_eq!(batch.obs.len(), 2);
537 }
538 }
539
540 #[test]
541 fn vec_env_determinism() {
542 let run = || {
543 let n = 8;
544 let envs: Vec<Box<dyn RLEnv>> = (0..n)
545 .map(|i| {
546 let s = derive_seed(42, i);
547 Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
548 })
549 .collect();
550
551 let mut venv = VecEnv::new(envs).unwrap();
552 venv.reset_all(Some(42)).unwrap();
553
554 let actions: Vec<Action> = (0..n).map(|i| random_action(i as u32)).collect();
555 let mut all_rewards = Vec::new();
556 for _ in 0..50 {
557 let batch = venv.step_all(&actions).unwrap();
558 all_rewards.extend(batch.rewards);
559 }
560 all_rewards
561 };
562
563 let r1 = run();
564 let r2 = run();
565 assert_eq!(r1, r2);
566 }
567}