rlox_core/buffer/
her.rs

1//! Hindsight Experience Replay (HER) buffer.
2//!
3//! Stores transitions with goal information and performs goal relabeling
4//! during sampling (Andrychowicz et al., 2017). Supports Final, Future(k),
5//! and Episode relabeling strategies.
6
7use rand::Rng;
8use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10
11use crate::error::RloxError;
12
13use super::episode::{EpisodeMeta, EpisodeTracker};
14use super::ringbuf::{ReplayBuffer, SampledBatch};
15
16/// HER goal relabeling strategy.
17#[derive(Debug, Clone, Copy)]
18pub enum HERStrategy {
19    /// Replace goal with the final state achieved in the episode.
20    Final,
21    /// Replace goal with a future state sampled uniformly from the remainder.
22    Future {
23        /// Number of relabeled goals per original transition. Default: 4.
24        k: usize,
25    },
26    /// Replace goal with a random state from the episode.
27    Episode,
28}
29
30impl Default for HERStrategy {
31    fn default() -> Self {
32        HERStrategy::Future { k: 4 }
33    }
34}
35
36/// Hindsight Experience Replay buffer.
37///
38/// Stores transitions with goal information and performs goal relabeling
39/// during sampling. The obs vector layout is:
40/// `[obs_core | achieved_goal | desired_goal | ...]`
41#[derive(Debug)]
42pub struct HERBuffer {
43    buffer: ReplayBuffer,
44    tracker: EpisodeTracker,
45    obs_dim: usize,
46    act_dim: usize,
47    goal_dim: usize,
48    achieved_goal_start: usize,
49    desired_goal_start: usize,
50    capacity: usize,
51    strategy: HERStrategy,
52    goal_tolerance: f32,
53}
54
55impl HERBuffer {
56    /// Create a new HER buffer.
57    ///
58    /// # Arguments
59    /// * `capacity` - maximum transitions
60    /// * `obs_dim` - full observation dimension (includes goal components)
61    /// * `act_dim` - action dimension
62    /// * `goal_dim` - goal vector dimension
63    /// * `achieved_goal_start` - index within obs where achieved goal starts
64    /// * `desired_goal_start` - index within obs where desired goal starts
65    /// * `strategy` - relabeling strategy
66    /// * `goal_tolerance` - tolerance for sparse reward computation
67    pub fn new(
68        capacity: usize,
69        obs_dim: usize,
70        act_dim: usize,
71        goal_dim: usize,
72        achieved_goal_start: usize,
73        desired_goal_start: usize,
74        strategy: HERStrategy,
75        goal_tolerance: f32,
76    ) -> Self {
77        Self {
78            buffer: ReplayBuffer::new(capacity, obs_dim, act_dim),
79            tracker: EpisodeTracker::new(capacity),
80            obs_dim,
81            act_dim,
82            goal_dim,
83            achieved_goal_start,
84            desired_goal_start,
85            capacity,
86            strategy,
87            goal_tolerance,
88        }
89    }
90
91    /// Push a single transition, notifying the episode tracker.
92    pub fn push_slices(
93        &mut self,
94        obs: &[f32],
95        next_obs: &[f32],
96        action: &[f32],
97        reward: f32,
98        terminated: bool,
99        truncated: bool,
100    ) -> Result<(), RloxError> {
101        let write_pos = self.buffer.write_pos();
102        let was_full = self.buffer.len() == self.capacity;
103
104        if was_full {
105            self.tracker.invalidate_overwritten(write_pos, 1);
106        }
107
108        self.buffer
109            .push_slices(obs, next_obs, action, reward, terminated, truncated)?;
110
111        let done = terminated || truncated;
112        self.tracker.notify_push(write_pos, done);
113
114        Ok(())
115    }
116
117    /// Sample a batch with HER relabeling.
118    ///
119    /// `her_ratio` controls the fraction of samples that get relabeled goals.
120    /// The remaining samples use their original goals.
121    pub fn sample_with_relabeling(
122        &self,
123        batch_size: usize,
124        her_ratio: f32,
125        seed: u64,
126    ) -> Result<SampledBatch, RloxError> {
127        if self.buffer.is_empty() {
128            return Err(RloxError::BufferError("buffer is empty".into()));
129        }
130
131        let episodes = self.tracker.episodes();
132        let complete: Vec<usize> = episodes
133            .iter()
134            .enumerate()
135            .filter(|(_, ep)| ep.complete)
136            .map(|(i, _)| i)
137            .collect();
138
139        if complete.is_empty() {
140            return Err(RloxError::BufferError(
141                "no complete episodes for HER relabeling".into(),
142            ));
143        }
144
145        let mut rng = ChaCha8Rng::seed_from_u64(seed);
146        let n_relabeled = ((batch_size as f32) * her_ratio).ceil() as usize;
147        let n_original = batch_size - n_relabeled;
148
149        let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
150
151        // Sample original (unrelabeled) transitions
152        if n_original > 0 {
153            let original = self.buffer.sample(n_original, rng.random())?;
154            batch.observations.extend_from_slice(&original.observations);
155            batch
156                .next_observations
157                .extend_from_slice(&original.next_observations);
158            batch.actions.extend_from_slice(&original.actions);
159            batch.rewards.extend_from_slice(&original.rewards);
160            batch.terminated.extend_from_slice(&original.terminated);
161            batch.truncated.extend_from_slice(&original.truncated);
162        }
163
164        // Sample relabeled transitions
165        for _ in 0..n_relabeled {
166            // Pick a random complete episode
167            let ep_idx = complete[rng.random_range(0..complete.len())];
168            let ep = &episodes[ep_idx];
169
170            // Pick a random transition within the episode
171            let trans_offset = rng.random_range(0..ep.length);
172            let trans_idx = (ep.start + trans_offset) % self.capacity;
173
174            // Get the original transition
175            let (obs, next_obs, action, _reward, terminated, truncated) =
176                self.buffer.get(trans_idx);
177
178            // Compute the relabel index based on strategy
179            let relabel_offset = match self.strategy {
180                HERStrategy::Final => ep.length - 1,
181                HERStrategy::Future { .. } => {
182                    if trans_offset >= ep.length - 1 {
183                        // Already at the end, use the same position
184                        trans_offset
185                    } else {
186                        rng.random_range((trans_offset + 1)..ep.length)
187                    }
188                }
189                HERStrategy::Episode => rng.random_range(0..ep.length),
190            };
191            let relabel_idx = (ep.start + relabel_offset) % self.capacity;
192
193            // Get the achieved goal from the relabel transition's next_obs
194            let (_, relabel_next_obs, _, _, _, _) = self.buffer.get(relabel_idx);
195            let new_goal = &relabel_next_obs
196                [self.achieved_goal_start..self.achieved_goal_start + self.goal_dim];
197
198            // Create modified observation with new desired goal
199            let mut new_obs = obs.to_vec();
200            new_obs[self.desired_goal_start..self.desired_goal_start + self.goal_dim]
201                .copy_from_slice(new_goal);
202
203            let mut new_next_obs = next_obs.to_vec();
204            new_next_obs[self.desired_goal_start..self.desired_goal_start + self.goal_dim]
205                .copy_from_slice(new_goal);
206
207            // Compute new reward based on achieved goal in next_obs vs new desired goal
208            let achieved_in_next =
209                &next_obs[self.achieved_goal_start..self.achieved_goal_start + self.goal_dim];
210            let new_reward = sparse_goal_reward(achieved_in_next, new_goal, self.goal_tolerance);
211
212            batch.observations.extend_from_slice(&new_obs);
213            batch.next_observations.extend_from_slice(&new_next_obs);
214            batch.actions.extend_from_slice(action);
215            batch.rewards.push(new_reward);
216            batch.terminated.push(terminated);
217            batch.truncated.push(truncated);
218        }
219
220        batch.batch_size = batch_size;
221
222        Ok(batch)
223    }
224
225    /// Compute relabeling indices for a given episode and transition.
226    ///
227    /// Returns indices (offsets within the episode) to use as substitute goals.
228    pub fn compute_relabel_indices(
229        &self,
230        episode: &EpisodeMeta,
231        transition_offset: usize,
232        seed: u64,
233    ) -> Vec<usize> {
234        let mut rng = ChaCha8Rng::seed_from_u64(seed);
235        match self.strategy {
236            HERStrategy::Final => vec![episode.length - 1],
237            HERStrategy::Future { k } => {
238                if transition_offset >= episode.length - 1 {
239                    // At the last step, can only relabel with itself
240                    vec![transition_offset; k]
241                } else {
242                    (0..k)
243                        .map(|_| rng.random_range((transition_offset + 1)..episode.length))
244                        .collect()
245                }
246            }
247            HERStrategy::Episode => {
248                vec![rng.random_range(0..episode.length)]
249            }
250        }
251    }
252
253    /// Number of valid transitions currently stored.
254    pub fn len(&self) -> usize {
255        self.buffer.len()
256    }
257
258    /// Whether the buffer is empty.
259    pub fn is_empty(&self) -> bool {
260        self.buffer.is_empty()
261    }
262
263    /// Number of complete episodes currently tracked.
264    pub fn num_complete_episodes(&self) -> usize {
265        self.tracker.num_complete_episodes()
266    }
267}
268
269/// Compute sparse goal-conditioned reward.
270///
271/// Returns `0.0` if `||achieved - desired||_2 < tolerance`, else `-1.0`.
272///
273/// Uses squared distance comparison to avoid a costly `sqrt`.
274#[inline]
275pub fn sparse_goal_reward(achieved: &[f32], desired: &[f32], tolerance: f32) -> f32 {
276    let dist_sq: f32 = achieved
277        .iter()
278        .zip(desired.iter())
279        .map(|(&a, &d)| (a - d) * (a - d))
280        .sum();
281    if dist_sq < tolerance * tolerance {
282        0.0
283    } else {
284        -1.0
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    /// Helper: build an obs vector with embedded achieved and desired goals.
293    /// Layout: [core(2) | achieved_goal(goal_dim) | desired_goal(goal_dim)]
294    fn make_obs(core: &[f32], achieved: &[f32], desired: &[f32]) -> Vec<f32> {
295        let mut obs = Vec::with_capacity(core.len() + achieved.len() + desired.len());
296        obs.extend_from_slice(core);
297        obs.extend_from_slice(achieved);
298        obs.extend_from_slice(desired);
299        obs
300    }
301
302    fn make_her_buffer(capacity: usize, goal_dim: usize) -> HERBuffer {
303        let core_dim = 2;
304        let obs_dim = core_dim + goal_dim * 2; // core + achieved + desired
305        HERBuffer::new(
306            capacity,
307            obs_dim,
308            1, // act_dim
309            goal_dim,
310            core_dim,               // achieved_goal_start
311            core_dim + goal_dim,    // desired_goal_start
312            HERStrategy::default(), // Future { k: 4 }
313            0.05,                   // goal_tolerance
314        )
315    }
316
317    /// Push an episode where the agent moves from origin toward a goal.
318    fn push_goal_episode(buf: &mut HERBuffer, length: usize, goal_dim: usize) {
319        let desired_goal = vec![10.0; goal_dim];
320        for i in 0..length {
321            let progress = (i as f32 + 1.0) / length as f32;
322            let achieved = vec![10.0 * progress; goal_dim];
323            let next_achieved = vec![10.0 * (progress + 1.0 / length as f32).min(1.0); goal_dim];
324            let core = vec![progress, progress];
325
326            let obs = make_obs(&core, &achieved, &desired_goal);
327            let next_obs = make_obs(
328                &[progress + 0.1, progress + 0.1],
329                &next_achieved,
330                &desired_goal,
331            );
332            let action = vec![0.0];
333            let reward = -1.0; // sparse: not at goal yet
334            let done = i == length - 1;
335
336            buf.push_slices(&obs, &next_obs, &action, reward, done, false)
337                .unwrap();
338        }
339    }
340
341    #[test]
342    fn test_her_new_is_empty() {
343        let buf = make_her_buffer(100, 3);
344        assert_eq!(buf.len(), 0);
345        assert!(buf.is_empty());
346    }
347
348    #[test]
349    fn test_her_push_increments() {
350        let mut buf = make_her_buffer(100, 3);
351        push_goal_episode(&mut buf, 5, 3);
352        assert_eq!(buf.len(), 5);
353        assert_eq!(buf.num_complete_episodes(), 1);
354    }
355
356    #[test]
357    fn test_final_strategy_uses_last_state() {
358        let goal_dim = 2;
359        let core_dim = 2;
360        let obs_dim = core_dim + goal_dim * 2;
361        let mut buf = HERBuffer::new(
362            100,
363            obs_dim,
364            1,
365            goal_dim,
366            core_dim,
367            core_dim + goal_dim,
368            HERStrategy::Final,
369            0.05,
370        );
371        push_goal_episode(&mut buf, 5, goal_dim);
372
373        let ep = &buf.tracker.episodes()[0];
374        let indices = buf.compute_relabel_indices(ep, 2, 42);
375        assert_eq!(indices.len(), 1);
376        assert_eq!(indices[0], 4); // last step of episode (length 5)
377    }
378
379    #[test]
380    fn test_future_strategy_picks_future_state() {
381        let goal_dim = 2;
382        let core_dim = 2;
383        let obs_dim = core_dim + goal_dim * 2;
384        let mut buf = HERBuffer::new(
385            100,
386            obs_dim,
387            1,
388            goal_dim,
389            core_dim,
390            core_dim + goal_dim,
391            HERStrategy::Future { k: 4 },
392            0.05,
393        );
394        push_goal_episode(&mut buf, 10, goal_dim);
395
396        let ep = &buf.tracker.episodes()[0];
397        let indices = buf.compute_relabel_indices(ep, 3, 42);
398        assert_eq!(indices.len(), 4);
399        for &idx in &indices {
400            assert!(
401                idx > 3,
402                "future index {idx} should be > transition offset 3"
403            );
404            assert!(idx < 10, "future index {idx} should be < episode length 10");
405        }
406    }
407
408    #[test]
409    fn test_episode_strategy_picks_any_state() {
410        let goal_dim = 2;
411        let core_dim = 2;
412        let obs_dim = core_dim + goal_dim * 2;
413        let mut buf = HERBuffer::new(
414            100,
415            obs_dim,
416            1,
417            goal_dim,
418            core_dim,
419            core_dim + goal_dim,
420            HERStrategy::Episode,
421            0.05,
422        );
423        push_goal_episode(&mut buf, 10, goal_dim);
424
425        let ep = &buf.tracker.episodes()[0];
426        let indices = buf.compute_relabel_indices(ep, 5, 42);
427        assert_eq!(indices.len(), 1);
428        assert!(indices[0] < 10);
429    }
430
431    #[test]
432    fn test_sparse_goal_reward_achieved() {
433        let achieved = [1.0, 2.0, 3.0];
434        let desired = [1.0, 2.0, 3.0];
435        assert_eq!(sparse_goal_reward(&achieved, &desired, 0.05), 0.0);
436    }
437
438    #[test]
439    fn test_sparse_goal_reward_not_achieved() {
440        let achieved = [1.0, 2.0, 3.0];
441        let desired = [10.0, 20.0, 30.0];
442        assert_eq!(sparse_goal_reward(&achieved, &desired, 0.05), -1.0);
443    }
444
445    #[test]
446    fn test_relabel_indices_future_k4() {
447        let goal_dim = 2;
448        let core_dim = 2;
449        let obs_dim = core_dim + goal_dim * 2;
450        let mut buf = HERBuffer::new(
451            100,
452            obs_dim,
453            1,
454            goal_dim,
455            core_dim,
456            core_dim + goal_dim,
457            HERStrategy::Future { k: 4 },
458            0.05,
459        );
460        push_goal_episode(&mut buf, 10, goal_dim);
461
462        let ep = &buf.tracker.episodes()[0];
463        let indices = buf.compute_relabel_indices(ep, 3, 42);
464        assert_eq!(indices.len(), 4);
465        for &idx in &indices {
466            assert!(idx > 3 && idx < 10);
467        }
468    }
469
470    #[test]
471    fn test_relabel_indices_deterministic() {
472        let goal_dim = 2;
473        let core_dim = 2;
474        let obs_dim = core_dim + goal_dim * 2;
475        let mut buf = HERBuffer::new(
476            100,
477            obs_dim,
478            1,
479            goal_dim,
480            core_dim,
481            core_dim + goal_dim,
482            HERStrategy::Future { k: 4 },
483            0.05,
484        );
485        push_goal_episode(&mut buf, 10, goal_dim);
486
487        let ep = &buf.tracker.episodes()[0];
488        let i1 = buf.compute_relabel_indices(ep, 3, 42);
489        let i2 = buf.compute_relabel_indices(ep, 3, 42);
490        assert_eq!(i1, i2);
491    }
492
493    #[test]
494    fn test_her_sample_batch_shape() {
495        let goal_dim = 3;
496        let mut buf = make_her_buffer(100, goal_dim);
497        push_goal_episode(&mut buf, 10, goal_dim);
498
499        let batch = buf.sample_with_relabeling(8, 0.8, 42).unwrap();
500        let obs_dim = 2 + goal_dim * 2;
501        assert_eq!(batch.batch_size, 8);
502        assert_eq!(batch.observations.len(), 8 * obs_dim);
503        assert_eq!(batch.actions.len(), 8);
504        assert_eq!(batch.rewards.len(), 8);
505    }
506
507    #[test]
508    fn test_her_ratio_controls_relabeling() {
509        let goal_dim = 2;
510        let mut buf = make_her_buffer(200, goal_dim);
511        // Push multiple episodes so we have enough data
512        for _ in 0..10 {
513            push_goal_episode(&mut buf, 10, goal_dim);
514        }
515
516        // With ratio=0.0, no relabeling => all rewards should be -1.0 (original)
517        let batch = buf.sample_with_relabeling(32, 0.0, 42).unwrap();
518        // All original rewards are -1.0
519        for &r in &batch.rewards {
520            assert_eq!(
521                r, -1.0,
522                "with ratio=0, all rewards should be original (-1.0)"
523            );
524        }
525    }
526
527    #[test]
528    fn test_her_with_ring_wrap() {
529        let goal_dim = 2;
530        let mut buf = make_her_buffer(50, goal_dim);
531        // Push 100 transitions (wraps around)
532        for _ in 0..10 {
533            push_goal_episode(&mut buf, 10, goal_dim);
534        }
535        assert_eq!(buf.len(), 50);
536        // Should still be able to sample
537        let result = buf.sample_with_relabeling(4, 0.8, 42);
538        assert!(result.is_ok());
539    }
540
541    #[test]
542    fn test_her_empty_buffer_errors() {
543        let buf = make_her_buffer(100, 3);
544        let result = buf.sample_with_relabeling(4, 0.8, 42);
545        assert!(result.is_err());
546    }
547
548    mod proptests {
549        use super::*;
550        use proptest::prelude::*;
551
552        proptest! {
553            #[test]
554            fn prop_relabel_indices_in_range(
555                ep_len in 2usize..20,
556                trans_offset in 0usize..19,
557            ) {
558                let trans_offset = trans_offset.min(ep_len - 1);
559                let goal_dim = 2;
560                let core_dim = 2;
561                let obs_dim = core_dim + goal_dim * 2;
562                let buf = HERBuffer::new(
563                    100, obs_dim, 1, goal_dim, core_dim, core_dim + goal_dim,
564                    HERStrategy::Future { k: 4 }, 0.05,
565                );
566                let ep = EpisodeMeta { start: 0, length: ep_len, complete: true };
567                let indices = buf.compute_relabel_indices(&ep, trans_offset, 42);
568                for &idx in &indices {
569                    prop_assert!(idx < ep_len, "index {idx} >= episode length {ep_len}");
570                }
571            }
572
573            #[test]
574            fn prop_future_indices_strictly_future(
575                ep_len in 3usize..20,
576                trans_offset in 0usize..18,
577            ) {
578                let trans_offset = trans_offset.min(ep_len - 2); // ensure room for future
579                let goal_dim = 2;
580                let core_dim = 2;
581                let obs_dim = core_dim + goal_dim * 2;
582                let buf = HERBuffer::new(
583                    100, obs_dim, 1, goal_dim, core_dim, core_dim + goal_dim,
584                    HERStrategy::Future { k: 4 }, 0.05,
585                );
586                let ep = EpisodeMeta { start: 0, length: ep_len, complete: true };
587                let indices = buf.compute_relabel_indices(&ep, trans_offset, 42);
588                for &idx in &indices {
589                    prop_assert!(idx > trans_offset,
590                        "future index {idx} should be > offset {trans_offset}");
591                }
592            }
593
594            #[test]
595            fn prop_sparse_reward_binary(
596                a0 in -10.0f32..10.0,
597                a1 in -10.0f32..10.0,
598                d0 in -10.0f32..10.0,
599                d1 in -10.0f32..10.0,
600            ) {
601                let r = sparse_goal_reward(&[a0, a1], &[d0, d1], 0.05);
602                prop_assert!(r == 0.0 || r == -1.0, "reward should be 0.0 or -1.0, got {r}");
603            }
604        }
605    }
606}