1use rand::Rng;
8use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10
11use crate::error::RloxError;
12
13use super::episode::{EpisodeMeta, EpisodeTracker};
14use super::ringbuf::{ReplayBuffer, SampledBatch};
15
16#[derive(Debug, Clone, Copy)]
18pub enum HERStrategy {
19 Final,
21 Future {
23 k: usize,
25 },
26 Episode,
28}
29
30impl Default for HERStrategy {
31 fn default() -> Self {
32 HERStrategy::Future { k: 4 }
33 }
34}
35
36#[derive(Debug)]
42pub struct HERBuffer {
43 buffer: ReplayBuffer,
44 tracker: EpisodeTracker,
45 obs_dim: usize,
46 act_dim: usize,
47 goal_dim: usize,
48 achieved_goal_start: usize,
49 desired_goal_start: usize,
50 capacity: usize,
51 strategy: HERStrategy,
52 goal_tolerance: f32,
53}
54
55impl HERBuffer {
56 pub fn new(
68 capacity: usize,
69 obs_dim: usize,
70 act_dim: usize,
71 goal_dim: usize,
72 achieved_goal_start: usize,
73 desired_goal_start: usize,
74 strategy: HERStrategy,
75 goal_tolerance: f32,
76 ) -> Self {
77 Self {
78 buffer: ReplayBuffer::new(capacity, obs_dim, act_dim),
79 tracker: EpisodeTracker::new(capacity),
80 obs_dim,
81 act_dim,
82 goal_dim,
83 achieved_goal_start,
84 desired_goal_start,
85 capacity,
86 strategy,
87 goal_tolerance,
88 }
89 }
90
91 pub fn push_slices(
93 &mut self,
94 obs: &[f32],
95 next_obs: &[f32],
96 action: &[f32],
97 reward: f32,
98 terminated: bool,
99 truncated: bool,
100 ) -> Result<(), RloxError> {
101 let write_pos = self.buffer.write_pos();
102 let was_full = self.buffer.len() == self.capacity;
103
104 if was_full {
105 self.tracker.invalidate_overwritten(write_pos, 1);
106 }
107
108 self.buffer
109 .push_slices(obs, next_obs, action, reward, terminated, truncated)?;
110
111 let done = terminated || truncated;
112 self.tracker.notify_push(write_pos, done);
113
114 Ok(())
115 }
116
117 pub fn sample_with_relabeling(
122 &self,
123 batch_size: usize,
124 her_ratio: f32,
125 seed: u64,
126 ) -> Result<SampledBatch, RloxError> {
127 if self.buffer.is_empty() {
128 return Err(RloxError::BufferError("buffer is empty".into()));
129 }
130
131 let episodes = self.tracker.episodes();
132 let complete: Vec<usize> = episodes
133 .iter()
134 .enumerate()
135 .filter(|(_, ep)| ep.complete)
136 .map(|(i, _)| i)
137 .collect();
138
139 if complete.is_empty() {
140 return Err(RloxError::BufferError(
141 "no complete episodes for HER relabeling".into(),
142 ));
143 }
144
145 let mut rng = ChaCha8Rng::seed_from_u64(seed);
146 let n_relabeled = ((batch_size as f32) * her_ratio).ceil() as usize;
147 let n_original = batch_size - n_relabeled;
148
149 let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
150
151 if n_original > 0 {
153 let original = self.buffer.sample(n_original, rng.random())?;
154 batch.observations.extend_from_slice(&original.observations);
155 batch
156 .next_observations
157 .extend_from_slice(&original.next_observations);
158 batch.actions.extend_from_slice(&original.actions);
159 batch.rewards.extend_from_slice(&original.rewards);
160 batch.terminated.extend_from_slice(&original.terminated);
161 batch.truncated.extend_from_slice(&original.truncated);
162 }
163
164 for _ in 0..n_relabeled {
166 let ep_idx = complete[rng.random_range(0..complete.len())];
168 let ep = &episodes[ep_idx];
169
170 let trans_offset = rng.random_range(0..ep.length);
172 let trans_idx = (ep.start + trans_offset) % self.capacity;
173
174 let (obs, next_obs, action, _reward, terminated, truncated) =
176 self.buffer.get(trans_idx);
177
178 let relabel_offset = match self.strategy {
180 HERStrategy::Final => ep.length - 1,
181 HERStrategy::Future { .. } => {
182 if trans_offset >= ep.length - 1 {
183 trans_offset
185 } else {
186 rng.random_range((trans_offset + 1)..ep.length)
187 }
188 }
189 HERStrategy::Episode => rng.random_range(0..ep.length),
190 };
191 let relabel_idx = (ep.start + relabel_offset) % self.capacity;
192
193 let (_, relabel_next_obs, _, _, _, _) = self.buffer.get(relabel_idx);
195 let new_goal = &relabel_next_obs
196 [self.achieved_goal_start..self.achieved_goal_start + self.goal_dim];
197
198 let mut new_obs = obs.to_vec();
200 new_obs[self.desired_goal_start..self.desired_goal_start + self.goal_dim]
201 .copy_from_slice(new_goal);
202
203 let mut new_next_obs = next_obs.to_vec();
204 new_next_obs[self.desired_goal_start..self.desired_goal_start + self.goal_dim]
205 .copy_from_slice(new_goal);
206
207 let achieved_in_next =
209 &next_obs[self.achieved_goal_start..self.achieved_goal_start + self.goal_dim];
210 let new_reward = sparse_goal_reward(achieved_in_next, new_goal, self.goal_tolerance);
211
212 batch.observations.extend_from_slice(&new_obs);
213 batch.next_observations.extend_from_slice(&new_next_obs);
214 batch.actions.extend_from_slice(action);
215 batch.rewards.push(new_reward);
216 batch.terminated.push(terminated);
217 batch.truncated.push(truncated);
218 }
219
220 batch.batch_size = batch_size;
221
222 Ok(batch)
223 }
224
225 pub fn compute_relabel_indices(
229 &self,
230 episode: &EpisodeMeta,
231 transition_offset: usize,
232 seed: u64,
233 ) -> Vec<usize> {
234 let mut rng = ChaCha8Rng::seed_from_u64(seed);
235 match self.strategy {
236 HERStrategy::Final => vec![episode.length - 1],
237 HERStrategy::Future { k } => {
238 if transition_offset >= episode.length - 1 {
239 vec![transition_offset; k]
241 } else {
242 (0..k)
243 .map(|_| rng.random_range((transition_offset + 1)..episode.length))
244 .collect()
245 }
246 }
247 HERStrategy::Episode => {
248 vec![rng.random_range(0..episode.length)]
249 }
250 }
251 }
252
253 pub fn len(&self) -> usize {
255 self.buffer.len()
256 }
257
258 pub fn is_empty(&self) -> bool {
260 self.buffer.is_empty()
261 }
262
263 pub fn num_complete_episodes(&self) -> usize {
265 self.tracker.num_complete_episodes()
266 }
267}
268
269#[inline]
275pub fn sparse_goal_reward(achieved: &[f32], desired: &[f32], tolerance: f32) -> f32 {
276 let dist_sq: f32 = achieved
277 .iter()
278 .zip(desired.iter())
279 .map(|(&a, &d)| (a - d) * (a - d))
280 .sum();
281 if dist_sq < tolerance * tolerance {
282 0.0
283 } else {
284 -1.0
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn make_obs(core: &[f32], achieved: &[f32], desired: &[f32]) -> Vec<f32> {
295 let mut obs = Vec::with_capacity(core.len() + achieved.len() + desired.len());
296 obs.extend_from_slice(core);
297 obs.extend_from_slice(achieved);
298 obs.extend_from_slice(desired);
299 obs
300 }
301
302 fn make_her_buffer(capacity: usize, goal_dim: usize) -> HERBuffer {
303 let core_dim = 2;
304 let obs_dim = core_dim + goal_dim * 2; HERBuffer::new(
306 capacity,
307 obs_dim,
308 1, goal_dim,
310 core_dim, core_dim + goal_dim, HERStrategy::default(), 0.05, )
315 }
316
317 fn push_goal_episode(buf: &mut HERBuffer, length: usize, goal_dim: usize) {
319 let desired_goal = vec![10.0; goal_dim];
320 for i in 0..length {
321 let progress = (i as f32 + 1.0) / length as f32;
322 let achieved = vec![10.0 * progress; goal_dim];
323 let next_achieved = vec![10.0 * (progress + 1.0 / length as f32).min(1.0); goal_dim];
324 let core = vec![progress, progress];
325
326 let obs = make_obs(&core, &achieved, &desired_goal);
327 let next_obs = make_obs(
328 &[progress + 0.1, progress + 0.1],
329 &next_achieved,
330 &desired_goal,
331 );
332 let action = vec![0.0];
333 let reward = -1.0; let done = i == length - 1;
335
336 buf.push_slices(&obs, &next_obs, &action, reward, done, false)
337 .unwrap();
338 }
339 }
340
341 #[test]
342 fn test_her_new_is_empty() {
343 let buf = make_her_buffer(100, 3);
344 assert_eq!(buf.len(), 0);
345 assert!(buf.is_empty());
346 }
347
348 #[test]
349 fn test_her_push_increments() {
350 let mut buf = make_her_buffer(100, 3);
351 push_goal_episode(&mut buf, 5, 3);
352 assert_eq!(buf.len(), 5);
353 assert_eq!(buf.num_complete_episodes(), 1);
354 }
355
356 #[test]
357 fn test_final_strategy_uses_last_state() {
358 let goal_dim = 2;
359 let core_dim = 2;
360 let obs_dim = core_dim + goal_dim * 2;
361 let mut buf = HERBuffer::new(
362 100,
363 obs_dim,
364 1,
365 goal_dim,
366 core_dim,
367 core_dim + goal_dim,
368 HERStrategy::Final,
369 0.05,
370 );
371 push_goal_episode(&mut buf, 5, goal_dim);
372
373 let ep = &buf.tracker.episodes()[0];
374 let indices = buf.compute_relabel_indices(ep, 2, 42);
375 assert_eq!(indices.len(), 1);
376 assert_eq!(indices[0], 4); }
378
379 #[test]
380 fn test_future_strategy_picks_future_state() {
381 let goal_dim = 2;
382 let core_dim = 2;
383 let obs_dim = core_dim + goal_dim * 2;
384 let mut buf = HERBuffer::new(
385 100,
386 obs_dim,
387 1,
388 goal_dim,
389 core_dim,
390 core_dim + goal_dim,
391 HERStrategy::Future { k: 4 },
392 0.05,
393 );
394 push_goal_episode(&mut buf, 10, goal_dim);
395
396 let ep = &buf.tracker.episodes()[0];
397 let indices = buf.compute_relabel_indices(ep, 3, 42);
398 assert_eq!(indices.len(), 4);
399 for &idx in &indices {
400 assert!(
401 idx > 3,
402 "future index {idx} should be > transition offset 3"
403 );
404 assert!(idx < 10, "future index {idx} should be < episode length 10");
405 }
406 }
407
408 #[test]
409 fn test_episode_strategy_picks_any_state() {
410 let goal_dim = 2;
411 let core_dim = 2;
412 let obs_dim = core_dim + goal_dim * 2;
413 let mut buf = HERBuffer::new(
414 100,
415 obs_dim,
416 1,
417 goal_dim,
418 core_dim,
419 core_dim + goal_dim,
420 HERStrategy::Episode,
421 0.05,
422 );
423 push_goal_episode(&mut buf, 10, goal_dim);
424
425 let ep = &buf.tracker.episodes()[0];
426 let indices = buf.compute_relabel_indices(ep, 5, 42);
427 assert_eq!(indices.len(), 1);
428 assert!(indices[0] < 10);
429 }
430
431 #[test]
432 fn test_sparse_goal_reward_achieved() {
433 let achieved = [1.0, 2.0, 3.0];
434 let desired = [1.0, 2.0, 3.0];
435 assert_eq!(sparse_goal_reward(&achieved, &desired, 0.05), 0.0);
436 }
437
438 #[test]
439 fn test_sparse_goal_reward_not_achieved() {
440 let achieved = [1.0, 2.0, 3.0];
441 let desired = [10.0, 20.0, 30.0];
442 assert_eq!(sparse_goal_reward(&achieved, &desired, 0.05), -1.0);
443 }
444
445 #[test]
446 fn test_relabel_indices_future_k4() {
447 let goal_dim = 2;
448 let core_dim = 2;
449 let obs_dim = core_dim + goal_dim * 2;
450 let mut buf = HERBuffer::new(
451 100,
452 obs_dim,
453 1,
454 goal_dim,
455 core_dim,
456 core_dim + goal_dim,
457 HERStrategy::Future { k: 4 },
458 0.05,
459 );
460 push_goal_episode(&mut buf, 10, goal_dim);
461
462 let ep = &buf.tracker.episodes()[0];
463 let indices = buf.compute_relabel_indices(ep, 3, 42);
464 assert_eq!(indices.len(), 4);
465 for &idx in &indices {
466 assert!(idx > 3 && idx < 10);
467 }
468 }
469
470 #[test]
471 fn test_relabel_indices_deterministic() {
472 let goal_dim = 2;
473 let core_dim = 2;
474 let obs_dim = core_dim + goal_dim * 2;
475 let mut buf = HERBuffer::new(
476 100,
477 obs_dim,
478 1,
479 goal_dim,
480 core_dim,
481 core_dim + goal_dim,
482 HERStrategy::Future { k: 4 },
483 0.05,
484 );
485 push_goal_episode(&mut buf, 10, goal_dim);
486
487 let ep = &buf.tracker.episodes()[0];
488 let i1 = buf.compute_relabel_indices(ep, 3, 42);
489 let i2 = buf.compute_relabel_indices(ep, 3, 42);
490 assert_eq!(i1, i2);
491 }
492
493 #[test]
494 fn test_her_sample_batch_shape() {
495 let goal_dim = 3;
496 let mut buf = make_her_buffer(100, goal_dim);
497 push_goal_episode(&mut buf, 10, goal_dim);
498
499 let batch = buf.sample_with_relabeling(8, 0.8, 42).unwrap();
500 let obs_dim = 2 + goal_dim * 2;
501 assert_eq!(batch.batch_size, 8);
502 assert_eq!(batch.observations.len(), 8 * obs_dim);
503 assert_eq!(batch.actions.len(), 8);
504 assert_eq!(batch.rewards.len(), 8);
505 }
506
507 #[test]
508 fn test_her_ratio_controls_relabeling() {
509 let goal_dim = 2;
510 let mut buf = make_her_buffer(200, goal_dim);
511 for _ in 0..10 {
513 push_goal_episode(&mut buf, 10, goal_dim);
514 }
515
516 let batch = buf.sample_with_relabeling(32, 0.0, 42).unwrap();
518 for &r in &batch.rewards {
520 assert_eq!(
521 r, -1.0,
522 "with ratio=0, all rewards should be original (-1.0)"
523 );
524 }
525 }
526
527 #[test]
528 fn test_her_with_ring_wrap() {
529 let goal_dim = 2;
530 let mut buf = make_her_buffer(50, goal_dim);
531 for _ in 0..10 {
533 push_goal_episode(&mut buf, 10, goal_dim);
534 }
535 assert_eq!(buf.len(), 50);
536 let result = buf.sample_with_relabeling(4, 0.8, 42);
538 assert!(result.is_ok());
539 }
540
541 #[test]
542 fn test_her_empty_buffer_errors() {
543 let buf = make_her_buffer(100, 3);
544 let result = buf.sample_with_relabeling(4, 0.8, 42);
545 assert!(result.is_err());
546 }
547
548 mod proptests {
549 use super::*;
550 use proptest::prelude::*;
551
552 proptest! {
553 #[test]
554 fn prop_relabel_indices_in_range(
555 ep_len in 2usize..20,
556 trans_offset in 0usize..19,
557 ) {
558 let trans_offset = trans_offset.min(ep_len - 1);
559 let goal_dim = 2;
560 let core_dim = 2;
561 let obs_dim = core_dim + goal_dim * 2;
562 let buf = HERBuffer::new(
563 100, obs_dim, 1, goal_dim, core_dim, core_dim + goal_dim,
564 HERStrategy::Future { k: 4 }, 0.05,
565 );
566 let ep = EpisodeMeta { start: 0, length: ep_len, complete: true };
567 let indices = buf.compute_relabel_indices(&ep, trans_offset, 42);
568 for &idx in &indices {
569 prop_assert!(idx < ep_len, "index {idx} >= episode length {ep_len}");
570 }
571 }
572
573 #[test]
574 fn prop_future_indices_strictly_future(
575 ep_len in 3usize..20,
576 trans_offset in 0usize..18,
577 ) {
578 let trans_offset = trans_offset.min(ep_len - 2); let goal_dim = 2;
580 let core_dim = 2;
581 let obs_dim = core_dim + goal_dim * 2;
582 let buf = HERBuffer::new(
583 100, obs_dim, 1, goal_dim, core_dim, core_dim + goal_dim,
584 HERStrategy::Future { k: 4 }, 0.05,
585 );
586 let ep = EpisodeMeta { start: 0, length: ep_len, complete: true };
587 let indices = buf.compute_relabel_indices(&ep, trans_offset, 42);
588 for &idx in &indices {
589 prop_assert!(idx > trans_offset,
590 "future index {idx} should be > offset {trans_offset}");
591 }
592 }
593
594 #[test]
595 fn prop_sparse_reward_binary(
596 a0 in -10.0f32..10.0,
597 a1 in -10.0f32..10.0,
598 d0 in -10.0f32..10.0,
599 d1 in -10.0f32..10.0,
600 ) {
601 let r = sparse_goal_reward(&[a0, a1], &[d0, d1], 0.05);
602 prop_assert!(r == 0.0 || r == -1.0, "reward should be 0.0 or -1.0, got {r}");
603 }
604 }
605 }
606}