1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub enum ActionSpace {
6 Discrete(usize),
8 Box {
10 low: Vec<f32>,
11 high: Vec<f32>,
12 shape: Vec<usize>,
13 },
14 MultiDiscrete(Vec<usize>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ObsSpace {
21 Discrete(usize),
22 Box {
23 low: Vec<f32>,
24 high: Vec<f32>,
25 shape: Vec<usize>,
26 },
27 MultiDiscrete(Vec<usize>),
28 Dict(Vec<(String, usize)>),
30}
31
32#[derive(Debug, Clone, PartialEq)]
34pub enum Action {
35 Discrete(u32),
36 Continuous(Vec<f32>),
37}
38
39#[derive(Debug, Clone, PartialEq)]
45pub enum Observation {
46 Flat(Vec<f32>),
48 Dict(Vec<(String, Vec<f32>)>),
50}
51
52impl Observation {
53 pub fn flat(data: Vec<f32>) -> Self {
57 Observation::Flat(data)
58 }
59
60 pub fn try_as_slice(&self) -> Option<&[f32]> {
66 match self {
67 Observation::Flat(v) => Some(v),
68 Observation::Dict(_) => None,
69 }
70 }
71
72 pub fn as_slice(&self) -> &[f32] {
82 match self {
83 Observation::Flat(v) => v,
84 Observation::Dict(_) => {
85 panic!("Observation::as_slice() called on Dict variant; use try_as_slice() or flatten() instead")
86 }
87 }
88 }
89
90 pub fn into_inner(self) -> Vec<f32> {
94 match self {
95 Observation::Flat(v) => v,
96 Observation::Dict(pairs) => {
97 let total = pairs.iter().map(|(_, v)| v.len()).sum();
98 let mut out = Vec::with_capacity(total);
99 for (_, v) in pairs {
100 out.extend(v);
101 }
102 out
103 }
104 }
105 }
106
107 pub fn total_dim(&self) -> usize {
109 match self {
110 Observation::Flat(v) => v.len(),
111 Observation::Dict(pairs) => pairs.iter().map(|(_, v)| v.len()).sum(),
112 }
113 }
114
115 pub fn flatten(&self) -> Vec<f32> {
120 match self {
121 Observation::Flat(v) => v.clone(),
122 Observation::Dict(pairs) => {
123 let total = pairs.iter().map(|(_, v)| v.len()).sum();
124 let mut out = Vec::with_capacity(total);
125 for (_, v) in pairs {
126 out.extend_from_slice(v);
127 }
128 out
129 }
130 }
131 }
132
133 pub fn get(&self, key: &str) -> Option<&[f32]> {
137 match self {
138 Observation::Flat(_) => None,
139 Observation::Dict(pairs) => pairs
140 .iter()
141 .find(|(k, _)| k == key)
142 .map(|(_, v)| v.as_slice()),
143 }
144 }
145}
146
147impl ActionSpace {
148 pub fn contains(&self, action: &Action) -> bool {
150 match (self, action) {
151 (ActionSpace::Discrete(n), Action::Discrete(a)) => (*a as usize) < *n,
152 (ActionSpace::Box { low, high, .. }, Action::Continuous(vals)) => {
153 vals.len() == low.len()
154 && vals
155 .iter()
156 .zip(low.iter().zip(high.iter()))
157 .all(|(v, (lo, hi))| *v >= *lo && *v <= *hi)
158 }
159 _ => false,
160 }
161 }
162}
163
164impl ObsSpace {
165 pub fn contains(&self, obs: &Observation) -> bool {
167 match self {
168 ObsSpace::Discrete(n) => {
169 let s = match obs {
170 Observation::Flat(v) => v.as_slice(),
171 _ => return false,
172 };
173 s.len() == 1 && (s[0] as usize) < *n
174 }
175 ObsSpace::Box { low, high, .. } => {
176 let s = match obs {
177 Observation::Flat(v) => v.as_slice(),
178 _ => return false,
179 };
180 s.len() == low.len()
181 && s.iter()
182 .zip(low.iter().zip(high.iter()))
183 .all(|(v, (lo, hi))| *v >= *lo && *v <= *hi)
184 }
185 ObsSpace::MultiDiscrete(nvec) => {
186 let s = match obs {
187 Observation::Flat(v) => v.as_slice(),
188 _ => return false,
189 };
190 s.len() == nvec.len()
191 && s.iter()
192 .zip(nvec.iter())
193 .all(|(v, n)| *v >= 0.0 && (*v as usize) < *n)
194 }
195 ObsSpace::Dict(entries) => {
196 let pairs = match obs {
197 Observation::Dict(p) => p,
198 _ => return false,
199 };
200 if pairs.len() != entries.len() {
201 return false;
202 }
203 pairs
204 .iter()
205 .zip(entries.iter())
206 .all(|((ok, ov), (ek, ed))| ok == ek && ov.len() == *ed)
207 }
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn discrete_action_space_contains() {
218 let space = ActionSpace::Discrete(3);
219 assert!(space.contains(&Action::Discrete(0)));
220 assert!(space.contains(&Action::Discrete(2)));
221 assert!(!space.contains(&Action::Discrete(3)));
222 assert!(!space.contains(&Action::Continuous(vec![0.0])));
223 }
224
225 #[test]
226 fn box_action_space_contains() {
227 let space = ActionSpace::Box {
228 low: vec![-1.0, -2.0],
229 high: vec![1.0, 2.0],
230 shape: vec![2],
231 };
232 assert!(space.contains(&Action::Continuous(vec![0.0, 0.0])));
233 assert!(space.contains(&Action::Continuous(vec![-1.0, 2.0])));
234 assert!(!space.contains(&Action::Continuous(vec![1.5, 0.0])));
235 assert!(!space.contains(&Action::Continuous(vec![0.0])));
236 }
237
238 #[test]
239 fn discrete_obs_space_contains() {
240 let space = ObsSpace::Discrete(5);
241 assert!(space.contains(&Observation::Flat(vec![3.0])));
242 assert!(!space.contains(&Observation::Flat(vec![5.0])));
243 assert!(!space.contains(&Observation::Flat(vec![1.0, 2.0])));
244 }
245
246 #[test]
247 fn box_obs_space_contains() {
248 let space = ObsSpace::Box {
249 low: vec![-4.8; 4],
250 high: vec![4.8; 4],
251 shape: vec![4],
252 };
253 assert!(space.contains(&Observation::Flat(vec![0.0, 0.0, 0.0, 0.0])));
254 assert!(!space.contains(&Observation::Flat(vec![0.0, 0.0, 5.0, 0.0])));
255 }
256
257 #[test]
258 fn multi_discrete_obs_space_contains() {
259 let space = ObsSpace::MultiDiscrete(vec![3, 5]);
260 assert!(space.contains(&Observation::Flat(vec![2.0, 4.0])));
261 assert!(!space.contains(&Observation::Flat(vec![3.0, 0.0])));
262 }
263
264 #[test]
267 fn test_dict_observation_total_dim() {
268 let obs = Observation::Dict(vec![
269 ("image".into(), vec![0.0; 784]),
270 ("proprio".into(), vec![0.0; 7]),
271 ]);
272 assert_eq!(obs.total_dim(), 791);
273 }
274
275 #[test]
276 fn test_dict_observation_flatten() {
277 let obs = Observation::Dict(vec![("a".into(), vec![1.0, 2.0]), ("b".into(), vec![3.0])]);
278 assert_eq!(obs.flatten(), vec![1.0, 2.0, 3.0]);
279 }
280
281 #[test]
282 fn test_dict_observation_get_key() {
283 let obs = Observation::Dict(vec![
284 ("image".into(), vec![1.0, 2.0, 3.0]),
285 ("proprio".into(), vec![4.0, 5.0]),
286 ]);
287 assert_eq!(obs.get("image"), Some([1.0, 2.0, 3.0].as_slice()));
288 assert_eq!(obs.get("proprio"), Some([4.0, 5.0].as_slice()));
289 assert_eq!(obs.get("missing"), None);
290 }
291
292 #[test]
293 fn test_dict_obs_space_contains() {
294 let space = ObsSpace::Dict(vec![("image".into(), 784), ("proprio".into(), 7)]);
295
296 let valid = Observation::Dict(vec![
297 ("image".into(), vec![0.0; 784]),
298 ("proprio".into(), vec![0.0; 7]),
299 ]);
300 assert!(space.contains(&valid));
301
302 let bad_dim = Observation::Dict(vec![
304 ("image".into(), vec![0.0; 784]),
305 ("proprio".into(), vec![0.0; 8]),
306 ]);
307 assert!(!space.contains(&bad_dim));
308
309 let bad_key = Observation::Dict(vec![
311 ("image".into(), vec![0.0; 784]),
312 ("wrong".into(), vec![0.0; 7]),
313 ]);
314 assert!(!space.contains(&bad_key));
315
316 let bad_count = Observation::Dict(vec![("image".into(), vec![0.0; 784])]);
318 assert!(!space.contains(&bad_count));
319
320 let flat = Observation::Flat(vec![0.0; 791]);
322 assert!(!space.contains(&flat));
323 }
324
325 #[test]
326 fn test_try_as_slice_flat_returns_some() {
327 let obs = Observation::Flat(vec![1.0, 2.0, 3.0]);
328 assert_eq!(obs.try_as_slice(), Some([1.0, 2.0, 3.0].as_slice()));
329 }
330
331 #[test]
332 fn test_try_as_slice_dict_returns_none() {
333 let obs = Observation::Dict(vec![("a".into(), vec![1.0])]);
334 assert_eq!(obs.try_as_slice(), None);
335 }
336
337 #[test]
338 #[should_panic(expected = "Dict variant")]
339 fn test_as_slice_dict_panics() {
340 let obs = Observation::Dict(vec![("a".into(), vec![1.0])]);
341 let _ = obs.as_slice();
342 }
343
344 #[test]
345 fn test_flat_observation_backward_compat() {
346 let obs = Observation::Flat(vec![1.0, 2.0, 3.0]);
348 assert_eq!(obs.as_slice(), &[1.0, 2.0, 3.0]);
349 assert_eq!(obs.total_dim(), 3);
350 assert_eq!(obs.flatten(), vec![1.0, 2.0, 3.0]);
351 assert_eq!(obs.get("anything"), None);
352
353 let inner = obs.into_inner();
354 assert_eq!(inner, vec![1.0, 2.0, 3.0]);
355 }
356
357 #[test]
358 fn test_dict_observation_into_inner_flattens() {
359 let obs = Observation::Dict(vec![
360 ("a".into(), vec![1.0, 2.0]),
361 ("b".into(), vec![3.0, 4.0, 5.0]),
362 ]);
363 assert_eq!(obs.into_inner(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
364 }
365
366 #[test]
367 fn test_dict_obs_space_does_not_match_flat_obs() {
368 let space = ObsSpace::Discrete(5);
369 let dict_obs = Observation::Dict(vec![("x".into(), vec![3.0])]);
370 assert!(!space.contains(&dict_obs));
371 }
372}