1use crate::error::RloxError;
7
8use super::ringbuf::{ReplayBuffer, SampledBatch};
9
10pub fn sample_mixed(
22 buffer_a: &ReplayBuffer,
23 buffer_b: &ReplayBuffer,
24 ratio: f64,
25 batch_size: usize,
26 seed: u64,
27) -> Result<SampledBatch, RloxError> {
28 if !(0.0..=1.0).contains(&ratio) {
29 return Err(RloxError::BufferError(format!(
30 "ratio must be in [0.0, 1.0], got {ratio}"
31 )));
32 }
33
34 if buffer_a.is_empty() && ratio > 0.0 {
35 return Err(RloxError::BufferError(
36 "buffer_a is empty but ratio > 0".into(),
37 ));
38 }
39 if buffer_b.is_empty() && ratio < 1.0 {
40 return Err(RloxError::BufferError(
41 "buffer_b is empty but ratio < 1.0".into(),
42 ));
43 }
44
45 let n_from_a = ((batch_size as f64) * ratio).ceil() as usize;
46 let n_from_a = n_from_a.min(batch_size); let n_from_b = batch_size - n_from_a;
48
49 let seed_a = seed;
51 let seed_b = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); let batch_a = if n_from_a > 0 {
54 Some(buffer_a.sample(n_from_a, seed_a)?)
55 } else {
56 None
57 };
58
59 let batch_b = if n_from_b > 0 {
60 Some(buffer_b.sample(n_from_b, seed_b)?)
61 } else {
62 None
63 };
64
65 let (obs_dim, act_dim) = match (&batch_a, &batch_b) {
67 (Some(a), Some(b)) => {
68 if a.obs_dim != b.obs_dim || a.act_dim != b.act_dim {
69 return Err(RloxError::ShapeMismatch {
70 expected: format!("obs_dim={}, act_dim={}", a.obs_dim, a.act_dim),
71 got: format!("obs_dim={}, act_dim={}", b.obs_dim, b.act_dim),
72 });
73 }
74 (a.obs_dim, a.act_dim)
75 }
76 (Some(a), None) => (a.obs_dim, a.act_dim),
77 (None, Some(b)) => (b.obs_dim, b.act_dim),
78 (None, None) => {
79 return Err(RloxError::BufferError(
80 "batch_size is 0 or both buffers empty".into(),
81 ));
82 }
83 };
84
85 let mut merged = SampledBatch::with_capacity(batch_size, obs_dim, act_dim);
86
87 if let Some(a) = batch_a {
88 merged.observations.extend_from_slice(&a.observations);
89 merged
90 .next_observations
91 .extend_from_slice(&a.next_observations);
92 merged.actions.extend_from_slice(&a.actions);
93 merged.rewards.extend_from_slice(&a.rewards);
94 merged.terminated.extend_from_slice(&a.terminated);
95 merged.truncated.extend_from_slice(&a.truncated);
96 }
97
98 if let Some(b) = batch_b {
99 merged.observations.extend_from_slice(&b.observations);
100 merged
101 .next_observations
102 .extend_from_slice(&b.next_observations);
103 merged.actions.extend_from_slice(&b.actions);
104 merged.rewards.extend_from_slice(&b.rewards);
105 merged.terminated.extend_from_slice(&b.terminated);
106 merged.truncated.extend_from_slice(&b.truncated);
107 }
108
109 merged.batch_size = batch_size;
110
111 Ok(merged)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::buffer::sample_record;
118
119 fn make_buffer_with_reward(
120 capacity: usize,
121 obs_dim: usize,
122 n: usize,
123 reward: f32,
124 ) -> ReplayBuffer {
125 let mut buf = ReplayBuffer::new(capacity, obs_dim, 1);
126 for _ in 0..n {
127 let mut r = sample_record(obs_dim);
128 r.reward = reward;
129 buf.push(r).unwrap();
130 }
131 buf
132 }
133
134 #[test]
135 fn test_mixed_ratio_one_all_from_a() {
136 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
137 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
138 let batch = sample_mixed(&buf_a, &buf_b, 1.0, 32, 42).unwrap();
139 assert_eq!(batch.batch_size, 32);
140 for &r in &batch.rewards {
141 assert_eq!(r, 1.0, "all samples should come from buffer_a");
142 }
143 }
144
145 #[test]
146 fn test_mixed_ratio_zero_all_from_b() {
147 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
148 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
149 let batch = sample_mixed(&buf_a, &buf_b, 0.0, 32, 42).unwrap();
150 assert_eq!(batch.batch_size, 32);
151 for &r in &batch.rewards {
152 assert_eq!(r, 2.0, "all samples should come from buffer_b");
153 }
154 }
155
156 #[test]
157 fn test_mixed_ratio_half() {
158 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
159 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
160 let batch = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
161 assert_eq!(batch.batch_size, 32);
162 let from_a = batch.rewards.iter().filter(|&&r| r == 1.0).count();
163 let from_b = batch.rewards.iter().filter(|&&r| r == 2.0).count();
164 assert_eq!(from_a, 16);
165 assert_eq!(from_b, 16);
166 }
167
168 #[test]
169 fn test_mixed_deterministic() {
170 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
171 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
172 let b1 = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
173 let b2 = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
174 assert_eq!(b1.observations, b2.observations);
175 assert_eq!(b1.rewards, b2.rewards);
176 }
177
178 #[test]
179 fn test_mixed_batch_shape() {
180 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
181 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
182 let batch = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42).unwrap();
183 assert_eq!(batch.observations.len(), 32 * 4);
184 assert_eq!(batch.next_observations.len(), 32 * 4);
185 assert_eq!(batch.actions.len(), 32);
186 assert_eq!(batch.rewards.len(), 32);
187 assert_eq!(batch.terminated.len(), 32);
188 }
189
190 #[test]
191 fn test_mixed_empty_buffer_errors() {
192 let buf_a = ReplayBuffer::new(100, 4, 1);
193 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
194 let result = sample_mixed(&buf_a, &buf_b, 0.5, 32, 42);
195 assert!(result.is_err());
196 }
197
198 #[test]
199 fn test_mixed_validates_ratio_range() {
200 let buf_a = make_buffer_with_reward(100, 4, 50, 1.0);
201 let buf_b = make_buffer_with_reward(100, 4, 50, 2.0);
202 let result = sample_mixed(&buf_a, &buf_b, 1.5, 32, 42);
203 assert!(result.is_err());
204
205 let result2 = sample_mixed(&buf_a, &buf_b, -0.1, 32, 42);
206 assert!(result2.is_err());
207 }
208
209 mod proptests {
210 use super::*;
211 use crate::buffer::sample_record;
212 use proptest::prelude::*;
213
214 proptest! {
215 #[test]
216 fn prop_mixed_batch_size_correct(
217 batch_size in 1usize..50,
218 ratio in 0.0f64..1.0,
219 ) {
220 let mut buf_a = ReplayBuffer::new(200, 4, 1);
221 let mut buf_b = ReplayBuffer::new(200, 4, 1);
222 for _ in 0..100 {
223 buf_a.push(sample_record(4)).unwrap();
224 buf_b.push(sample_record(4)).unwrap();
225 }
226 let batch = sample_mixed(&buf_a, &buf_b, ratio, batch_size, 42).unwrap();
227 prop_assert_eq!(batch.batch_size, batch_size);
228 prop_assert_eq!(batch.rewards.len(), batch_size);
229 }
230 }
231 }
232}