rlox_core/training/
augmentation.rs

1//! Image augmentation for visual RL (DrQ-v2 style random shift).
2//!
3//! Provides the [`ImageAugmentation`] trait and concrete implementations
4//! for composable, reproducible image augmentations on flat `(B, C, H, W)` arrays.
5
6use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12/// Trait for composable image augmentations.
13///
14/// Implementors transform a flat batch of images `(B, C, H, W)` stored as
15/// contiguous f32 arrays. The trait enables adding new augmentations
16/// (color jitter, cutout, etc.) without modifying existing code.
17pub trait ImageAugmentation: Send + Sync {
18    /// Apply the augmentation to a batch of images.
19    ///
20    /// `images` is a flat array of length `batch_size * channels * height * width`.
21    /// Returns a new flat array of the same length.
22    fn augment_batch(
23        &self,
24        images: &[f32],
25        batch_size: usize,
26        channels: usize,
27        height: usize,
28        width: usize,
29        seed: u64,
30    ) -> Result<Vec<f32>, RloxError>;
31
32    /// Human-readable name for logging/debugging.
33    fn name(&self) -> &str;
34}
35
36/// DrQ-v2 random shift augmentation.
37///
38/// Pads the image with zeros, then randomly crops back to original size.
39/// Effectively translates the image by up to `pad` pixels in each direction.
40pub struct RandomShift {
41    pub pad: usize,
42}
43
44impl ImageAugmentation for RandomShift {
45    fn augment_batch(
46        &self,
47        images: &[f32],
48        batch_size: usize,
49        channels: usize,
50        height: usize,
51        width: usize,
52        seed: u64,
53    ) -> Result<Vec<f32>, RloxError> {
54        random_shift_batch(images, batch_size, channels, height, width, self.pad, seed)
55    }
56
57    fn name(&self) -> &str {
58        "RandomShift"
59    }
60}
61
62/// Random shift: pad image with zeros, then crop a random (H, W) window.
63///
64/// For each image in the batch, a random offset `(dy, dx)` is sampled
65/// uniformly from `[0, 2*pad]`. The output pixel at `(y, x)` is taken
66/// from the padded image at `(y + dy, x + dx)`, which corresponds to
67/// the original image pixel at `(y + dy - pad, x + dx - pad)` if in
68/// bounds, or zero otherwise.
69///
70/// # Arguments
71/// * `images` - flat `(B * C * H * W)` f32 array
72/// * `pad` - number of zero-pad pixels on each side
73/// * `seed` - ChaCha8 RNG seed for reproducibility
74///
75/// # Returns
76/// New flat array of same length as input.
77#[inline]
78pub fn random_shift_batch(
79    images: &[f32],
80    batch_size: usize,
81    channels: usize,
82    height: usize,
83    width: usize,
84    pad: usize,
85    seed: u64,
86) -> Result<Vec<f32>, RloxError> {
87    let expected_len = batch_size * channels * height * width;
88    if images.len() != expected_len {
89        return Err(RloxError::ShapeMismatch {
90            expected: format!(
91                "B*C*H*W = {}*{}*{}*{} = {}",
92                batch_size, channels, height, width, expected_len
93            ),
94            got: format!("images.len() = {}", images.len()),
95        });
96    }
97
98    if expected_len == 0 {
99        return Ok(Vec::new());
100    }
101
102    if pad == 0 {
103        return Ok(images.to_vec());
104    }
105
106    let img_size = channels * height * width;
107    let mut output = vec![0.0f32; expected_len];
108    let mut rng = ChaCha8Rng::seed_from_u64(seed);
109
110    for b in 0..batch_size {
111        // Random offset in the padded image
112        let dy = rng.random_range(0..=(2 * pad));
113        let dx = rng.random_range(0..=(2 * pad));
114
115        let img_offset = b * img_size;
116
117        for c in 0..channels {
118            let ch_offset = img_offset + c * height * width;
119            for y in 0..height {
120                let src_y = y as isize + dy as isize - pad as isize;
121                if src_y < 0 || src_y >= height as isize {
122                    continue;
123                }
124                let src_y = src_y as usize;
125
126                // For output x, source is src_x = x + dx - pad.
127                // Valid range: 0 <= src_x < width, i.e.:
128                //   x >= pad - dx  (lower bound, clamped to 0)
129                //   x <  width + pad - dx  (upper bound, clamped to width)
130                let x_lo = pad.saturating_sub(dx);
131                let x_hi = if dx > pad {
132                    width.saturating_sub(dx - pad)
133                } else {
134                    width
135                };
136
137                if x_lo < x_hi {
138                    let src_x_start = x_lo + dx - pad;
139                    let row_len = x_hi - x_lo;
140                    let src_base = ch_offset + src_y * width + src_x_start;
141                    let dst_base = ch_offset + y * width + x_lo;
142                    output[dst_base..dst_base + row_len]
143                        .copy_from_slice(&images[src_base..src_base + row_len]);
144                }
145            }
146        }
147    }
148
149    Ok(output)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_random_shift_preserves_shape() {
158        let images = vec![1.0f32; 2 * 3 * 8 * 8];
159        let output = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
160        assert_eq!(output.len(), 2 * 3 * 8 * 8);
161    }
162
163    #[test]
164    fn test_random_shift_different_seeds_differ() {
165        let images: Vec<f32> = (0..2 * 3 * 8 * 8).map(|i| i as f32 / 100.0).collect();
166        let a = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
167        let b = random_shift_batch(&images, 2, 3, 8, 8, 2, 99).unwrap();
168        assert_ne!(a, b);
169    }
170
171    #[test]
172    fn test_random_shift_pad_zero_is_identity() {
173        let images: Vec<f32> = (0..16).map(|i| i as f32).collect();
174        let output = random_shift_batch(&images, 1, 1, 4, 4, 0, 42).unwrap();
175        assert_eq!(output, images);
176    }
177
178    #[test]
179    fn test_random_shift_values_bounded() {
180        let images: Vec<f32> = (0..16).map(|i| i as f32 / 15.0).collect();
181        let output = random_shift_batch(&images, 1, 1, 4, 4, 2, 42).unwrap();
182        for &v in &output {
183            assert!(v >= 0.0 && v <= 1.0, "value out of bounds: {v}");
184        }
185    }
186
187    #[test]
188    fn test_random_shift_single_pixel() {
189        let images = vec![5.0f32];
190        let output = random_shift_batch(&images, 1, 1, 1, 1, 1, 42).unwrap();
191        assert_eq!(output.len(), 1);
192        // The single pixel might land or be zero-padded
193        assert!(output[0] == 0.0 || output[0] == 5.0);
194    }
195
196    #[test]
197    fn test_random_shift_large_pad_mostly_zeros() {
198        let images = vec![1.0f32; 4];
199        let output = random_shift_batch(&images, 1, 1, 2, 2, 10, 42).unwrap();
200        let num_zeros = output.iter().filter(|&&v| v == 0.0).count();
201        // With pad=10 on a 2x2 image, most positions will be zero
202        assert!(
203            num_zeros >= 2,
204            "expected mostly zeros with large pad, got {num_zeros}/4 zeros"
205        );
206    }
207
208    #[test]
209    fn test_empty_batch_returns_empty() {
210        let output = random_shift_batch(&[], 0, 3, 8, 8, 2, 42).unwrap();
211        assert!(output.is_empty());
212    }
213
214    #[test]
215    fn test_shape_validation_rejects_mismatched_input() {
216        let images = vec![1.0f32; 10]; // wrong length
217        let result = random_shift_batch(&images, 2, 3, 8, 8, 2, 42);
218        assert!(result.is_err());
219        let err = result.unwrap_err();
220        assert!(
221            matches!(err, RloxError::ShapeMismatch { .. }),
222            "expected ShapeMismatch, got {err:?}"
223        );
224    }
225
226    #[test]
227    fn test_trait_object_safety() {
228        let aug: Box<dyn ImageAugmentation> = Box::new(RandomShift { pad: 4 });
229        assert_eq!(aug.name(), "RandomShift");
230    }
231
232    #[test]
233    fn test_random_shift_deterministic_with_same_seed() {
234        let images: Vec<f32> = (0..2 * 3 * 8 * 8).map(|i| i as f32 / 100.0).collect();
235        let a = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
236        let b = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
237        assert_eq!(a, b);
238    }
239
240    mod proptests {
241        use super::*;
242        use proptest::prelude::*;
243
244        proptest! {
245            #[test]
246            fn prop_shift_batch_size_preserved(
247                b in 1usize..8,
248                c in 1usize..4,
249                h in 2usize..16,
250                w in 2usize..16,
251                pad in 0usize..4,
252            ) {
253                let images = vec![0.5f32; b * c * h * w];
254                let output = random_shift_batch(&images, b, c, h, w, pad, 42).unwrap();
255                prop_assert_eq!(output.len(), b * c * h * w);
256            }
257
258            #[test]
259            fn prop_shift_deterministic_with_seed(
260                b in 1usize..4,
261                c in 1usize..3,
262                h in 2usize..8,
263                w in 2usize..8,
264                seed in 0u64..1000,
265            ) {
266                let images = vec![1.0f32; b * c * h * w];
267                let a = random_shift_batch(&images, b, c, h, w, 2, seed).unwrap();
268                let b_out = random_shift_batch(&images, b, c, h, w, 2, seed).unwrap();
269                prop_assert_eq!(a, b_out);
270            }
271
272            #[test]
273            fn prop_shift_values_in_input_range(
274                b in 1usize..4,
275                c in 1usize..3,
276                h in 2usize..8,
277                w in 2usize..8,
278            ) {
279                let n = b * c * h * w;
280                let images: Vec<f32> = (0..n).map(|i| (i as f32) / (n as f32)).collect();
281                let min_val = images.iter().cloned().fold(f32::INFINITY, f32::min);
282                let output = random_shift_batch(&images, b, c, h, w, 2, 42).unwrap();
283                for &v in &output {
284                    // Values are either from the original image or zero (padding)
285                    prop_assert!(v >= 0.0 && v <= 1.0,
286                        "value {v} not in [0.0, 1.0], min_val={min_val}");
287                }
288            }
289        }
290    }
291}