1use rand::Rng;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::error::RloxError;
11
12pub trait ImageAugmentation: Send + Sync {
18 fn augment_batch(
23 &self,
24 images: &[f32],
25 batch_size: usize,
26 channels: usize,
27 height: usize,
28 width: usize,
29 seed: u64,
30 ) -> Result<Vec<f32>, RloxError>;
31
32 fn name(&self) -> &str;
34}
35
36pub struct RandomShift {
41 pub pad: usize,
42}
43
44impl ImageAugmentation for RandomShift {
45 fn augment_batch(
46 &self,
47 images: &[f32],
48 batch_size: usize,
49 channels: usize,
50 height: usize,
51 width: usize,
52 seed: u64,
53 ) -> Result<Vec<f32>, RloxError> {
54 random_shift_batch(images, batch_size, channels, height, width, self.pad, seed)
55 }
56
57 fn name(&self) -> &str {
58 "RandomShift"
59 }
60}
61
62#[inline]
78pub fn random_shift_batch(
79 images: &[f32],
80 batch_size: usize,
81 channels: usize,
82 height: usize,
83 width: usize,
84 pad: usize,
85 seed: u64,
86) -> Result<Vec<f32>, RloxError> {
87 let expected_len = batch_size * channels * height * width;
88 if images.len() != expected_len {
89 return Err(RloxError::ShapeMismatch {
90 expected: format!(
91 "B*C*H*W = {}*{}*{}*{} = {}",
92 batch_size, channels, height, width, expected_len
93 ),
94 got: format!("images.len() = {}", images.len()),
95 });
96 }
97
98 if expected_len == 0 {
99 return Ok(Vec::new());
100 }
101
102 if pad == 0 {
103 return Ok(images.to_vec());
104 }
105
106 let img_size = channels * height * width;
107 let mut output = vec![0.0f32; expected_len];
108 let mut rng = ChaCha8Rng::seed_from_u64(seed);
109
110 for b in 0..batch_size {
111 let dy = rng.random_range(0..=(2 * pad));
113 let dx = rng.random_range(0..=(2 * pad));
114
115 let img_offset = b * img_size;
116
117 for c in 0..channels {
118 let ch_offset = img_offset + c * height * width;
119 for y in 0..height {
120 let src_y = y as isize + dy as isize - pad as isize;
121 if src_y < 0 || src_y >= height as isize {
122 continue;
123 }
124 let src_y = src_y as usize;
125
126 let x_lo = pad.saturating_sub(dx);
131 let x_hi = if dx > pad {
132 width.saturating_sub(dx - pad)
133 } else {
134 width
135 };
136
137 if x_lo < x_hi {
138 let src_x_start = x_lo + dx - pad;
139 let row_len = x_hi - x_lo;
140 let src_base = ch_offset + src_y * width + src_x_start;
141 let dst_base = ch_offset + y * width + x_lo;
142 output[dst_base..dst_base + row_len]
143 .copy_from_slice(&images[src_base..src_base + row_len]);
144 }
145 }
146 }
147 }
148
149 Ok(output)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_random_shift_preserves_shape() {
158 let images = vec![1.0f32; 2 * 3 * 8 * 8];
159 let output = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
160 assert_eq!(output.len(), 2 * 3 * 8 * 8);
161 }
162
163 #[test]
164 fn test_random_shift_different_seeds_differ() {
165 let images: Vec<f32> = (0..2 * 3 * 8 * 8).map(|i| i as f32 / 100.0).collect();
166 let a = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
167 let b = random_shift_batch(&images, 2, 3, 8, 8, 2, 99).unwrap();
168 assert_ne!(a, b);
169 }
170
171 #[test]
172 fn test_random_shift_pad_zero_is_identity() {
173 let images: Vec<f32> = (0..16).map(|i| i as f32).collect();
174 let output = random_shift_batch(&images, 1, 1, 4, 4, 0, 42).unwrap();
175 assert_eq!(output, images);
176 }
177
178 #[test]
179 fn test_random_shift_values_bounded() {
180 let images: Vec<f32> = (0..16).map(|i| i as f32 / 15.0).collect();
181 let output = random_shift_batch(&images, 1, 1, 4, 4, 2, 42).unwrap();
182 for &v in &output {
183 assert!(v >= 0.0 && v <= 1.0, "value out of bounds: {v}");
184 }
185 }
186
187 #[test]
188 fn test_random_shift_single_pixel() {
189 let images = vec![5.0f32];
190 let output = random_shift_batch(&images, 1, 1, 1, 1, 1, 42).unwrap();
191 assert_eq!(output.len(), 1);
192 assert!(output[0] == 0.0 || output[0] == 5.0);
194 }
195
196 #[test]
197 fn test_random_shift_large_pad_mostly_zeros() {
198 let images = vec![1.0f32; 4];
199 let output = random_shift_batch(&images, 1, 1, 2, 2, 10, 42).unwrap();
200 let num_zeros = output.iter().filter(|&&v| v == 0.0).count();
201 assert!(
203 num_zeros >= 2,
204 "expected mostly zeros with large pad, got {num_zeros}/4 zeros"
205 );
206 }
207
208 #[test]
209 fn test_empty_batch_returns_empty() {
210 let output = random_shift_batch(&[], 0, 3, 8, 8, 2, 42).unwrap();
211 assert!(output.is_empty());
212 }
213
214 #[test]
215 fn test_shape_validation_rejects_mismatched_input() {
216 let images = vec![1.0f32; 10]; let result = random_shift_batch(&images, 2, 3, 8, 8, 2, 42);
218 assert!(result.is_err());
219 let err = result.unwrap_err();
220 assert!(
221 matches!(err, RloxError::ShapeMismatch { .. }),
222 "expected ShapeMismatch, got {err:?}"
223 );
224 }
225
226 #[test]
227 fn test_trait_object_safety() {
228 let aug: Box<dyn ImageAugmentation> = Box::new(RandomShift { pad: 4 });
229 assert_eq!(aug.name(), "RandomShift");
230 }
231
232 #[test]
233 fn test_random_shift_deterministic_with_same_seed() {
234 let images: Vec<f32> = (0..2 * 3 * 8 * 8).map(|i| i as f32 / 100.0).collect();
235 let a = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
236 let b = random_shift_batch(&images, 2, 3, 8, 8, 2, 42).unwrap();
237 assert_eq!(a, b);
238 }
239
240 mod proptests {
241 use super::*;
242 use proptest::prelude::*;
243
244 proptest! {
245 #[test]
246 fn prop_shift_batch_size_preserved(
247 b in 1usize..8,
248 c in 1usize..4,
249 h in 2usize..16,
250 w in 2usize..16,
251 pad in 0usize..4,
252 ) {
253 let images = vec![0.5f32; b * c * h * w];
254 let output = random_shift_batch(&images, b, c, h, w, pad, 42).unwrap();
255 prop_assert_eq!(output.len(), b * c * h * w);
256 }
257
258 #[test]
259 fn prop_shift_deterministic_with_seed(
260 b in 1usize..4,
261 c in 1usize..3,
262 h in 2usize..8,
263 w in 2usize..8,
264 seed in 0u64..1000,
265 ) {
266 let images = vec![1.0f32; b * c * h * w];
267 let a = random_shift_batch(&images, b, c, h, w, 2, seed).unwrap();
268 let b_out = random_shift_batch(&images, b, c, h, w, 2, seed).unwrap();
269 prop_assert_eq!(a, b_out);
270 }
271
272 #[test]
273 fn prop_shift_values_in_input_range(
274 b in 1usize..4,
275 c in 1usize..3,
276 h in 2usize..8,
277 w in 2usize..8,
278 ) {
279 let n = b * c * h * w;
280 let images: Vec<f32> = (0..n).map(|i| (i as f32) / (n as f32)).collect();
281 let min_val = images.iter().cloned().fold(f32::INFINITY, f32::min);
282 let output = random_shift_batch(&images, b, c, h, w, 2, 42).unwrap();
283 for &v in &output {
284 prop_assert!(v >= 0.0 && v <= 1.0,
286 "value {v} not in [0.0, 1.0], min_val={min_val}");
287 }
288 }
289 }
290 }
291}