rlox_core/buffer/
mixed.rs

1//! Mixed sampling from two replay buffers.
2//!
3//! Used for offline-to-online fine-tuning: sample a configurable ratio
4//! from an offline dataset and the remainder from an online replay buffer.
5
6use crate::error::RloxError;
7
8use super::ringbuf::{ReplayBuffer, SampledBatch};
9
10/// Sample a mixed batch from two buffers.
11///
12/// Draws `ceil(batch_size * ratio)` from `buffer_a` and the remainder from
13/// `buffer_b`. Returns a merged `SampledBatch`.
14///
15/// # Arguments
16/// * `buffer_a` - first buffer (e.g., offline dataset)
17/// * `buffer_b` - second buffer (e.g., online replay)
18/// * `ratio` - fraction of samples from buffer_a (0.0 to 1.0)
19/// * `batch_size` - total number of transitions to sample
20/// * `seed` - RNG seed
21pub fn sample_mixed(
22    buffer_a: &ReplayBuffer,
23    buffer_b: &ReplayBuffer,
24    ratio: f64,
25    batch_size: usize,
26    seed: u64,
27) -> Result<SampledBatch, RloxError> {
28    if !(0.0..=1.0).contains(&ratio) {
29        return Err(RloxError::BufferError(format!(
30            "ratio must be in [0.0, 1.0], got {ratio}"
31        )));
32    }
33
34    if buffer_a.is_empty() && ratio > 0.0 {
35        return Err(RloxError::BufferError(
36            "buffer_a is empty but ratio > 0".into(),
37        ));
38    }
39    if buffer_b.is_empty() && ratio < 1.0 {
40        return Err(RloxError::BufferError(
41            "buffer_b is empty but ratio < 1.0".into(),
42        ));
43    }
44
45    let n_from_a = ((batch_size as f64) * ratio).ceil() as usize;
46    let n_from_a = n_from_a.min(batch_size); // safety clamp
47    let n_from_b = batch_size - n_from_a;
48
49    // Use different seeds for the two buffers to avoid correlation
50    let seed_a = seed;
51    let seed_b = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); // golden-ratio derived offset
52
53    let batch_a = if n_from_a > 0 {
54        Some(buffer_a.sample(n_from_a, seed_a)?)
55    } else {
56        None
57    };
58
59    let batch_b = if n_from_b > 0 {
60        Some(buffer_b.sample(n_from_b, seed_b)?)
61    } else {
62        None
63    };
64
65    // Determine dimensions from whichever batch we have
66    let (obs_dim, act_dim) = match (&batch_a, &batch_b) {
67        (Some(a), Some(b)) => {
68            if a.obs_dim != b.obs_dim || a.act_dim != b.act_dim {
69                return Err(RloxError::ShapeMismatch {
70                    expected: format!("obs_dim={}, act_dim={}", a.obs_dim, a.act_dim),
71                    got: format!("obs_dim={}, act_dim={}", b.obs_dim, b.act_dim),
72                });
73            }
74            (a.obs_dim, a.act_dim)
75        }
76        (Some(a), None) => (a.obs_dim, a.act_dim),
77        (None, Some(b)) => (b.obs_dim, b.act_dim),
78        (None, None) => {
79            return Err(RloxError::BufferError(
80                "batch_size is 0 or both buffers empty".into(),
81            ));
82        }
83    };
84
85    let mut merged = SampledBatch::with_capacity(batch_size, obs_dim, act_dim);
86
87    if let Some(a) = batch_a {
88        merged.observations.extend_from_slice(&a.observations);
89        merged
90            .next_observations
91            .extend_from_slice(&a.next_observations);
92        merged.actions.extend_from_slice(&a.actions);
93        merged.rewards.extend_from_slice(&a.rewards);
94        merged.terminated.extend_from_slice(&a.terminated);
95        merged.truncated.extend_from_slice(&a.truncated);
96    }
97
98    if let Some(b) = batch_b {
99        merged.observations.extend_from_slice(&b.observations);
100        merged
101            .next_observations
102            .extend_from_slice(&b.next_observations);
103        merged.actions.extend_from_slice(&b.actions);
104        merged.rewards.extend_from_slice(&b.rewards);
105        merged.terminated.extend_from_slice(&b.terminated);
106        merged.truncated.extend_from_slice(&b.truncated);
107    }
108
109    merged.batch_size = batch_size;
110
111    Ok(merged)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::buffer::sample_record;
118
119    fn make_buffer_with_reward(
120        capacity: usize,
121        obs_dim: usize,
122        n: usize,
123        reward: f32,
124    ) -> ReplayBuffer {
125        let mut buf = ReplayBuffer::new(capacity, obs_dim, 1);
126        for _ in 0..n {
127            let mut r = sample_record(obs_dim);
128            r.reward = reward;
129            buf.push(r).unwrap();
130        }
131        buf
132    }
133
134    #[test]
135    fn test_mixed_ratio_one_all_from_a() {
136        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
137        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
138        let batch = sample_mixed(&buf_a, &buf_b, 1.0, 32, 42).unwrap();
139        assert_eq!(batch.batch_size, 32);
140        for &r in &batch.rewards {
141            assert_eq!(r, 1.0, "all samples should come from buffer_a");
142        }
143    }
144
145    #[test]
146    fn test_mixed_ratio_zero_all_from_b() {
147        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
148        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
149        let batch = sample_mixed(&buf_a, &buf_b, 0.0, 32, 42).unwrap();
150        assert_eq!(batch.batch_size, 32);
151        for &r in &batch.rewards {
152            assert_eq!(r, 2.0, "all samples should come from buffer_b");
153        }
154    }
155
156    #[test]
157    fn test_mixed_ratio_half() {
158        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
159        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
160        let batch = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
161        assert_eq!(batch.batch_size, 32);
162        let from_a = batch.rewards.iter().filter(|&&r| r == 1.0).count();
163        let from_b = batch.rewards.iter().filter(|&&r| r == 2.0).count();
164        assert_eq!(from_a, 16);
165        assert_eq!(from_b, 16);
166    }
167
168    #[test]
169    fn test_mixed_deterministic() {
170        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
171        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
172        let b1 = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
173        let b2 = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
174        assert_eq!(b1.observations, b2.observations);
175        assert_eq!(b1.rewards, b2.rewards);
176    }
177
178    #[test]
179    fn test_mixed_batch_shape() {
180        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
181        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
182        let batch = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
183        assert_eq!(batch.observations.len(), 32 * 4);
184        assert_eq!(batch.next_observations.len(), 32 * 4);
185        assert_eq!(batch.actions.len(), 32);
186        assert_eq!(batch.rewards.len(), 32);
187        assert_eq!(batch.terminated.len(), 32);
188    }
189
190    #[test]
191    fn test_mixed_empty_buffer_errors() {
192        let buf_a = ReplayBuffer::new(100, 4, 1);
193        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
194        let result = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42);
195        assert!(result.is_err());
196    }
197
198    #[test]
199    fn test_mixed_validates_ratio_range() {
200        let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
201        let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
202        let result = sample_mixed(&buf_a, &buf_b, 1.5, 32, 42);
203        assert!(result.is_err());
204
205        let result2 = sample_mixed(&buf_a, &buf_b, -0.1, 32, 42);
206        assert!(result2.is_err());
207    }
208
209    mod proptests {
210        use super::*;
211        use crate::buffer::sample_record;
212        use proptest::prelude::*;
213
214        proptest! {
215            #[test]
216            fn prop_mixed_batch_size_correct(
217                batch_size in 1usize..50,
218                ratio in 0.0f64..1.0,
219            ) {
220                let mut buf_a = ReplayBuffer::new(200, 4, 1);
221                let mut buf_b = ReplayBuffer::new(200, 4, 1);
222                for _ in 0..100 {
223                    buf_a.push(sample_record(4)).unwrap();
224                    buf_b.push(sample_record(4)).unwrap();
225                }
226                let batch = sample_mixed(&buf_a, &buf_b, ratio, batch_size, 42).unwrap();
227                prop_assert_eq!(batch.batch_size, batch_size);
228                prop_assert_eq!(batch.rewards.len(), batch_size);
229            }
230        }
231    }
232}