rlox_core/env/
mujoco.rs

1//! MuJoCo environment bindings.
2//!
3//! Gated behind the `mujoco` feature flag. When the flag is not enabled,
4//! this module provides [`SimplifiedMuJoCoEnv`] -- a placeholder that
5//! implements `RLEnv` with basic linear dynamics, allowing the full
6//! architecture (VecEnv, PyVecEnv factory, training loops) to be validated
7//! without the MuJoCo C library installed.
8
9#[cfg(feature = "mujoco")]
10mod inner {
11    // TODO: Full MuJoCo implementation using `mujoco-rs` crate.
12    //
13    // When the `mujoco` feature is enabled, this module will provide:
14    //   - `MuJoCoEnv`: a generic wrapper around any MuJoCo XML model
15    //   - Pre-built constructors for standard envs (HalfCheetah, Ant, Hopper, etc.)
16    //   - Proper contact physics, joint limits, and reward shaping
17    //
18    // For now, re-export the simplified env so downstream code compiles
19    // regardless of feature state.
20    pub use super::simplified::SimplifiedMuJoCoEnv;
21}
22
23#[cfg(not(feature = "mujoco"))]
24mod inner {
25    pub use super::simplified::SimplifiedMuJoCoEnv;
26}
27
28pub use inner::*;
29
30mod simplified {
31    use std::collections::HashMap;
32
33    use rand::Rng;
34    use rand_chacha::ChaCha8Rng;
35
36    use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
37    use crate::env::{RLEnv, Transition};
38    use crate::error::RloxError;
39    use crate::seed::rng_from_seed;
40
41    // ---------------------------------------------------------------------------
42    // HalfCheetah-v4 dimensions (matching Gymnasium / MuJoCo)
43    // ---------------------------------------------------------------------------
44
45    /// Observation dimensionality for HalfCheetah-v4.
46    ///
47    /// 17 = 8 joint positions (excluding rootx) + 9 joint velocities.
48    const OBS_DIM: usize = 17;
49
50    /// Action dimensionality for HalfCheetah-v4.
51    ///
52    /// 6 torque actuators on the 6 non-root joints.
53    const ACT_DIM: usize = 6;
54
55    /// Integration timestep (seconds).
56    const DT: f64 = 0.05;
57
58    /// Maximum episode length before truncation.
59    const MAX_STEPS: u32 = 1000;
60
61    /// Control cost coefficient (matching Gymnasium default).
62    const CTRL_COST_WEIGHT: f64 = 0.1;
63
64    /// Simplified MuJoCo-like environment with HalfCheetah-v4 dimensions.
65    ///
66    /// Uses basic linear dynamics as a placeholder:
67    ///   `next_state = state + dt * f(action)`
68    ///
69    /// This is **not** physically accurate -- it exists solely to validate
70    /// the trait implementation, VecEnv integration, and Python bindings
71    /// before the real MuJoCo backend is wired up.
72    pub struct SimplifiedMuJoCoEnv {
73        /// Internal state vector (17-dim).
74        state: Vec<f64>,
75        rng: ChaCha8Rng,
76        steps: u32,
77        action_space: ActionSpace,
78        obs_space: ObsSpace,
79        done: bool,
80        /// Previous x-position for velocity reward computation.
81        prev_x_pos: f64,
82    }
83
84    impl SimplifiedMuJoCoEnv {
85        /// Create a new simplified HalfCheetah-v4 environment.
86        pub fn new(seed: Option<u64>) -> Self {
87            let seed_val = seed.unwrap_or(0);
88            let rng = rng_from_seed(seed_val);
89
90            let action_low = vec![-1.0_f32; ACT_DIM];
91            let action_high = vec![1.0_f32; ACT_DIM];
92
93            // Observation bounds: generous range to avoid clipping during rollouts.
94            let obs_low = vec![-f32::INFINITY; OBS_DIM];
95            let obs_high = vec![f32::INFINITY; OBS_DIM];
96
97            let mut env = SimplifiedMuJoCoEnv {
98                state: vec![0.0; OBS_DIM],
99                rng,
100                steps: 0,
101                action_space: ActionSpace::Box {
102                    low: action_low,
103                    high: action_high,
104                    shape: vec![ACT_DIM],
105                },
106                obs_space: ObsSpace::Box {
107                    low: obs_low,
108                    high: obs_high,
109                    shape: vec![OBS_DIM],
110                },
111                done: true,
112                prev_x_pos: 0.0,
113            };
114            let _ = env.reset(Some(seed_val));
115            env
116        }
117
118        /// Build observation from current state.
119        fn obs(&self) -> Observation {
120            Observation::Flat(self.state.iter().map(|&v| v as f32).collect())
121        }
122
123        /// Forward velocity (used as the primary reward signal).
124        ///
125        /// In real HalfCheetah this comes from the MuJoCo simulator's
126        /// `(x_after - x_before) / dt`. Here we approximate it from the
127        /// velocity component of the state (index 8 = rootx velocity in the
128        /// simplified model).
129        fn forward_velocity(&self) -> f64 {
130            // state[8] represents the x-velocity in our simplified layout.
131            // Indices 0..8 are joint positions, 8..17 are joint velocities.
132            self.state[8]
133        }
134    }
135
136    impl RLEnv for SimplifiedMuJoCoEnv {
137        fn step(&mut self, action: &Action) -> Result<Transition, RloxError> {
138            if self.done {
139                return Err(RloxError::EnvError(
140                    "Environment is done. Call reset() before stepping.".into(),
141                ));
142            }
143
144            let torques = match action {
145                Action::Continuous(vals) if vals.len() == ACT_DIM => vals,
146                _ => {
147                    return Err(RloxError::InvalidAction(format!(
148                        "HalfCheetah expects a Continuous action with {} elements",
149                        ACT_DIM
150                    )));
151                }
152            };
153
154            // --- Simplified dynamics ---
155            // Position updates: affected by current velocities
156            // state[0..8] = joint positions, state[8..17] = joint velocities
157            //
158            // Velocity updates: affected by actions (torques) on joints 0..6
159            // This is a gross simplification -- real MuJoCo solves the full
160            // equations of motion with contacts, inertia, Coriolis forces, etc.
161
162            // Update velocities from torques (joints 0..6 map to actions 0..6)
163            for (i, &t) in torques.iter().enumerate().take(ACT_DIM) {
164                let torque = (t as f64).clamp(-1.0, 1.0);
165                self.state[8 + i] += DT * torque;
166            }
167
168            // Update positions from velocities
169            for i in 0..8 {
170                let vel_idx = 8 + i.min(OBS_DIM - 9);
171                self.state[i] += DT * self.state[vel_idx];
172            }
173
174            self.steps += 1;
175
176            // --- Reward ---
177            // HalfCheetah-v4 reward = forward_velocity - ctrl_cost
178            let forward_vel = self.forward_velocity();
179            let ctrl_cost: f64 = CTRL_COST_WEIGHT
180                * torques
181                    .iter()
182                    .map(|&t| (t as f64) * (t as f64))
183                    .sum::<f64>();
184            let reward = forward_vel - ctrl_cost;
185
186            // HalfCheetah never terminates early, only truncates at max_steps.
187            let truncated = self.steps >= MAX_STEPS;
188            self.done = truncated;
189
190            Ok(Transition {
191                obs: self.obs(),
192                reward,
193                terminated: false,
194                truncated,
195                info: Some({
196                    let mut info = HashMap::new();
197                    info.insert("x_velocity".to_string(), forward_vel);
198                    info.insert("reward_forward".to_string(), forward_vel);
199                    info.insert("reward_ctrl".to_string(), -ctrl_cost);
200                    info
201                }),
202            })
203        }
204
205        fn reset(&mut self, seed: Option<u64>) -> Result<Observation, RloxError> {
206            if let Some(s) = seed {
207                self.rng = rng_from_seed(s);
208            }
209
210            // Initialize state with small random values (matching Gymnasium's
211            // `init_qpos + noise` and `init_qvel + noise` strategy).
212            for s in self.state.iter_mut() {
213                *s = self.rng.random_range(-0.1..0.1);
214            }
215
216            self.steps = 0;
217            self.done = false;
218            self.prev_x_pos = 0.0;
219
220            Ok(self.obs())
221        }
222
223        fn action_space(&self) -> &ActionSpace {
224            &self.action_space
225        }
226
227        fn obs_space(&self) -> &ObsSpace {
228            &self.obs_space
229        }
230
231        fn render(&self) -> Option<String> {
232            Some(format!(
233                "SimplifiedHalfCheetah | step={} | x_vel={:.4}",
234                self.steps,
235                self.forward_velocity()
236            ))
237        }
238    }
239}
240
241// ---------------------------------------------------------------------------
242// Tests
243// ---------------------------------------------------------------------------
244
245#[cfg(test)]
246mod tests {
247    use super::SimplifiedMuJoCoEnv;
248    use crate::env::parallel::VecEnv;
249    use crate::env::spaces::{Action, ActionSpace, ObsSpace};
250    use crate::env::RLEnv;
251    use crate::seed::derive_seed;
252
253    fn zero_action() -> Action {
254        Action::Continuous(vec![0.0; 6])
255    }
256
257    fn random_action(seed: u32) -> Action {
258        // Deterministic pseudo-random action for testing
259        let vals: Vec<f32> = (0..6)
260            .map(|i| ((seed as f32 + i as f32) * 0.31415).sin() * 0.8)
261            .collect();
262        Action::Continuous(vals)
263    }
264
265    // ----- Observation shape -----
266
267    #[test]
268    fn obs_dim_is_17() {
269        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
270        let obs = env.reset(Some(42)).unwrap();
271        assert_eq!(obs.as_slice().len(), 17, "HalfCheetah obs must be 17-dim");
272    }
273
274    #[test]
275    fn obs_space_shape_is_17() {
276        let env = SimplifiedMuJoCoEnv::new(Some(42));
277        match env.obs_space() {
278            ObsSpace::Box { shape, .. } => {
279                assert_eq!(shape, &[17]);
280            }
281            other => panic!("Expected Box obs space, got {:?}", other),
282        }
283    }
284
285    // ----- Action shape -----
286
287    #[test]
288    fn action_dim_is_6() {
289        let env = SimplifiedMuJoCoEnv::new(Some(42));
290        match env.action_space() {
291            ActionSpace::Box { low, high, shape } => {
292                assert_eq!(shape, &[6]);
293                assert_eq!(low.len(), 6);
294                assert_eq!(high.len(), 6);
295                for (&lo, &hi) in low.iter().zip(high.iter()) {
296                    assert!((lo - (-1.0)).abs() < f32::EPSILON);
297                    assert!((hi - 1.0).abs() < f32::EPSILON);
298                }
299            }
300            other => panic!("Expected Box action space, got {:?}", other),
301        }
302    }
303
304    // ----- Reset -----
305
306    #[test]
307    fn reset_returns_valid_obs() {
308        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
309        let obs = env.reset(Some(99)).unwrap();
310        assert_eq!(obs.as_slice().len(), 17);
311        // Initial state should be small random values in [-0.1, 0.1]
312        for &v in obs.as_slice() {
313            assert!(
314                v.abs() <= 0.1 + f32::EPSILON,
315                "initial obs element out of range: {}",
316                v
317            );
318        }
319    }
320
321    #[test]
322    fn reset_clears_step_counter() {
323        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
324        // Step a few times
325        for _ in 0..10 {
326            env.step(&zero_action()).unwrap();
327        }
328        // Reset and step again -- should not truncate immediately
329        env.reset(Some(42)).unwrap();
330        let t = env.step(&zero_action()).unwrap();
331        assert!(!t.truncated);
332    }
333
334    // ----- Step produces valid output -----
335
336    #[test]
337    fn step_returns_17_dim_obs() {
338        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
339        let t = env.step(&zero_action()).unwrap();
340        assert_eq!(t.obs.as_slice().len(), 17);
341    }
342
343    #[test]
344    fn step_never_terminates() {
345        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
346        for _ in 0..1000 {
347            let t = env.step(&random_action(0)).unwrap();
348            assert!(!t.terminated, "HalfCheetah should never terminate early");
349            if t.truncated {
350                break;
351            }
352        }
353    }
354
355    // ----- Truncation at 1000 steps -----
356
357    #[test]
358    fn truncates_at_1000_steps() {
359        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
360        for i in 0..1000 {
361            let t = env.step(&zero_action()).unwrap();
362            if i < 999 {
363                assert!(!t.truncated, "should not truncate at step {}", i + 1);
364            } else {
365                assert!(t.truncated, "should truncate at step 1000");
366            }
367        }
368        // Stepping after truncation should error
369        let result = env.step(&zero_action());
370        assert!(result.is_err());
371    }
372
373    // ----- Reward structure -----
374
375    #[test]
376    fn zero_action_gives_zero_ctrl_cost() {
377        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
378        let t = env.step(&zero_action()).unwrap();
379        // With zero action, ctrl_cost = 0, so reward = forward_velocity
380        let x_vel = t
381            .info
382            .as_ref()
383            .and_then(|m| m.get("x_velocity"))
384            .copied()
385            .unwrap_or(0.0);
386        let ctrl = t
387            .info
388            .as_ref()
389            .and_then(|m| m.get("reward_ctrl"))
390            .copied()
391            .unwrap_or(0.0);
392        assert!(
393            ctrl.abs() < 1e-10,
394            "ctrl cost should be ~0 for zero action, got {}",
395            ctrl
396        );
397        assert!(
398            (t.reward - x_vel).abs() < 1e-10,
399            "reward should equal x_velocity when ctrl_cost=0"
400        );
401    }
402
403    #[test]
404    fn nonzero_action_incurs_ctrl_cost() {
405        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
406        let action = Action::Continuous(vec![0.5; 6]);
407        let t = env.step(&action).unwrap();
408        let ctrl = t
409            .info
410            .as_ref()
411            .and_then(|m| m.get("reward_ctrl"))
412            .copied()
413            .unwrap_or(0.0);
414        // ctrl_cost = 0.1 * sum(0.5^2 * 6) = 0.1 * 1.5 = 0.15
415        assert!(ctrl < 0.0, "ctrl reward should be negative, got {}", ctrl);
416    }
417
418    // ----- Invalid actions -----
419
420    #[test]
421    fn discrete_action_rejected() {
422        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
423        let result = env.step(&Action::Discrete(0));
424        assert!(result.is_err());
425    }
426
427    #[test]
428    fn wrong_dim_action_rejected() {
429        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
430        let result = env.step(&Action::Continuous(vec![0.0; 3]));
431        assert!(result.is_err());
432    }
433
434    // ----- Seeded determinism -----
435
436    #[test]
437    fn seeded_determinism() {
438        let run = |seed: u64| -> Vec<f64> {
439            let mut env = SimplifiedMuJoCoEnv::new(Some(seed));
440            let mut rewards = Vec::with_capacity(100);
441            for i in 0..100 {
442                let t = env.step(&random_action(i)).unwrap();
443                rewards.push(t.reward);
444            }
445            rewards
446        };
447
448        let r1 = run(123);
449        let r2 = run(123);
450        assert_eq!(r1, r2, "same seed must produce identical trajectories");
451
452        let r3 = run(456);
453        assert_ne!(
454            r1, r3,
455            "different seeds should produce different trajectories"
456        );
457    }
458
459    // ----- Step after done errors -----
460
461    #[test]
462    fn step_after_done_errors() {
463        let mut env = SimplifiedMuJoCoEnv::new(Some(42));
464        // Run to truncation
465        for _ in 0..1000 {
466            let _ = env.step(&zero_action()).unwrap();
467        }
468        let result = env.step(&zero_action());
469        assert!(result.is_err());
470    }
471
472    // ----- VecEnv integration -----
473
474    #[test]
475    fn vec_env_with_multiple_half_cheetahs() {
476        let n = 4;
477        let envs: Vec<Box<dyn RLEnv>> = (0..n)
478            .map(|i| {
479                let s = derive_seed(42, i);
480                Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
481            })
482            .collect();
483
484        let mut venv = VecEnv::new(envs).unwrap();
485        assert_eq!(venv.num_envs(), 4);
486
487        // Step all
488        let actions: Vec<Action> = (0..n).map(|i| random_action(i as u32)).collect();
489        let batch = venv.step_all(&actions).unwrap();
490
491        assert_eq!(batch.obs.len(), 4);
492        assert_eq!(batch.rewards.len(), 4);
493        assert_eq!(batch.terminated.len(), 4);
494        assert_eq!(batch.truncated.len(), 4);
495
496        for obs in &batch.obs {
497            assert_eq!(obs.len(), 17, "each env obs must be 17-dim");
498        }
499    }
500
501    #[test]
502    fn vec_env_flat_stepping() {
503        let n = 4;
504        let envs: Vec<Box<dyn RLEnv>> = (0..n)
505            .map(|i| {
506                let s = derive_seed(42, i);
507                Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
508            })
509            .collect();
510
511        let mut venv = VecEnv::new(envs).unwrap();
512        let actions: Vec<Action> = (0..n).map(|_| zero_action()).collect();
513        let batch = venv.step_all_flat(&actions).unwrap();
514
515        assert!(batch.obs.is_empty());
516        assert_eq!(batch.obs_flat.len(), 4 * 17);
517        assert_eq!(batch.obs_dim, 17);
518    }
519
520    #[test]
521    fn vec_env_auto_reset_across_truncation() {
522        let n = 2;
523        let envs: Vec<Box<dyn RLEnv>> = (0..n)
524            .map(|i| {
525                let s = derive_seed(42, i);
526                Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
527            })
528            .collect();
529
530        let mut venv = VecEnv::new(envs).unwrap();
531        let actions: Vec<Action> = (0..n).map(|_| zero_action()).collect();
532
533        // Step past truncation -- VecEnv auto-resets so this should not error
534        for _ in 0..1100 {
535            let batch = venv.step_all(&actions).unwrap();
536            assert_eq!(batch.obs.len(), 2);
537        }
538    }
539
540    #[test]
541    fn vec_env_determinism() {
542        let run = || {
543            let n = 8;
544            let envs: Vec<Box<dyn RLEnv>> = (0..n)
545                .map(|i| {
546                    let s = derive_seed(42, i);
547                    Box::new(SimplifiedMuJoCoEnv::new(Some(s))) as Box<dyn RLEnv>
548                })
549                .collect();
550
551            let mut venv = VecEnv::new(envs).unwrap();
552            venv.reset_all(Some(42)).unwrap();
553
554            let actions: Vec<Action> = (0..n).map(|i| random_action(i as u32)).collect();
555            let mut all_rewards = Vec::new();
556            for _ in 0..50 {
557                let batch = venv.step_all(&actions).unwrap();
558                all_rewards.extend(batch.rewards);
559            }
560            all_rewards
561        };
562
563        let r1 = run();
564        let r2 = run();
565        assert_eq!(r1, r2);
566    }
567}