1use crate::error::RloxError;
8
9pub struct RewardContext<'a> {
11 pub dones: &'a [f64],
12 pub gamma: f64,
13 pub extras: &'a [(&'a str, &'a [f64])],
15}
16
17pub trait RewardTransform: Send + Sync {
22 fn transform(
24 &self,
25 rewards: &[f64],
26 context: &RewardContext<'_>,
27 ) -> Result<Vec<f64>, RloxError>;
28
29 fn name(&self) -> &str;
31}
32
33pub struct PBRSTransform;
37
38impl RewardTransform for PBRSTransform {
39 fn transform(
40 &self,
41 rewards: &[f64],
42 context: &RewardContext<'_>,
43 ) -> Result<Vec<f64>, RloxError> {
44 let phi_current = context
45 .extras
46 .iter()
47 .find(|(name, _)| *name == "phi_current")
48 .map(|(_, v)| *v)
49 .ok_or_else(|| RloxError::BufferError("missing 'phi_current' in extras".into()))?;
50 let phi_next = context
51 .extras
52 .iter()
53 .find(|(name, _)| *name == "phi_next")
54 .map(|(_, v)| *v)
55 .ok_or_else(|| RloxError::BufferError("missing 'phi_next' in extras".into()))?;
56
57 shape_rewards_pbrs(rewards, phi_current, phi_next, context.gamma, context.dones)
58 }
59
60 fn name(&self) -> &str {
61 "PBRS"
62 }
63}
64
65pub struct GoalDistanceTransform {
71 pub scale: f64,
72 pub goal_start: usize,
73 pub goal_dim: usize,
74}
75
76impl RewardTransform for GoalDistanceTransform {
77 fn transform(
78 &self,
79 rewards: &[f64],
80 context: &RewardContext<'_>,
81 ) -> Result<Vec<f64>, RloxError> {
82 let phi_current = context
83 .extras
84 .iter()
85 .find(|(name, _)| *name == "phi_current")
86 .map(|(_, v)| *v)
87 .ok_or_else(|| RloxError::BufferError("missing 'phi_current' in extras".into()))?;
88 let phi_next = context
89 .extras
90 .iter()
91 .find(|(name, _)| *name == "phi_next")
92 .map(|(_, v)| *v)
93 .ok_or_else(|| RloxError::BufferError("missing 'phi_next' in extras".into()))?;
94
95 shape_rewards_pbrs(rewards, phi_current, phi_next, context.gamma, context.dones)
96 }
97
98 fn name(&self) -> &str {
99 "GoalDistance"
100 }
101}
102
103#[inline]
115pub fn shape_rewards_pbrs(
116 rewards: &[f64],
117 potentials_current: &[f64],
118 potentials_next: &[f64],
119 gamma: f64,
120 dones: &[f64],
121) -> Result<Vec<f64>, RloxError> {
122 let n = rewards.len();
123 if potentials_current.len() != n || potentials_next.len() != n || dones.len() != n {
124 return Err(RloxError::ShapeMismatch {
125 expected: format!("all slices length {n}"),
126 got: format!(
127 "phi_current={}, phi_next={}, dones={}",
128 potentials_current.len(),
129 potentials_next.len(),
130 dones.len()
131 ),
132 });
133 }
134
135 let mut output = Vec::with_capacity(n);
136 for i in 0..n {
137 if dones[i] == 1.0 {
138 output.push(rewards[i]);
139 } else {
140 output.push(rewards[i] + gamma * potentials_next[i] - potentials_current[i]);
141 }
142 }
143 Ok(output)
144}
145
146#[inline]
156pub fn compute_goal_distance_potentials(
157 observations: &[f64],
158 goal: &[f64],
159 obs_dim: usize,
160 goal_start: usize,
161 goal_dim: usize,
162 scale: f64,
163) -> Result<Vec<f64>, RloxError> {
164 if goal.len() != goal_dim {
165 return Err(RloxError::ShapeMismatch {
166 expected: format!("goal.len() == goal_dim={goal_dim}"),
167 got: format!("goal.len()={}", goal.len()),
168 });
169 }
170 if obs_dim == 0 {
171 return Err(RloxError::ShapeMismatch {
172 expected: "obs_dim > 0".into(),
173 got: "obs_dim=0".into(),
174 });
175 }
176 if !observations.len().is_multiple_of(obs_dim) {
177 return Err(RloxError::ShapeMismatch {
178 expected: format!("observations.len() divisible by obs_dim={obs_dim}"),
179 got: format!("observations.len()={}", observations.len()),
180 });
181 }
182 if goal_start + goal_dim > obs_dim {
183 return Err(RloxError::ShapeMismatch {
184 expected: format!("goal_start + goal_dim <= obs_dim={obs_dim}"),
185 got: format!("goal_start={goal_start}, goal_dim={goal_dim}"),
186 });
187 }
188
189 let n = observations.len() / obs_dim;
190 let mut potentials = Vec::with_capacity(n);
191
192 for i in 0..n {
193 let obs_start = i * obs_dim + goal_start;
194 let obs_slice = &observations[obs_start..obs_start + goal_dim];
195 let dist_sq: f64 = obs_slice
196 .iter()
197 .zip(goal.iter())
198 .map(|(&o, &g)| (o - g) * (o - g))
199 .sum();
200 potentials.push(-scale * dist_sq.sqrt());
201 }
202
203 Ok(potentials)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_pbrs_known_values() {
212 let rewards = &[1.0, 2.0];
213 let phi = &[0.5, 0.3];
214 let phi_next = &[0.3, 0.8];
215 let gamma = 0.99;
216 let dones = &[0.0, 0.0];
217 let result = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
218 assert!(
220 (result[0] - 0.797).abs() < 1e-10,
221 "expected 0.797, got {}",
222 result[0]
223 );
224 assert!(
226 (result[1] - 2.492).abs() < 1e-10,
227 "expected 2.492, got {}",
228 result[1]
229 );
230 }
231
232 #[test]
233 fn test_pbrs_done_resets_potential() {
234 let rewards = &[1.0, 2.0];
235 let phi = &[0.5, 0.3];
236 let phi_next = &[0.3, 0.8];
237 let gamma = 0.99;
238 let dones = &[0.0, 1.0];
239 let result = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
240 assert!(
241 (result[1] - 2.0).abs() < 1e-10,
242 "done step should return raw reward, got {}",
243 result[1]
244 );
245 }
246
247 #[test]
248 fn test_pbrs_zero_potentials_no_change() {
249 let rewards = &[1.0, 2.0, 3.0];
250 let phi = &[0.0, 0.0, 0.0];
251 let phi_next = &[0.0, 0.0, 0.0];
252 let dones = &[0.0, 0.0, 0.0];
253 let result = shape_rewards_pbrs(rewards, phi, phi_next, 0.99, dones).unwrap();
254 for i in 0..3 {
255 assert!(
256 (result[i] - rewards[i]).abs() < 1e-10,
257 "zero potentials should not change reward"
258 );
259 }
260 }
261
262 #[test]
263 fn test_pbrs_preserves_optimal_policy() {
264 let rewards = &[1.0, 2.0, 3.0, 4.0, 5.0];
267 let phi = &[0.5, 0.3, 0.8, 0.1, 0.9];
268 let phi_next = &[0.3, 0.8, 0.1, 0.9, 0.0];
269 let gamma = 0.99;
270 let dones = &[0.0, 0.0, 0.0, 0.0, 0.0];
272 let shaped = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
273 let sum_raw: f64 = rewards.iter().sum();
274 let sum_shaped: f64 = shaped.iter().sum();
275 let shaping_sum: f64 = (0..5).map(|i| gamma * phi_next[i] - phi[i]).sum::<f64>();
277 assert!(
278 (sum_shaped - (sum_raw + shaping_sum)).abs() < 1e-10,
279 "sum_shaped={sum_shaped}, expected {}",
280 sum_raw + shaping_sum
281 );
282 }
283
284 #[test]
285 fn test_pbrs_length_mismatch_errors() {
286 let result = shape_rewards_pbrs(
287 &[1.0, 2.0, 3.0],
288 &[0.0, 0.0],
289 &[0.0, 0.0],
290 0.99,
291 &[0.0, 0.0],
292 );
293 assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
294 }
295
296 #[test]
297 fn test_goal_distance_decreasing_near_goal() {
298 let goal = &[0.0, 0.0];
299 let observations = &[
301 1.0, 0.0, 1.0, 0.0, 0.1, 0.0, 0.1, 0.0, ];
304 let potentials =
305 compute_goal_distance_potentials(observations, goal, 4, 0, 2, 1.0).unwrap();
306 assert!(
307 potentials[1] > potentials[0],
308 "closer obs should have less negative potential: far={}, close={}",
309 potentials[0],
310 potentials[1]
311 );
312 }
313
314 #[test]
315 fn test_goal_distance_at_goal_is_zero() {
316 let goal = &[3.0, 4.0];
317 let observations = &[3.0, 4.0, 0.0, 0.0]; let potentials =
319 compute_goal_distance_potentials(observations, goal, 4, 0, 2, 1.0).unwrap();
320 assert!(
321 potentials[0].abs() < 1e-10,
322 "at goal, potential should be 0, got {}",
323 potentials[0]
324 );
325 }
326
327 #[test]
328 fn test_goal_distance_scale_factor() {
329 let goal = &[0.0];
330 let observations = &[1.0, 0.0]; let phi_1 = compute_goal_distance_potentials(observations, goal, 2, 0, 1, 1.0).unwrap();
332 let phi_2 = compute_goal_distance_potentials(observations, goal, 2, 0, 1, 2.0).unwrap();
333 assert!(
334 (phi_2[0] - 2.0 * phi_1[0]).abs() < 1e-10,
335 "scale=2 should double potential: phi_1={}, phi_2={}",
336 phi_1[0],
337 phi_2[0]
338 );
339 }
340
341 #[test]
342 fn test_goal_distance_validates_dimensions() {
343 let result = compute_goal_distance_potentials(
344 &[1.0, 2.0, 3.0, 4.0],
345 &[0.0, 0.0, 0.0], 4,
347 0,
348 2,
349 1.0,
350 );
351 assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
352 }
353
354 #[test]
355 fn test_trait_object_safety() {
356 let transform: Box<dyn RewardTransform> = Box::new(PBRSTransform);
357 assert_eq!(transform.name(), "PBRS");
358 }
359
360 mod proptests {
361 use super::*;
362 use proptest::prelude::*;
363
364 proptest! {
365 #[test]
366 fn prop_pbrs_length_matches_input(n in 1usize..500) {
367 let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
368 let phi: Vec<f64> = vec![0.5; n];
369 let phi_next: Vec<f64> = vec![0.3; n];
370 let dones: Vec<f64> = vec![0.0; n];
371 let result = shape_rewards_pbrs(&rewards, &phi, &phi_next, 0.99, &dones).unwrap();
372 prop_assert_eq!(result.len(), n);
373 }
374
375 #[test]
376 fn prop_pbrs_zero_gamma_no_future(n in 1usize..100) {
377 let rewards: Vec<f64> = (0..n).map(|i| i as f64).collect();
378 let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.5).collect();
379 let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
380 let dones: Vec<f64> = vec![0.0; n];
381 let result = shape_rewards_pbrs(&rewards, &phi, &phi_next, 0.0, &dones).unwrap();
382 for i in 0..n {
383 let expected = rewards[i] - phi[i]; prop_assert!(
385 (result[i] - expected).abs() < 1e-10,
386 "index {i}: got {}, expected {expected}",
387 result[i]
388 );
389 }
390 }
391
392 #[test]
393 fn prop_goal_distance_non_positive(
394 n in 1usize..50,
395 goal_dim in 1usize..4,
396 ) {
397 let obs_dim = goal_dim + 2;
398 let obs: Vec<f64> = (0..(n * obs_dim)).map(|i| i as f64 * 0.1).collect();
399 let goal: Vec<f64> = vec![0.0; goal_dim];
400 let potentials = compute_goal_distance_potentials(
401 &obs, &goal, obs_dim, 0, goal_dim, 1.0
402 ).unwrap();
403 for (i, &p) in potentials.iter().enumerate() {
404 prop_assert!(p <= 0.0 + 1e-10,
405 "potential[{i}] = {p} should be <= 0 for positive scale");
406 }
407 }
408 }
409 }
410}