rlox_core/training/
packing.rs

1use crate::error::RloxError;
2
3/// A single packed batch containing multiple sequences concatenated together.
4#[derive(Debug, Clone)]
5pub struct PackedBatch {
6    /// Token IDs, concatenated sequences padded to fill the bin.
7    pub input_ids: Vec<u32>,
8    /// Attention mask: 1 for real tokens, 0 for padding.
9    pub attention_mask: Vec<u32>,
10    /// Position IDs: per-sequence positions starting from 0.
11    pub position_ids: Vec<u32>,
12    /// Start indices of each sequence within this batch.
13    pub sequence_starts: Vec<usize>,
14}
15
16/// Pack variable-length sequences into fixed-size bins using first-fit-decreasing.
17///
18/// Sorts sequences by length (longest first), then greedily assigns each
19/// to the first bin that has enough remaining capacity. Creates a new bin
20/// if no existing bin can fit the sequence.
21///
22/// Returns `Err` if any single sequence exceeds `max_length`.
23pub fn pack_sequences(
24    sequences: &[&[u32]],
25    max_length: usize,
26) -> Result<Vec<PackedBatch>, RloxError> {
27    if sequences.is_empty() {
28        return Ok(Vec::new());
29    }
30
31    // Validate: no sequence exceeds max_length
32    for (i, seq) in sequences.iter().enumerate() {
33        if seq.len() > max_length {
34            return Err(RloxError::BufferError(format!(
35                "sequence {} has length {} which exceeds max_length {}",
36                i,
37                seq.len(),
38                max_length
39            )));
40        }
41    }
42
43    // Sort by length descending (first-fit-decreasing)
44    let mut indexed: Vec<(usize, &[u32])> =
45        sequences.iter().enumerate().map(|(i, s)| (i, *s)).collect();
46    indexed.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
47
48    // Bins: track remaining capacity and accumulated sequences
49    struct Bin {
50        tokens: Vec<u32>,
51        attention_mask: Vec<u32>,
52        position_ids: Vec<u32>,
53        sequence_starts: Vec<usize>,
54        used: usize,
55    }
56
57    let mut bins: Vec<Bin> = Vec::new();
58
59    for (_orig_idx, seq) in &indexed {
60        let seq_len = seq.len();
61
62        // First-fit: find first bin with enough room
63        let mut placed = false;
64        for bin in bins.iter_mut() {
65            if bin.used + seq_len <= max_length {
66                bin.sequence_starts.push(bin.used);
67                bin.tokens.extend_from_slice(seq);
68                bin.attention_mask
69                    .extend(std::iter::repeat_n(1u32, seq_len));
70                for j in 0..seq_len {
71                    bin.position_ids.push(j as u32);
72                }
73                bin.used += seq_len;
74                placed = true;
75                break;
76            }
77        }
78
79        if !placed {
80            let mut bin = Bin {
81                tokens: Vec::with_capacity(max_length),
82                attention_mask: Vec::with_capacity(max_length),
83                position_ids: Vec::with_capacity(max_length),
84                sequence_starts: Vec::new(),
85                used: 0,
86            };
87            bin.sequence_starts.push(0);
88            bin.tokens.extend_from_slice(seq);
89            bin.attention_mask
90                .extend(std::iter::repeat_n(1u32, seq_len));
91            for j in 0..seq_len {
92                bin.position_ids.push(j as u32);
93            }
94            bin.used = seq_len;
95            bins.push(bin);
96        }
97    }
98
99    // Convert bins to PackedBatch (pad to used length, not max_length — no wasteful padding)
100    let result = bins
101        .into_iter()
102        .map(|bin| PackedBatch {
103            input_ids: bin.tokens,
104            attention_mask: bin.attention_mask,
105            position_ids: bin.position_ids,
106            sequence_starts: bin.sequence_starts,
107        })
108        .collect();
109
110    Ok(result)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn pack_sequences_one_bin_exact_fit() {
119        let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5, 6]];
120        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
121        let packed = pack_sequences(&slices, 6).unwrap();
122        assert_eq!(
123            packed.len(),
124            1,
125            "should produce 1 bin for total_len=max_len"
126        );
127        assert_eq!(packed[0].input_ids.len(), 6);
128    }
129
130    #[test]
131    fn pack_sequences_two_bins_when_overflow() {
132        let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
133        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
134        let packed = pack_sequences(&slices, 6).unwrap();
135        assert!(packed.len() >= 2, "should produce at least 2 bins");
136        for bin in &packed {
137            assert!(
138                bin.input_ids.len() <= 6,
139                "bin length {} exceeds max_length 6",
140                bin.input_ids.len()
141            );
142        }
143    }
144
145    #[test]
146    fn pack_sequences_all_sequences_present() {
147        let seqs: Vec<Vec<u32>> = vec![
148            vec![1, 2, 3],
149            vec![4, 5],
150            vec![6, 7, 8, 9],
151            vec![10],
152            vec![11, 12],
153        ];
154        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
155        let packed = pack_sequences(&slices, 8).unwrap();
156
157        let mut all_packed_tokens: Vec<u32> = packed
158            .iter()
159            .flat_map(|b| {
160                b.input_ids
161                    .iter()
162                    .copied()
163                    .zip(b.attention_mask.iter().copied())
164                    .filter_map(|(t, m)| if m != 0 { Some(t) } else { None })
165            })
166            .collect();
167        all_packed_tokens.sort_unstable();
168
169        let mut all_input_tokens: Vec<u32> = seqs.iter().flatten().copied().collect();
170        all_input_tokens.sort_unstable();
171
172        assert_eq!(
173            all_packed_tokens, all_input_tokens,
174            "all input tokens must appear exactly once in output"
175        );
176    }
177
178    #[test]
179    fn pack_sequences_sequence_exceeds_max_length_returns_error() {
180        let long_seq = vec![1u32; 100];
181        let slices = vec![long_seq.as_slice()];
182        let result = pack_sequences(&slices, 50);
183        assert!(
184            result.is_err(),
185            "sequence longer than max_length must return Err"
186        );
187    }
188
189    #[test]
190    fn pack_sequences_empty_input_returns_empty() {
191        let result = pack_sequences(&[], 512).unwrap();
192        assert!(result.is_empty());
193    }
194
195    #[test]
196    fn pack_sequences_single_sequence() {
197        let seq = vec![10u32, 20, 30];
198        let slices = vec![seq.as_slice()];
199        let packed = pack_sequences(&slices, 512).unwrap();
200        assert_eq!(packed.len(), 1);
201        assert_eq!(&packed[0].input_ids[..3], &[10, 20, 30]);
202    }
203
204    #[test]
205    fn pack_sequences_attention_mask_matches_input_ids_length() {
206        let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5]];
207        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
208        let packed = pack_sequences(&slices, 8).unwrap();
209        for bin in &packed {
210            assert_eq!(
211                bin.input_ids.len(),
212                bin.attention_mask.len(),
213                "attention_mask must have same length as input_ids"
214            );
215        }
216    }
217
218    #[test]
219    fn pack_sequences_position_ids_per_sequence_start_from_zero() {
220        let seqs: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5]];
221        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
222        let packed = pack_sequences(&slices, 8).unwrap();
223
224        for bin in &packed {
225            for (k, &start) in bin.sequence_starts.iter().enumerate() {
226                let end = if k + 1 < bin.sequence_starts.len() {
227                    bin.sequence_starts[k + 1]
228                } else {
229                    bin.input_ids.len()
230                };
231                for (j, pos) in (start..end).enumerate() {
232                    assert_eq!(
233                        bin.position_ids[pos], j as u32,
234                        "position_ids[{pos}] should be {j}"
235                    );
236                }
237            }
238        }
239    }
240
241    #[test]
242    fn pack_sequences_fill_rate_good_for_varied_lengths() {
243        let lengths = [10, 50, 20, 80, 30, 60, 15, 100, 40, 25];
244        let seqs: Vec<Vec<u32>> = lengths
245            .iter()
246            .enumerate()
247            .map(|(i, &l)| (0..l as u32).map(|t| i as u32 * 1000 + t).collect())
248            .collect();
249        let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
250        let max_len = 128;
251        let packed = pack_sequences(&slices, max_len).unwrap();
252
253        let total_tokens: usize = lengths.iter().sum();
254        let total_capacity: usize = packed.iter().map(|b| b.input_ids.len()).sum();
255        let fill_rate = total_tokens as f64 / total_capacity as f64;
256
257        assert!(
258            fill_rate > 0.5,
259            "fill_rate {fill_rate:.2} is below 0.5 — bin packing too inefficient"
260        );
261    }
262
263    mod proptests {
264        use super::*;
265        use proptest::collection::vec;
266        use proptest::prelude::*;
267
268        proptest! {
269            #[test]
270            fn pack_sequences_no_bin_exceeds_max_length(
271                lengths in vec(1usize..100, 1..20),
272                max_len in 10usize..200,
273            ) {
274                let seqs: Vec<Vec<u32>> = lengths
275                    .iter()
276                    .filter(|&&l| l <= max_len)
277                    .map(|&l| (0..l as u32).collect())
278                    .collect();
279                if seqs.is_empty() {
280                    return Ok(());
281                }
282                let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
283                let packed = pack_sequences(&slices, max_len).unwrap();
284                for bin in &packed {
285                    prop_assert!(
286                        bin.input_ids.len() <= max_len,
287                        "bin length {} exceeds max_len {}",
288                        bin.input_ids.len(),
289                        max_len
290                    );
291                }
292            }
293
294            #[test]
295            fn pack_sequences_all_tokens_present(
296                lengths in vec(1usize..50, 1..15),
297            ) {
298                let max_len = 128;
299                let seqs: Vec<Vec<u32>> = lengths
300                    .iter()
301                    .filter(|&&l| l <= max_len)
302                    .enumerate()
303                    .map(|(i, &l)| (0..l as u32).map(|t| i as u32 * 1000 + t).collect())
304                    .collect();
305                if seqs.is_empty() {
306                    return Ok(());
307                }
308                let slices: Vec<&[u32]> = seqs.iter().map(|s| s.as_slice()).collect();
309                let packed = pack_sequences(&slices, max_len).unwrap();
310
311                let mut packed_tokens: Vec<u32> = packed
312                    .iter()
313                    .flat_map(|b| {
314                        b.input_ids
315                            .iter()
316                            .copied()
317                            .zip(b.attention_mask.iter().copied())
318                            .filter_map(|(t, m)| if m != 0 { Some(t) } else { None })
319                    })
320                    .collect();
321                packed_tokens.sort_unstable();
322
323                let mut input_tokens: Vec<u32> = seqs.iter().flatten().copied().collect();
324                input_tokens.sort_unstable();
325
326                prop_assert_eq!(packed_tokens, input_tokens);
327            }
328        }
329    }
330}