rlox_core/env/
spaces.rs

1use serde::{Deserialize, Serialize};
2
3/// Describes the action space of an environment.
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub enum ActionSpace {
6    /// A single discrete action in `0..n`.
7    Discrete(usize),
8    /// A continuous box space with per-dimension bounds.
9    Box {
10        low: Vec<f32>,
11        high: Vec<f32>,
12        shape: Vec<usize>,
13    },
14    /// Multiple discrete sub-spaces, each with its own cardinality.
15    MultiDiscrete(Vec<usize>),
16}
17
18/// Describes the observation space of an environment.
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ObsSpace {
21    Discrete(usize),
22    Box {
23        low: Vec<f32>,
24        high: Vec<f32>,
25        shape: Vec<usize>,
26    },
27    MultiDiscrete(Vec<usize>),
28    /// Dict observation space: ordered `(key, dim)` pairs.
29    Dict(Vec<(String, usize)>),
30}
31
32/// A concrete action value.
33#[derive(Debug, Clone, PartialEq)]
34pub enum Action {
35    Discrete(u32),
36    Continuous(Vec<f32>),
37}
38
39/// A concrete observation value.
40///
41/// `Flat` is the default: a single dense vector of f32 values.
42/// `Dict` supports multi-modal observations (e.g. image + proprioception),
43/// stored as ordered key-value pairs of named sub-vectors.
44#[derive(Debug, Clone, PartialEq)]
45pub enum Observation {
46    /// A flat float vector (the common case).
47    Flat(Vec<f32>),
48    /// Named sub-observations in a fixed order.
49    Dict(Vec<(String, Vec<f32>)>),
50}
51
52impl Observation {
53    /// Convenience constructor matching the old tuple-struct API.
54    ///
55    /// `Observation::flat(vec)` is equivalent to the former `Observation(vec)`.
56    pub fn flat(data: Vec<f32>) -> Self {
57        Observation::Flat(data)
58    }
59
60    /// Try to view as a flat f32 slice.
61    ///
62    /// Returns `Some(&[f32])` for the `Flat` variant, `None` for `Dict`.
63    /// Prefer this over [`as_slice`](Self::as_slice) when handling observations
64    /// that may be either variant.
65    pub fn try_as_slice(&self) -> Option<&[f32]> {
66        match self {
67            Observation::Flat(v) => Some(v),
68            Observation::Dict(_) => None,
69        }
70    }
71
72    /// View as a flat f32 slice.
73    ///
74    /// For `Flat`, returns the inner data directly.
75    ///
76    /// # Panics
77    ///
78    /// Panics if the observation is the `Dict` variant. Use
79    /// [`try_as_slice`](Self::try_as_slice) for a fallible alternative, or
80    /// [`flatten`](Self::flatten) if you need a contiguous copy.
81    pub fn as_slice(&self) -> &[f32] {
82        match self {
83            Observation::Flat(v) => v,
84            Observation::Dict(_) => {
85                panic!("Observation::as_slice() called on Dict variant; use try_as_slice() or flatten() instead")
86            }
87        }
88    }
89
90    /// Consume and return the inner Vec for the `Flat` variant.
91    ///
92    /// For `Dict`, returns a flattened (concatenated) copy.
93    pub fn into_inner(self) -> Vec<f32> {
94        match self {
95            Observation::Flat(v) => v,
96            Observation::Dict(pairs) => {
97                let total = pairs.iter().map(|(_, v)| v.len()).sum();
98                let mut out = Vec::with_capacity(total);
99                for (_, v) in pairs {
100                    out.extend(v);
101                }
102                out
103            }
104        }
105    }
106
107    /// Total number of f32 elements across all keys.
108    pub fn total_dim(&self) -> usize {
109        match self {
110            Observation::Flat(v) => v.len(),
111            Observation::Dict(pairs) => pairs.iter().map(|(_, v)| v.len()).sum(),
112        }
113    }
114
115    /// Flatten to a single `Vec<f32>` (concatenate all values in key order).
116    ///
117    /// For `Flat`, returns a clone of the inner data.
118    /// For `Dict`, concatenates all sub-vectors.
119    pub fn flatten(&self) -> Vec<f32> {
120        match self {
121            Observation::Flat(v) => v.clone(),
122            Observation::Dict(pairs) => {
123                let total = pairs.iter().map(|(_, v)| v.len()).sum();
124                let mut out = Vec::with_capacity(total);
125                for (_, v) in pairs {
126                    out.extend_from_slice(v);
127                }
128                out
129            }
130        }
131    }
132
133    /// Get a named sub-observation (for the `Dict` variant).
134    ///
135    /// Returns `None` for the `Flat` variant or if the key is not found.
136    pub fn get(&self, key: &str) -> Option<&[f32]> {
137        match self {
138            Observation::Flat(_) => None,
139            Observation::Dict(pairs) => pairs
140                .iter()
141                .find(|(k, _)| k == key)
142                .map(|(_, v)| v.as_slice()),
143        }
144    }
145}
146
147impl ActionSpace {
148    /// Check whether an action is valid for this space.
149    pub fn contains(&self, action: &Action) -> bool {
150        match (self, action) {
151            (ActionSpace::Discrete(n), Action::Discrete(a)) => (*a as usize) < *n,
152            (ActionSpace::Box { low, high, .. }, Action::Continuous(vals)) => {
153                vals.len() == low.len()
154                    && vals
155                        .iter()
156                        .zip(low.iter().zip(high.iter()))
157                        .all(|(v, (lo, hi))| *v >= *lo && *v <= *hi)
158            }
159            _ => false,
160        }
161    }
162}
163
164impl ObsSpace {
165    /// Check whether an observation is valid for this space.
166    pub fn contains(&self, obs: &Observation) -> bool {
167        match self {
168            ObsSpace::Discrete(n) => {
169                let s = match obs {
170                    Observation::Flat(v) => v.as_slice(),
171                    _ => return false,
172                };
173                s.len() == 1 && (s[0] as usize) < *n
174            }
175            ObsSpace::Box { low, high, .. } => {
176                let s = match obs {
177                    Observation::Flat(v) => v.as_slice(),
178                    _ => return false,
179                };
180                s.len() == low.len()
181                    && s.iter()
182                        .zip(low.iter().zip(high.iter()))
183                        .all(|(v, (lo, hi))| *v >= *lo && *v <= *hi)
184            }
185            ObsSpace::MultiDiscrete(nvec) => {
186                let s = match obs {
187                    Observation::Flat(v) => v.as_slice(),
188                    _ => return false,
189                };
190                s.len() == nvec.len()
191                    && s.iter()
192                        .zip(nvec.iter())
193                        .all(|(v, n)| *v >= 0.0 && (*v as usize) < *n)
194            }
195            ObsSpace::Dict(entries) => {
196                let pairs = match obs {
197                    Observation::Dict(p) => p,
198                    _ => return false,
199                };
200                if pairs.len() != entries.len() {
201                    return false;
202                }
203                pairs
204                    .iter()
205                    .zip(entries.iter())
206                    .all(|((ok, ov), (ek, ed))| ok == ek && ov.len() == *ed)
207            }
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn discrete_action_space_contains() {
218        let space = ActionSpace::Discrete(3);
219        assert!(space.contains(&Action::Discrete(0)));
220        assert!(space.contains(&Action::Discrete(2)));
221        assert!(!space.contains(&Action::Discrete(3)));
222        assert!(!space.contains(&Action::Continuous(vec![0.0])));
223    }
224
225    #[test]
226    fn box_action_space_contains() {
227        let space = ActionSpace::Box {
228            low: vec![-1.0, -2.0],
229            high: vec![1.0, 2.0],
230            shape: vec![2],
231        };
232        assert!(space.contains(&Action::Continuous(vec![0.0, 0.0])));
233        assert!(space.contains(&Action::Continuous(vec![-1.0, 2.0])));
234        assert!(!space.contains(&Action::Continuous(vec![1.5, 0.0])));
235        assert!(!space.contains(&Action::Continuous(vec![0.0])));
236    }
237
238    #[test]
239    fn discrete_obs_space_contains() {
240        let space = ObsSpace::Discrete(5);
241        assert!(space.contains(&Observation::Flat(vec![3.0])));
242        assert!(!space.contains(&Observation::Flat(vec![5.0])));
243        assert!(!space.contains(&Observation::Flat(vec![1.0, 2.0])));
244    }
245
246    #[test]
247    fn box_obs_space_contains() {
248        let space = ObsSpace::Box {
249            low: vec![-4.8; 4],
250            high: vec![4.8; 4],
251            shape: vec![4],
252        };
253        assert!(space.contains(&Observation::Flat(vec![0.0, 0.0, 0.0, 0.0])));
254        assert!(!space.contains(&Observation::Flat(vec![0.0, 0.0, 5.0, 0.0])));
255    }
256
257    #[test]
258    fn multi_discrete_obs_space_contains() {
259        let space = ObsSpace::MultiDiscrete(vec![3, 5]);
260        assert!(space.contains(&Observation::Flat(vec![2.0, 4.0])));
261        assert!(!space.contains(&Observation::Flat(vec![3.0, 0.0])));
262    }
263
264    // --- Dict Observation tests ---
265
266    #[test]
267    fn test_dict_observation_total_dim() {
268        let obs = Observation::Dict(vec![
269            ("image".into(), vec![0.0; 784]),
270            ("proprio".into(), vec![0.0; 7]),
271        ]);
272        assert_eq!(obs.total_dim(), 791);
273    }
274
275    #[test]
276    fn test_dict_observation_flatten() {
277        let obs = Observation::Dict(vec![("a".into(), vec![1.0, 2.0]), ("b".into(), vec![3.0])]);
278        assert_eq!(obs.flatten(), vec![1.0, 2.0, 3.0]);
279    }
280
281    #[test]
282    fn test_dict_observation_get_key() {
283        let obs = Observation::Dict(vec![
284            ("image".into(), vec![1.0, 2.0, 3.0]),
285            ("proprio".into(), vec![4.0, 5.0]),
286        ]);
287        assert_eq!(obs.get("image"), Some([1.0, 2.0, 3.0].as_slice()));
288        assert_eq!(obs.get("proprio"), Some([4.0, 5.0].as_slice()));
289        assert_eq!(obs.get("missing"), None);
290    }
291
292    #[test]
293    fn test_dict_obs_space_contains() {
294        let space = ObsSpace::Dict(vec![("image".into(), 784), ("proprio".into(), 7)]);
295
296        let valid = Observation::Dict(vec![
297            ("image".into(), vec![0.0; 784]),
298            ("proprio".into(), vec![0.0; 7]),
299        ]);
300        assert!(space.contains(&valid));
301
302        // Wrong dim
303        let bad_dim = Observation::Dict(vec![
304            ("image".into(), vec![0.0; 784]),
305            ("proprio".into(), vec![0.0; 8]),
306        ]);
307        assert!(!space.contains(&bad_dim));
308
309        // Wrong key name
310        let bad_key = Observation::Dict(vec![
311            ("image".into(), vec![0.0; 784]),
312            ("wrong".into(), vec![0.0; 7]),
313        ]);
314        assert!(!space.contains(&bad_key));
315
316        // Wrong number of entries
317        let bad_count = Observation::Dict(vec![("image".into(), vec![0.0; 784])]);
318        assert!(!space.contains(&bad_count));
319
320        // Flat obs should not match Dict space
321        let flat = Observation::Flat(vec![0.0; 791]);
322        assert!(!space.contains(&flat));
323    }
324
325    #[test]
326    fn test_try_as_slice_flat_returns_some() {
327        let obs = Observation::Flat(vec![1.0, 2.0, 3.0]);
328        assert_eq!(obs.try_as_slice(), Some([1.0, 2.0, 3.0].as_slice()));
329    }
330
331    #[test]
332    fn test_try_as_slice_dict_returns_none() {
333        let obs = Observation::Dict(vec![("a".into(), vec![1.0])]);
334        assert_eq!(obs.try_as_slice(), None);
335    }
336
337    #[test]
338    #[should_panic(expected = "Dict variant")]
339    fn test_as_slice_dict_panics() {
340        let obs = Observation::Dict(vec![("a".into(), vec![1.0])]);
341        let _ = obs.as_slice();
342    }
343
344    #[test]
345    fn test_flat_observation_backward_compat() {
346        // Ensure Flat variant works exactly like the old tuple struct
347        let obs = Observation::Flat(vec![1.0, 2.0, 3.0]);
348        assert_eq!(obs.as_slice(), &[1.0, 2.0, 3.0]);
349        assert_eq!(obs.total_dim(), 3);
350        assert_eq!(obs.flatten(), vec![1.0, 2.0, 3.0]);
351        assert_eq!(obs.get("anything"), None);
352
353        let inner = obs.into_inner();
354        assert_eq!(inner, vec![1.0, 2.0, 3.0]);
355    }
356
357    #[test]
358    fn test_dict_observation_into_inner_flattens() {
359        let obs = Observation::Dict(vec![
360            ("a".into(), vec![1.0, 2.0]),
361            ("b".into(), vec![3.0, 4.0, 5.0]),
362        ]);
363        assert_eq!(obs.into_inner(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
364    }
365
366    #[test]
367    fn test_dict_obs_space_does_not_match_flat_obs() {
368        let space = ObsSpace::Discrete(5);
369        let dict_obs = Observation::Dict(vec![("x".into(), vec![3.0])]);
370        assert!(!space.contains(&dict_obs));
371    }
372}