rlox_core/training/
reward_shaping.rs

1//! Potential-based reward shaping (PBRS) and goal-distance potentials.
2//!
3//! Implements the PBRS transform `r' = r + gamma * Phi(s') - Phi(s)` which
4//! preserves the optimal policy (Ng et al., 1999), plus goal-distance
5//! potential computation for goal-conditioned RL.
6
7use crate::error::RloxError;
8
9/// Context passed to reward transforms for access to transition metadata.
10pub struct RewardContext<'a> {
11    pub dones: &'a [f64],
12    pub gamma: f64,
13    /// Arbitrary f64 slices keyed by name, for potential values, goal distances, etc.
14    pub extras: &'a [(&'a str, &'a [f64])],
15}
16
17/// Trait for composable reward transformations.
18///
19/// Each transform takes raw rewards and a context and produces shaped rewards.
20/// Transforms can be chained: `ClipReward -> PBRSTransform -> ScaleReward`.
21pub trait RewardTransform: Send + Sync {
22    /// Transform a batch of rewards.
23    fn transform(
24        &self,
25        rewards: &[f64],
26        context: &RewardContext<'_>,
27    ) -> Result<Vec<f64>, RloxError>;
28
29    /// Human-readable name.
30    fn name(&self) -> &str;
31}
32
33/// Potential-based reward shaping transform.
34///
35/// Expects `extras` to contain entries named `"phi_current"` and `"phi_next"`.
36pub struct PBRSTransform;
37
38impl RewardTransform for PBRSTransform {
39    fn transform(
40        &self,
41        rewards: &[f64],
42        context: &RewardContext<'_>,
43    ) -> Result<Vec<f64>, RloxError> {
44        let phi_current = context
45            .extras
46            .iter()
47            .find(|(name, _)| *name == "phi_current")
48            .map(|(_, v)| *v)
49            .ok_or_else(|| RloxError::BufferError("missing 'phi_current' in extras".into()))?;
50        let phi_next = context
51            .extras
52            .iter()
53            .find(|(name, _)| *name == "phi_next")
54            .map(|(_, v)| *v)
55            .ok_or_else(|| RloxError::BufferError("missing 'phi_next' in extras".into()))?;
56
57        shape_rewards_pbrs(rewards, phi_current, phi_next, context.gamma, context.dones)
58    }
59
60    fn name(&self) -> &str {
61        "PBRS"
62    }
63}
64
65/// Goal-distance reward transform.
66///
67/// Computes `phi(s) = -scale * ||s[goal_slice] - goal||` and applies PBRS.
68/// Expects `extras` to contain `"phi_current"` and `"phi_next"` entries
69/// pre-computed via [`compute_goal_distance_potentials`].
70pub struct GoalDistanceTransform {
71    pub scale: f64,
72    pub goal_start: usize,
73    pub goal_dim: usize,
74}
75
76impl RewardTransform for GoalDistanceTransform {
77    fn transform(
78        &self,
79        rewards: &[f64],
80        context: &RewardContext<'_>,
81    ) -> Result<Vec<f64>, RloxError> {
82        let phi_current = context
83            .extras
84            .iter()
85            .find(|(name, _)| *name == "phi_current")
86            .map(|(_, v)| *v)
87            .ok_or_else(|| RloxError::BufferError("missing 'phi_current' in extras".into()))?;
88        let phi_next = context
89            .extras
90            .iter()
91            .find(|(name, _)| *name == "phi_next")
92            .map(|(_, v)| *v)
93            .ok_or_else(|| RloxError::BufferError("missing 'phi_next' in extras".into()))?;
94
95        shape_rewards_pbrs(rewards, phi_current, phi_next, context.gamma, context.dones)
96    }
97
98    fn name(&self) -> &str {
99        "GoalDistance"
100    }
101}
102
103/// Compute shaped rewards: `r' = r + gamma * Phi(s') - Phi(s)`
104///
105/// At episode boundaries (`dones[i] == 1.0`), the potential difference
106/// is zeroed out: `r'_i = r_i` (no shaping across episode boundaries).
107///
108/// # Arguments
109/// * `rewards` - raw rewards, length N
110/// * `potentials_current` - Phi(s_t), length N
111/// * `potentials_next` - Phi(s_{t+1}), length N
112/// * `gamma` - discount factor
113/// * `dones` - episode termination flags (1.0 = done), length N
114#[inline]
115pub fn shape_rewards_pbrs(
116    rewards: &[f64],
117    potentials_current: &[f64],
118    potentials_next: &[f64],
119    gamma: f64,
120    dones: &[f64],
121) -> Result<Vec<f64>, RloxError> {
122    let n = rewards.len();
123    if potentials_current.len() != n || potentials_next.len() != n || dones.len() != n {
124        return Err(RloxError::ShapeMismatch {
125            expected: format!("all slices length {n}"),
126            got: format!(
127                "phi_current={}, phi_next={}, dones={}",
128                potentials_current.len(),
129                potentials_next.len(),
130                dones.len()
131            ),
132        });
133    }
134
135    let mut output = Vec::with_capacity(n);
136    for i in 0..n {
137        if dones[i] == 1.0 {
138            output.push(rewards[i]);
139        } else {
140            output.push(rewards[i] + gamma * potentials_next[i] - potentials_current[i]);
141        }
142    }
143    Ok(output)
144}
145
146/// Goal-distance potential: `Phi(s) = -scale * ||s[goal_slice] - goal||_2`
147///
148/// # Arguments
149/// * `observations` - flat `(N * obs_dim)` array
150/// * `goal` - target goal vector, length `goal_dim`
151/// * `obs_dim` - dimensionality of each observation
152/// * `goal_start` - starting index within obs where goal-relevant dims begin
153/// * `goal_dim` - number of goal-relevant dimensions
154/// * `scale` - scaling factor for the potential
155#[inline]
156pub fn compute_goal_distance_potentials(
157    observations: &[f64],
158    goal: &[f64],
159    obs_dim: usize,
160    goal_start: usize,
161    goal_dim: usize,
162    scale: f64,
163) -> Result<Vec<f64>, RloxError> {
164    if goal.len() != goal_dim {
165        return Err(RloxError::ShapeMismatch {
166            expected: format!("goal.len() == goal_dim={goal_dim}"),
167            got: format!("goal.len()={}", goal.len()),
168        });
169    }
170    if obs_dim == 0 {
171        return Err(RloxError::ShapeMismatch {
172            expected: "obs_dim > 0".into(),
173            got: "obs_dim=0".into(),
174        });
175    }
176    if !observations.len().is_multiple_of(obs_dim) {
177        return Err(RloxError::ShapeMismatch {
178            expected: format!("observations.len() divisible by obs_dim={obs_dim}"),
179            got: format!("observations.len()={}", observations.len()),
180        });
181    }
182    if goal_start + goal_dim > obs_dim {
183        return Err(RloxError::ShapeMismatch {
184            expected: format!("goal_start + goal_dim <= obs_dim={obs_dim}"),
185            got: format!("goal_start={goal_start}, goal_dim={goal_dim}"),
186        });
187    }
188
189    let n = observations.len() / obs_dim;
190    let mut potentials = Vec::with_capacity(n);
191
192    for i in 0..n {
193        let obs_start = i * obs_dim + goal_start;
194        let obs_slice = &observations[obs_start..obs_start + goal_dim];
195        let dist_sq: f64 = obs_slice
196            .iter()
197            .zip(goal.iter())
198            .map(|(&o, &g)| (o - g) * (o - g))
199            .sum();
200        potentials.push(-scale * dist_sq.sqrt());
201    }
202
203    Ok(potentials)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_pbrs_known_values() {
212        let rewards = &[1.0, 2.0];
213        let phi = &[0.5, 0.3];
214        let phi_next = &[0.3, 0.8];
215        let gamma = 0.99;
216        let dones = &[0.0, 0.0];
217        let result = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
218        // r'[0] = 1.0 + 0.99*0.3 - 0.5 = 0.797
219        assert!(
220            (result[0] - 0.797).abs() < 1e-10,
221            "expected 0.797, got {}",
222            result[0]
223        );
224        // r'[1] = 2.0 + 0.99*0.8 - 0.3 = 2.492
225        assert!(
226            (result[1] - 2.492).abs() < 1e-10,
227            "expected 2.492, got {}",
228            result[1]
229        );
230    }
231
232    #[test]
233    fn test_pbrs_done_resets_potential() {
234        let rewards = &[1.0, 2.0];
235        let phi = &[0.5, 0.3];
236        let phi_next = &[0.3, 0.8];
237        let gamma = 0.99;
238        let dones = &[0.0, 1.0];
239        let result = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
240        assert!(
241            (result[1] - 2.0).abs() < 1e-10,
242            "done step should return raw reward, got {}",
243            result[1]
244        );
245    }
246
247    #[test]
248    fn test_pbrs_zero_potentials_no_change() {
249        let rewards = &[1.0, 2.0, 3.0];
250        let phi = &[0.0, 0.0, 0.0];
251        let phi_next = &[0.0, 0.0, 0.0];
252        let dones = &[0.0, 0.0, 0.0];
253        let result = shape_rewards_pbrs(rewards, phi, phi_next, 0.99, dones).unwrap();
254        for i in 0..3 {
255            assert!(
256                (result[i] - rewards[i]).abs() < 1e-10,
257                "zero potentials should not change reward"
258            );
259        }
260    }
261
262    #[test]
263    fn test_pbrs_preserves_optimal_policy() {
264        // For a complete episode, sum(shaped) = sum(raw) + gamma*Phi(s_T) - Phi(s_0)
265        // With done at the last step, the last step gets raw reward
266        let rewards = &[1.0, 2.0, 3.0, 4.0, 5.0];
267        let phi = &[0.5, 0.3, 0.8, 0.1, 0.9];
268        let phi_next = &[0.3, 0.8, 0.1, 0.9, 0.0];
269        let gamma = 0.99;
270        // No dones (all within episode)
271        let dones = &[0.0, 0.0, 0.0, 0.0, 0.0];
272        let shaped = shape_rewards_pbrs(rewards, phi, phi_next, gamma, dones).unwrap();
273        let sum_raw: f64 = rewards.iter().sum();
274        let sum_shaped: f64 = shaped.iter().sum();
275        // Telescoping sum: sum(gamma*phi_next[i] - phi[i]) for non-done steps
276        let shaping_sum: f64 = (0..5).map(|i| gamma * phi_next[i] - phi[i]).sum::<f64>();
277        assert!(
278            (sum_shaped - (sum_raw + shaping_sum)).abs() < 1e-10,
279            "sum_shaped={sum_shaped}, expected {}",
280            sum_raw + shaping_sum
281        );
282    }
283
284    #[test]
285    fn test_pbrs_length_mismatch_errors() {
286        let result = shape_rewards_pbrs(
287            &[1.0, 2.0, 3.0],
288            &[0.0, 0.0],
289            &[0.0, 0.0],
290            0.99,
291            &[0.0, 0.0],
292        );
293        assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
294    }
295
296    #[test]
297    fn test_goal_distance_decreasing_near_goal() {
298        let goal = &[0.0, 0.0];
299        // obs far from goal, obs close to goal
300        let observations = &[
301            1.0, 0.0, 1.0, 0.0, // obs_dim=4, goal at [0..2], far (dist=1)
302            0.1, 0.0, 0.1, 0.0, // close (dist=0.1)
303        ];
304        let potentials =
305            compute_goal_distance_potentials(observations, goal, 4, 0, 2, 1.0).unwrap();
306        assert!(
307            potentials[1] > potentials[0],
308            "closer obs should have less negative potential: far={}, close={}",
309            potentials[0],
310            potentials[1]
311        );
312    }
313
314    #[test]
315    fn test_goal_distance_at_goal_is_zero() {
316        let goal = &[3.0, 4.0];
317        let observations = &[3.0, 4.0, 0.0, 0.0]; // obs_dim=4, goal at [0..2]
318        let potentials =
319            compute_goal_distance_potentials(observations, goal, 4, 0, 2, 1.0).unwrap();
320        assert!(
321            potentials[0].abs() < 1e-10,
322            "at goal, potential should be 0, got {}",
323            potentials[0]
324        );
325    }
326
327    #[test]
328    fn test_goal_distance_scale_factor() {
329        let goal = &[0.0];
330        let observations = &[1.0, 0.0]; // obs_dim=2, goal at [0..1]
331        let phi_1 = compute_goal_distance_potentials(observations, goal, 2, 0, 1, 1.0).unwrap();
332        let phi_2 = compute_goal_distance_potentials(observations, goal, 2, 0, 1, 2.0).unwrap();
333        assert!(
334            (phi_2[0] - 2.0 * phi_1[0]).abs() < 1e-10,
335            "scale=2 should double potential: phi_1={}, phi_2={}",
336            phi_1[0],
337            phi_2[0]
338        );
339    }
340
341    #[test]
342    fn test_goal_distance_validates_dimensions() {
343        let result = compute_goal_distance_potentials(
344            &[1.0, 2.0, 3.0, 4.0],
345            &[0.0, 0.0, 0.0], // goal_dim=3 but we say 2
346            4,
347            0,
348            2,
349            1.0,
350        );
351        assert!(matches!(result, Err(RloxError::ShapeMismatch { .. })));
352    }
353
354    #[test]
355    fn test_trait_object_safety() {
356        let transform: Box<dyn RewardTransform> = Box::new(PBRSTransform);
357        assert_eq!(transform.name(), "PBRS");
358    }
359
360    mod proptests {
361        use super::*;
362        use proptest::prelude::*;
363
364        proptest! {
365            #[test]
366            fn prop_pbrs_length_matches_input(n in 1usize..500) {
367                let rewards: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
368                let phi: Vec<f64> = vec![0.5; n];
369                let phi_next: Vec<f64> = vec![0.3; n];
370                let dones: Vec<f64> = vec![0.0; n];
371                let result = shape_rewards_pbrs(&rewards, &phi, &phi_next, 0.99, &dones).unwrap();
372                prop_assert_eq!(result.len(), n);
373            }
374
375            #[test]
376            fn prop_pbrs_zero_gamma_no_future(n in 1usize..100) {
377                let rewards: Vec<f64> = (0..n).map(|i| i as f64).collect();
378                let phi: Vec<f64> = (0..n).map(|i| i as f64 * 0.5).collect();
379                let phi_next: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
380                let dones: Vec<f64> = vec![0.0; n];
381                let result = shape_rewards_pbrs(&rewards, &phi, &phi_next, 0.0, &dones).unwrap();
382                for i in 0..n {
383                    let expected = rewards[i] - phi[i]; // gamma=0 -> no phi_next term
384                    prop_assert!(
385                        (result[i] - expected).abs() < 1e-10,
386                        "index {i}: got {}, expected {expected}",
387                        result[i]
388                    );
389                }
390            }
391
392            #[test]
393            fn prop_goal_distance_non_positive(
394                n in 1usize..50,
395                goal_dim in 1usize..4,
396            ) {
397                let obs_dim = goal_dim + 2;
398                let obs: Vec<f64> = (0..(n * obs_dim)).map(|i| i as f64 * 0.1).collect();
399                let goal: Vec<f64> = vec![0.0; goal_dim];
400                let potentials = compute_goal_distance_potentials(
401                    &obs, &goal, obs_dim, 0, goal_dim, 1.0
402                ).unwrap();
403                for (i, &p) in potentials.iter().enumerate() {
404                    prop_assert!(p <= 0.0 + 1e-10,
405                        "potential[{i}] = {p} should be <= 0 for positive scale");
406                }
407            }
408        }
409    }
410}