1use std::f64::consts::PI;
2
3use rand::Rng;
4use rand_chacha::ChaCha8Rng;
5
6use crate::env::spaces::{Action, ActionSpace, ObsSpace, Observation};
7use crate::env::{RLEnv, Transition};
8use crate::error::RloxError;
9use crate::seed::rng_from_seed;
10
11const GRAVITY: f64 = 9.8;
13const MASSCART: f64 = 1.0;
14const MASSPOLE: f64 = 0.1;
15const TOTAL_MASS: f64 = MASSCART + MASSPOLE;
16const LENGTH: f64 = 0.5; const POLEMASS_LENGTH: f64 = MASSPOLE * LENGTH;
18const FORCE_MAG: f64 = 10.0;
19const TAU: f64 = 0.02; const THETA_THRESHOLD: f64 = 12.0 * 2.0 * PI / 360.0; const X_THRESHOLD: f64 = 2.4;
22const MAX_STEPS: u32 = 500;
23
24const OBS_HIGH: [f32; 4] = [
26 (X_THRESHOLD * 2.0) as f32,
27 f32::MAX,
28 (THETA_THRESHOLD * 2.0) as f32,
29 f32::MAX,
30];
31
32pub struct CartPole {
34 state: [f64; 4],
36 rng: ChaCha8Rng,
37 steps: u32,
38 action_space: ActionSpace,
39 obs_space: ObsSpace,
40 done: bool,
41}
42
43impl CartPole {
44 pub fn new(seed: Option<u64>) -> Self {
45 let seed = seed.unwrap_or(0);
46 let rng = rng_from_seed(seed);
47 let obs_low: Vec<f32> = OBS_HIGH.iter().map(|h| -h).collect();
48 let obs_high: Vec<f32> = OBS_HIGH.to_vec();
49
50 let mut env = CartPole {
51 state: [0.0; 4],
52 rng,
53 steps: 0,
54 action_space: ActionSpace::Discrete(2),
55 obs_space: ObsSpace::Box {
56 low: obs_low,
57 high: obs_high,
58 shape: vec![4],
59 },
60 done: true,
61 };
62 let _ = env.reset(Some(seed));
64 env
65 }
66
67 fn obs(&self) -> Observation {
68 Observation::Flat(self.state.iter().map(|&v| v as f32).collect())
69 }
70}
71
72impl RLEnv for CartPole {
73 fn step(&mut self, action: &Action) -> Result<Transition, RloxError> {
74 if self.done {
75 return Err(RloxError::EnvError(
76 "Environment is done. Call reset() before stepping.".into(),
77 ));
78 }
79
80 let action_idx = match action {
81 Action::Discrete(a) => *a,
82 _ => {
83 return Err(RloxError::InvalidAction(
84 "CartPole expects a Discrete action".into(),
85 ))
86 }
87 };
88
89 if !self.action_space.contains(action) {
90 return Err(RloxError::InvalidAction(format!(
91 "Action {} is out of range for Discrete(2)",
92 action_idx
93 )));
94 }
95
96 let [x, x_dot, theta, theta_dot] = self.state;
97
98 let force = if action_idx == 1 {
99 FORCE_MAG
100 } else {
101 -FORCE_MAG
102 };
103
104 let cos_theta = theta.cos();
105 let sin_theta = theta.sin();
106
107 let temp = (force + POLEMASS_LENGTH * theta_dot * theta_dot * sin_theta) / TOTAL_MASS;
109 let theta_acc = (GRAVITY * sin_theta - cos_theta * temp)
110 / (LENGTH * (4.0 / 3.0 - MASSPOLE * cos_theta * cos_theta / TOTAL_MASS));
111 let x_acc = temp - POLEMASS_LENGTH * theta_acc * cos_theta / TOTAL_MASS;
112
113 let new_x = x + TAU * x_dot;
115 let new_x_dot = x_dot + TAU * x_acc;
116 let new_theta = theta + TAU * theta_dot;
117 let new_theta_dot = theta_dot + TAU * theta_acc;
118
119 self.state = [new_x, new_x_dot, new_theta, new_theta_dot];
120 self.steps += 1;
121
122 let terminated = new_x < -X_THRESHOLD
123 || new_x > X_THRESHOLD
124 || new_theta < -THETA_THRESHOLD
125 || new_theta > THETA_THRESHOLD;
126
127 let truncated = !terminated && self.steps >= MAX_STEPS;
128
129 self.done = terminated || truncated;
130
131 Ok(Transition {
132 obs: self.obs(),
133 reward: 1.0,
134 terminated,
135 truncated,
136 info: None,
137 })
138 }
139
140 fn reset(&mut self, seed: Option<u64>) -> Result<Observation, RloxError> {
141 if let Some(s) = seed {
142 self.rng = rng_from_seed(s);
143 }
144
145 for s in self.state.iter_mut() {
147 *s = self.rng.random_range(-0.05..0.05);
148 }
149
150 self.steps = 0;
151 self.done = false;
152
153 Ok(self.obs())
154 }
155
156 fn action_space(&self) -> &ActionSpace {
157 &self.action_space
158 }
159
160 fn obs_space(&self) -> &ObsSpace {
161 &self.obs_space
162 }
163
164 fn render(&self) -> Option<String> {
165 Some(format!(
166 "CartPole | step={} | x={:.4} theta={:.4}",
167 self.steps, self.state[0], self.state[2]
168 ))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn cartpole_reset_produces_valid_obs() {
178 let env = CartPole::new(Some(42));
179 let obs = env.obs();
180 assert_eq!(obs.as_slice().len(), 4);
181 for &v in obs.as_slice() {
182 assert!(v.abs() <= 0.05, "initial state out of range: {}", v);
183 }
184 }
185
186 #[test]
187 fn cartpole_step_returns_reward_one() {
188 let mut env = CartPole::new(Some(42));
189 let t = env.step(&Action::Discrete(1)).unwrap();
190 assert!((t.reward - 1.0).abs() < f64::EPSILON);
191 assert!(!t.terminated);
192 assert!(!t.truncated);
193 }
194
195 #[test]
196 fn cartpole_invalid_action() {
197 let mut env = CartPole::new(Some(42));
198 let result = env.step(&Action::Discrete(5));
199 assert!(result.is_err());
200 }
201
202 #[test]
203 fn cartpole_step_without_reset_after_done() {
204 let mut env = CartPole::new(Some(42));
205 loop {
207 let t = env.step(&Action::Discrete(1)).unwrap();
208 if t.terminated || t.truncated {
209 break;
210 }
211 }
212 let result = env.step(&Action::Discrete(0));
214 assert!(result.is_err());
215 }
216
217 #[test]
218 fn cartpole_seeded_determinism() {
219 let run = |seed: u64| -> Vec<Vec<f32>> {
220 let mut env = CartPole::new(Some(seed));
221 let mut observations = vec![env.obs().into_inner()];
222 for _ in 0..50 {
223 match env.step(&Action::Discrete(1)) {
224 Ok(t) => observations.push(t.obs.into_inner()),
225 Err(_) => break,
226 }
227 }
228 observations
229 };
230
231 let run1 = run(123);
232 let run2 = run(123);
233 assert_eq!(run1, run2);
234
235 let run3 = run(456);
237 assert_ne!(run1, run3);
238 }
239
240 #[test]
241 fn cartpole_truncates_at_500() {
242 let mut env = CartPole::new(Some(0));
243 let mut truncated = false;
246 for i in 0..600 {
247 let action = Action::Discrete((i % 2) as u32);
248 match env.step(&action) {
249 Ok(t) => {
250 if t.truncated {
251 assert_eq!(env.steps, MAX_STEPS);
252 truncated = true;
253 break;
254 }
255 if t.terminated {
256 env.reset(Some(0)).unwrap();
258 }
259 }
260 Err(_) => {
261 env.reset(Some(0)).unwrap();
262 }
263 }
264 }
265 let _ = truncated; }
269
270 #[test]
271 fn cartpole_numerical_equivalence_seed_42() {
272 let env = CartPole::new(Some(42));
274 let obs = env.obs();
275 assert_eq!(obs.as_slice().len(), 4);
277 for &v in obs.as_slice() {
278 assert!(v.abs() <= 0.05, "initial obs out of expected range: {v}");
279 }
280 }
281
282 #[test]
283 fn cartpole_many_steps_reward_sum() {
284 let mut env = CartPole::new(Some(42));
287 let mut total_reward = 0.0;
288 let mut steps = 0;
289 for _ in 0..100 {
290 match env.step(&Action::Discrete(1)) {
291 Ok(t) => {
292 total_reward += t.reward;
293 steps += 1;
294 if t.terminated || t.truncated {
295 break;
296 }
297 }
298 Err(_) => break,
299 }
300 }
301 assert!(steps > 0);
302 assert!((total_reward - steps as f64).abs() < f64::EPSILON);
303 }
304
305 #[test]
306 fn cartpole_terminates_on_out_of_bounds() {
307 let mut env = CartPole::new(Some(42));
308 let mut terminated = false;
310 for _ in 0..500 {
311 match env.step(&Action::Discrete(1)) {
312 Ok(t) => {
313 if t.terminated {
314 terminated = true;
315 break;
316 }
317 }
318 Err(_) => break,
319 }
320 }
321 assert!(
322 terminated,
323 "CartPole should terminate when always pushing right"
324 );
325 }
326}
327
328const PENDULUM_GRAVITY: f64 = 10.0;
334const PENDULUM_MASS: f64 = 1.0;
335const PENDULUM_LENGTH: f64 = 1.0;
336const PENDULUM_DT: f64 = 0.05;
337const PENDULUM_MAX_VEL: f64 = 8.0;
338const PENDULUM_MAX_TORQUE: f64 = 2.0;
339const PENDULUM_MAX_STEPS: u32 = 200;
340
341#[inline]
346fn angle_normalize(x: f64) -> f64 {
347 (x + PI).rem_euclid(2.0 * PI) - PI
348}
349
350pub struct Pendulum {
356 theta: f64,
358 vel: f64,
359 rng: ChaCha8Rng,
360 steps: u32,
361 action_space: ActionSpace,
362 obs_space: ObsSpace,
363 done: bool,
364}
365
366impl Pendulum {
367 pub fn new(seed: Option<u64>) -> Self {
368 let seed = seed.unwrap_or(0);
369 let rng = rng_from_seed(seed);
370
371 let mut env = Pendulum {
372 theta: 0.0,
373 vel: 0.0,
374 rng,
375 steps: 0,
376 action_space: ActionSpace::Box {
377 low: vec![-PENDULUM_MAX_TORQUE as f32],
378 high: vec![PENDULUM_MAX_TORQUE as f32],
379 shape: vec![1],
380 },
381 obs_space: ObsSpace::Box {
382 low: vec![-1.0, -1.0, -PENDULUM_MAX_VEL as f32],
383 high: vec![1.0, 1.0, PENDULUM_MAX_VEL as f32],
384 shape: vec![3],
385 },
386 done: true,
387 };
388 let _ = env.reset(Some(seed));
389 env
390 }
391
392 #[inline]
393 fn obs(&self) -> Observation {
394 Observation::Flat(vec![
395 self.theta.cos() as f32,
396 self.theta.sin() as f32,
397 self.vel as f32,
398 ])
399 }
400}
401
402impl RLEnv for Pendulum {
403 fn step(&mut self, action: &Action) -> Result<Transition, RloxError> {
404 if self.done {
405 return Err(RloxError::EnvError(
406 "Environment is done. Call reset() before stepping.".into(),
407 ));
408 }
409
410 let torque = match action {
411 Action::Continuous(vals) if vals.len() == 1 => {
412 (vals[0] as f64).clamp(-PENDULUM_MAX_TORQUE, PENDULUM_MAX_TORQUE)
413 }
414 _ => {
415 return Err(RloxError::InvalidAction(
416 "Pendulum expects a Continuous action with 1 element".into(),
417 ));
418 }
419 };
420
421 let theta = self.theta;
422 let vel = self.vel;
423
424 let norm_theta = angle_normalize(theta);
426 let reward = -(norm_theta * norm_theta + 0.1 * vel * vel + 0.001 * torque * torque);
427
428 let g = PENDULUM_GRAVITY;
430 let m = PENDULUM_MASS;
431 let l = PENDULUM_LENGTH;
432 let dt = PENDULUM_DT;
433
434 let new_vel = vel + (3.0 * g / (2.0 * l) * theta.sin() + 3.0 / (m * l * l) * torque) * dt;
435 let new_vel = new_vel.clamp(-PENDULUM_MAX_VEL, PENDULUM_MAX_VEL);
436 let new_theta = theta + new_vel * dt;
437
438 self.theta = new_theta;
439 self.vel = new_vel;
440 self.steps += 1;
441
442 let truncated = self.steps >= PENDULUM_MAX_STEPS;
444 self.done = truncated;
445
446 Ok(Transition {
447 obs: self.obs(),
448 reward,
449 terminated: false,
450 truncated,
451 info: None,
452 })
453 }
454
455 fn reset(&mut self, seed: Option<u64>) -> Result<Observation, RloxError> {
456 if let Some(s) = seed {
457 self.rng = rng_from_seed(s);
458 }
459
460 self.theta = self.rng.random_range(-PI..PI);
462 self.vel = self.rng.random_range(-1.0..1.0);
463 self.steps = 0;
464 self.done = false;
465
466 Ok(self.obs())
467 }
468
469 fn action_space(&self) -> &ActionSpace {
470 &self.action_space
471 }
472
473 fn obs_space(&self) -> &ObsSpace {
474 &self.obs_space
475 }
476
477 fn render(&self) -> Option<String> {
478 Some(format!(
479 "Pendulum | step={} | theta={:.4} vel={:.4}",
480 self.steps, self.theta, self.vel
481 ))
482 }
483}
484
485#[cfg(test)]
486mod pendulum_tests {
487 use super::*;
488
489 #[test]
490 fn pendulum_reset_produces_valid_obs() {
491 let env = Pendulum::new(Some(42));
492 let obs = env.obs();
493 let s = obs.as_slice();
494 assert_eq!(s.len(), 3);
495 assert!(
497 s[0] >= -1.0 && s[0] <= 1.0,
498 "cos(theta) out of range: {}",
499 s[0]
500 );
501 assert!(
502 s[1] >= -1.0 && s[1] <= 1.0,
503 "sin(theta) out of range: {}",
504 s[1]
505 );
506 assert!(s[2].abs() <= 8.0, "vel out of range: {}", s[2]);
508 }
509
510 #[test]
511 fn pendulum_step_known_state() {
512 let mut env = Pendulum::new(Some(42));
514 env.reset(Some(42)).unwrap();
515
516 let theta0 = env.theta;
518 let vel0 = env.vel;
519
520 let t = env.step(&Action::Continuous(vec![0.0])).unwrap();
522
523 let g = PENDULUM_GRAVITY;
525 let l = PENDULUM_LENGTH;
526 let dt = PENDULUM_DT;
527
528 let expected_vel = (vel0 + (3.0 * g / (2.0 * l) * theta0.sin()) * dt)
529 .clamp(-PENDULUM_MAX_VEL, PENDULUM_MAX_VEL);
530 let expected_theta = theta0 + expected_vel * dt;
531
532 assert!(
533 (env.theta - expected_theta).abs() < 1e-10,
534 "theta mismatch: got {}, expected {}",
535 env.theta,
536 expected_theta
537 );
538 assert!(
539 (env.vel - expected_vel).abs() < 1e-10,
540 "vel mismatch: got {}, expected {}",
541 env.vel,
542 expected_vel
543 );
544
545 let norm_theta = angle_normalize(theta0);
547 let expected_reward = -(norm_theta * norm_theta + 0.1 * vel0 * vel0);
548 assert!(
549 (t.reward - expected_reward).abs() < 1e-10,
550 "reward mismatch: got {}, expected {}",
551 t.reward,
552 expected_reward
553 );
554
555 assert!(!t.terminated);
556 assert!(!t.truncated);
557 }
558
559 #[test]
560 fn pendulum_step_with_torque() {
561 let mut env = Pendulum::new(Some(7));
562 env.reset(Some(7)).unwrap();
563
564 let theta0 = env.theta;
565 let vel0 = env.vel;
566 let torque = 1.5_f32;
567
568 let t = env.step(&Action::Continuous(vec![torque])).unwrap();
569
570 let g = PENDULUM_GRAVITY;
571 let m = PENDULUM_MASS;
572 let l = PENDULUM_LENGTH;
573 let dt = PENDULUM_DT;
574
575 let expected_vel = (vel0
576 + (3.0 * g / (2.0 * l) * theta0.sin() + 3.0 / (m * l * l) * torque as f64) * dt)
577 .clamp(-PENDULUM_MAX_VEL, PENDULUM_MAX_VEL);
578 let expected_theta = theta0 + expected_vel * dt;
579
580 assert!(
581 (env.theta - expected_theta).abs() < 1e-10,
582 "theta: got {}, expected {}",
583 env.theta,
584 expected_theta
585 );
586 assert!(
587 (env.vel - expected_vel).abs() < 1e-10,
588 "vel: got {}, expected {}",
589 env.vel,
590 expected_vel
591 );
592
593 let norm_theta = angle_normalize(theta0);
594 let expected_reward = -(norm_theta * norm_theta
595 + 0.1 * vel0 * vel0
596 + 0.001 * (torque as f64) * (torque as f64));
597 assert!(
598 (t.reward - expected_reward).abs() < 1e-10,
599 "reward: got {}, expected {}",
600 t.reward,
601 expected_reward
602 );
603 }
604
605 #[test]
606 fn pendulum_torque_clamped() {
607 let mut env = Pendulum::new(Some(42));
609 env.reset(Some(42)).unwrap();
610
611 let theta0 = env.theta;
612 let vel0 = env.vel;
613
614 env.step(&Action::Continuous(vec![10.0])).unwrap();
616
617 let g = PENDULUM_GRAVITY;
618 let m = PENDULUM_MASS;
619 let l = PENDULUM_LENGTH;
620 let dt = PENDULUM_DT;
621 let clamped_torque = PENDULUM_MAX_TORQUE;
622
623 let expected_vel = (vel0
624 + (3.0 * g / (2.0 * l) * theta0.sin() + 3.0 / (m * l * l) * clamped_torque) * dt)
625 .clamp(-PENDULUM_MAX_VEL, PENDULUM_MAX_VEL);
626
627 assert!(
628 (env.vel - expected_vel).abs() < 1e-10,
629 "torque clamping failed: vel={}, expected={}",
630 env.vel,
631 expected_vel
632 );
633 }
634
635 #[test]
636 fn pendulum_truncates_at_200() {
637 let mut env = Pendulum::new(Some(42));
638 env.reset(Some(42)).unwrap();
639
640 for i in 0..200 {
641 let t = env.step(&Action::Continuous(vec![0.0])).unwrap();
642 if i < 199 {
643 assert!(!t.truncated, "should not truncate at step {}", i + 1);
644 } else {
645 assert!(t.truncated, "should truncate at step 200");
646 assert!(!t.terminated);
647 }
648 }
649
650 let result = env.step(&Action::Continuous(vec![0.0]));
652 assert!(result.is_err());
653 }
654
655 #[test]
656 fn pendulum_never_terminates() {
657 let mut env = Pendulum::new(Some(42));
659 env.reset(Some(42)).unwrap();
660
661 for _ in 0..200 {
662 let t = env.step(&Action::Continuous(vec![0.0])).unwrap();
663 assert!(!t.terminated);
664 }
665 }
666
667 #[test]
668 fn pendulum_observation_bounds() {
669 let mut env = Pendulum::new(Some(42));
670 env.reset(Some(42)).unwrap();
671
672 for _ in 0..200 {
673 let t = env.step(&Action::Continuous(vec![2.0])).unwrap();
674 let s = t.obs.as_slice();
675 assert!(s[0] >= -1.0 && s[0] <= 1.0, "cos out of [-1,1]: {}", s[0]);
676 assert!(s[1] >= -1.0 && s[1] <= 1.0, "sin out of [-1,1]: {}", s[1]);
677 assert!(
678 s[2].abs() <= PENDULUM_MAX_VEL as f32 + 1e-6,
679 "vel out of [-8,8]: {}",
680 s[2]
681 );
682 if t.truncated {
683 break;
684 }
685 }
686 }
687
688 #[test]
689 fn pendulum_seeded_determinism() {
690 let run = |seed: u64| -> Vec<f64> {
691 let mut env = Pendulum::new(Some(seed));
692 let mut rewards = Vec::new();
693 for _ in 0..100 {
694 let t = env.step(&Action::Continuous(vec![1.0])).unwrap();
695 rewards.push(t.reward);
696 }
697 rewards
698 };
699
700 let r1 = run(123);
701 let r2 = run(123);
702 assert_eq!(r1, r2);
703
704 let r3 = run(456);
705 assert_ne!(r1, r3);
706 }
707
708 #[test]
709 fn pendulum_invalid_action_discrete() {
710 let mut env = Pendulum::new(Some(42));
711 env.reset(Some(42)).unwrap();
712 let result = env.step(&Action::Discrete(0));
713 assert!(result.is_err());
714 }
715
716 #[test]
717 fn pendulum_invalid_action_wrong_dim() {
718 let mut env = Pendulum::new(Some(42));
719 env.reset(Some(42)).unwrap();
720 let result = env.step(&Action::Continuous(vec![1.0, 2.0]));
721 assert!(result.is_err());
722 }
723
724 #[test]
725 fn angle_normalize_basic() {
726 assert!((angle_normalize(0.0)).abs() < 1e-10);
727 assert!((angle_normalize(PI) - (-PI)).abs() < 1e-10);
729 assert!((angle_normalize(-PI) - (-PI)).abs() < 1e-10);
730 assert!((angle_normalize(2.0 * PI)).abs() < 1e-10);
732 assert!((angle_normalize(3.0 * PI) - (-PI)).abs() < 1e-10);
734 }
735}