1use crate::error::RloxError;
2
3#[derive(Debug, Clone)]
5pub struct PackedBatch {
6 pub input_ids: Vec<u32>,
8 pub attention_mask: Vec<u32>,
10 pub position_ids: Vec<u32>,
12 pub sequence_starts: Vec<usize>,
14}
15
16pub fn pack_sequences(
24 sequences: &[&[u32]],
25 max_length: usize,
26) -> Result<Vec<PackedBatch>, RloxError> {
27 if sequences.is_empty() {
28 return Ok(Vec::new());
29 }
30
31 for (i, seq) in sequences.iter().enumerate() {
33 if seq.len() > max_length {
34 return Err(RloxError::BufferError(format!(
35 "sequence {} has length {} which exceeds max_length {}",
36 i,
37 seq.len(),
38 max_length
39 )));
40 }
41 }
42
43 let mut indexed: Vec<(usize, &[u32])> =
45 sequences.iter().enumerate().map(|(i, s)| (i, *s)).collect();
46 indexed.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
47
48 struct Bin {
50 tokens: Vec<u32>,
51 attention_mask: Vec<u32>,
52 position_ids: Vec<u32>,
53 sequence_starts: Vec<usize>,
54 used: usize,
55 }
56
57 let mut bins: Vec<Bin> = Vec::new();
58
59 for (_orig_idx, seq) in &indexed {
60 let seq_len = seq.len();
61
62 let mut placed = false;
64 for bin in bins.iter_mut() {
65 if bin.used + seq_len <= max_length {
66 bin.sequence_starts.push(bin.used);
67 bin.tokens.extend_from_slice(seq);
68 bin.attention_mask
69 .extend(std::iter::repeat_n(1u32, seq_len));
70 for j in 0..seq_len {
71 bin.position_ids.push(j as u32);
72 }
73 bin.used += seq_len;
74 placed = true;
75 break;
76 }
77 }
78
79 if !placed {
80 let mut bin = Bin {
81 tokens: Vec::with_capacity(max_length),
82 attention_mask: Vec::with_capacity(max_length),
83 position_ids: Vec::with_capacity(max_length),
84 sequence_starts: Vec::new(),
85 used: 0,
86 };
87 bin.sequence_starts.push(0);
88 bin.tokens.extend_from_slice(seq);
89 bin.attention_mask
90 .extend(std::iter::repeat_n(1u32, seq_len));
91 for j in 0..seq_len {
92 bin.position_ids.push(j as u32);
93 }
94 bin.used = seq_len;
95 bins.push(bin);
96 }
97 }
98
99 let result = bins
101 .into_iter()
102 .map(|bin| PackedBatch {
103 input_ids: bin.tokens,
104 attention_mask: bin.attention_mask,
105 position_ids: bin.position_ids,
106 sequence_starts: bin.sequence_starts,
107 })
108 .collect();
109
110 Ok(result)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn pack_sequences_one_bin_exact_fit() {
119 let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5, 6]];
120 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
121 let packed = pack_sequences(&slices, 6).unwrap();
122 assert_eq!(
123 packed.len(),
124 1,
125 "should produce 1 bin for total_len=max_len"
126 );
127 assert_eq!(packed[0].input_ids.len(), 6);
128 }
129
130 #[test]
131 fn pack_sequences_two_bins_when_overflow() {
132 let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
133 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
134 let packed = pack_sequences(&slices, 6).unwrap();
135 assert!(packed.len() >= 2, "should produce at least 2 bins");
136 for bin in &packed {
137 assert!(
138 bin.input_ids.len() <= 6,
139 "bin length {} exceeds max_length 6",
140 bin.input_ids.len()
141 );
142 }
143 }
144
145 #[test]
146 fn pack_sequences_all_sequences_present() {
147 let seqs: Vec<Vec<u32>> = vec![
148 vec![1, 2, 3],
149 vec![4, 5],
150 vec![6, 7, 8, 9],
151 vec![10],
152 vec![11, 12],
153 ];
154 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
155 let packed = pack_sequences(&slices, 8).unwrap();
156
157 let mut all_packed_tokens: Vec<u32> = packed
158 .iter()
159 .flat_map(|b| {
160 b.input_ids
161 .iter()
162 .copied()
163 .zip(b.attention_mask.iter().copied())
164 .filter_map(|(t, m)| if m != 0 { Some(t) } else { None })
165 })
166 .collect();
167 all_packed_tokens.sort_unstable();
168
169 let mut all_input_tokens: Vec<u32> = seqs.iter().flatten().copied().collect();
170 all_input_tokens.sort_unstable();
171
172 assert_eq!(
173 all_packed_tokens, all_input_tokens,
174 "all input tokens must appear exactly once in output"
175 );
176 }
177
178 #[test]
179 fn pack_sequences_sequence_exceeds_max_length_returns_error() {
180 let long_seq = vec![1u32; 100];
181 let slices = vec![long_seq.as_slice()];
182 let result = pack_sequences(&slices, 50);
183 assert!(
184 result.is_err(),
185 "sequence longer than max_length must return Err"
186 );
187 }
188
189 #[test]
190 fn pack_sequences_empty_input_returns_empty() {
191 let result = pack_sequences(&[], 512).unwrap();
192 assert!(result.is_empty());
193 }
194
195 #[test]
196 fn pack_sequences_single_sequence() {
197 let seq = vec![10u32, 20, 30];
198 let slices = vec![seq.as_slice()];
199 let packed = pack_sequences(&slices, 512).unwrap();
200 assert_eq!(packed.len(), 1);
201 assert_eq!(&packed[0].input_ids[..3], &[10, 20, 30]);
202 }
203
204 #[test]
205 fn pack_sequences_attention_mask_matches_input_ids_length() {
206 let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5]];
207 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
208 let packed = pack_sequences(&slices, 8).unwrap();
209 for bin in &packed {
210 assert_eq!(
211 bin.input_ids.len(),
212 bin.attention_mask.len(),
213 "attention_mask must have same length as input_ids"
214 );
215 }
216 }
217
218 #[test]
219 fn pack_sequences_position_ids_per_sequence_start_from_zero() {
220 let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5]];
221 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
222 let packed = pack_sequences(&slices, 8).unwrap();
223
224 for bin in &packed {
225 for (k, &start) in bin.sequence_starts.iter().enumerate() {
226 let end = if k + 1 < bin.sequence_starts.len() {
227 bin.sequence_starts[k + 1]
228 } else {
229 bin.input_ids.len()
230 };
231 for (j, pos) in (start..end).enumerate() {
232 assert_eq!(
233 bin.position_ids[pos], j as u32,
234 "position_ids[{pos}] should be {j}"
235 );
236 }
237 }
238 }
239 }
240
241 #[test]
242 fn pack_sequences_fill_rate_good_for_varied_lengths() {
243 let lengths = [10, 50, 20, 80, 30, 60, 15, 100, 40, 25];
244 let seqs: Vec<Vec<u32>> = lengths
245 .iter()
246 .enumerate()
247 .map(|(i, &l)| (0..l as u32).map(|t| i as u32 * 1000 + t).collect())
248 .collect();
249 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
250 let max_len = 128;
251 let packed = pack_sequences(&slices, max_len).unwrap();
252
253 let total_tokens: usize = lengths.iter().sum();
254 let total_capacity: usize = packed.iter().map(|b| b.input_ids.len()).sum();
255 let fill_rate = total_tokens as f64 / total_capacity as f64;
256
257 assert!(
258 fill_rate > 0.5,
259 "fill_rate {fill_rate:.2} is below 0.5 — bin packing too inefficient"
260 );
261 }
262
263 mod proptests {
264 use super::*;
265 use proptest::collection::vec;
266 use proptest::prelude::*;
267
268 proptest! {
269 #[test]
270 fn pack_sequences_no_bin_exceeds_max_length(
271 lengths in vec(1usize..100, 1..20),
272 max_len in 10usize..200,
273 ) {
274 let seqs: Vec<Vec<u32>> = lengths
275 .iter()
276 .filter(|&&l| l <= max_len)
277 .map(|&l| (0..l as u32).collect())
278 .collect();
279 if seqs.is_empty() {
280 return Ok(());
281 }
282 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
283 let packed = pack_sequences(&slices, max_len).unwrap();
284 for bin in &packed {
285 prop_assert!(
286 bin.input_ids.len() <= max_len,
287 "bin length {} exceeds max_len {}",
288 bin.input_ids.len(),
289 max_len
290 );
291 }
292 }
293
294 #[test]
295 fn pack_sequences_all_tokens_present(
296 lengths in vec(1usize..50, 1..15),
297 ) {
298 let max_len = 128;
299 let seqs: Vec<Vec<u32>> = lengths
300 .iter()
301 .filter(|&&l| l <= max_len)
302 .enumerate()
303 .map(|(i, &l)| (0..l as u32).map(|t| i as u32 * 1000 + t).collect())
304 .collect();
305 if seqs.is_empty() {
306 return Ok(());
307 }
308 let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
309 let packed = pack_sequences(&slices, max_len).unwrap();
310
311 let mut packed_tokens: Vec<u32> = packed
312 .iter()
313 .flat_map(|b| {
314 b.input_ids
315 .iter()
316 .copied()
317 .zip(b.attention_mask.iter().copied())
318 .filter_map(|(t, m)| if m != 0 { Some(t) } else { None })
319 })
320 .collect();
321 packed_tokens.sort_unstable();
322
323 let mut input_tokens: Vec<u32> = seqs.iter().flatten().copied().collect();
324 input_tokens.sort_unstable();
325
326 prop_assert_eq!(packed_tokens, input_tokens);
327 }
328 }
329 }
330}