rlox_core/llm/
ops.rs

1macro_rules! impl_kl_ops {
2    ($mod_name:ident, $float:ty) => {
3        pub mod $mod_name {
4            use crate::error::RloxError;
5
6            /// GRPO group advantage: `(reward - mean) / std`.
7            /// Returns zeros if std < 1e-8.
8            pub fn compute_group_advantages(rewards: &[$float]) -> Vec<$float> {
9                if rewards.is_empty() {
10                    return Vec::new();
11                }
12
13                let n = rewards.len() as $float;
14                let mean = rewards.iter().sum::<$float>() / n;
15                let variance = rewards
16                    .iter()
17                    .map(|&r| (r - mean) * (r - mean))
18                    .sum::<$float>()
19                    / n;
20                let std = variance.sqrt();
21
22                if std < 1e-8 as $float {
23                    return vec![0.0 as $float; rewards.len()];
24                }
25
26                let inv_std = 1.0 as $float / std;
27                rewards.iter().map(|&r| (r - mean) * inv_std).collect()
28            }
29
30            /// Token-level KL divergence: `sum(exp(log_p) * (log_p - log_q))`.
31            pub fn compute_token_kl(
32                log_probs_policy: &[$float],
33                log_probs_ref: &[$float],
34            ) -> Result<$float, RloxError> {
35                if log_probs_policy.len() != log_probs_ref.len() {
36                    return Err(RloxError::ShapeMismatch {
37                        expected: format!("len={}", log_probs_policy.len()),
38                        got: format!("len={}", log_probs_ref.len()),
39                    });
40                }
41
42                Ok(log_probs_policy
43                    .iter()
44                    .zip(log_probs_ref.iter())
45                    .map(|(&log_p, &log_q)| log_p.exp() * (log_p - log_q))
46                    .sum())
47            }
48
49            /// Batched GRPO group advantages: process all groups in a single call.
50            ///
51            /// `rewards` is a flat slice of length `n_prompts * group_size`.
52            /// Returns a Vec of the same length with per-group z-score normalisation.
53            pub fn compute_batch_group_advantages(
54                rewards: &[$float],
55                group_size: usize,
56            ) -> Result<Vec<$float>, RloxError> {
57                if group_size == 0 {
58                    return Err(RloxError::ShapeMismatch {
59                        expected: "group_size > 0".to_string(),
60                        got: "0".to_string(),
61                    });
62                }
63                if rewards.len() % group_size != 0 {
64                    return Err(RloxError::ShapeMismatch {
65                        expected: format!("len divisible by {group_size}"),
66                        got: format!("len={}", rewards.len()),
67                    });
68                }
69
70                const PAR_ELEMENT_THRESHOLD: usize = 4096;
71                if rewards.len() >= PAR_ELEMENT_THRESHOLD {
72                    use rayon::prelude::*;
73                    let out: Vec<$float> = rewards
74                        .par_chunks_exact(group_size)
75                        .flat_map_iter(|group| compute_group_advantages(group))
76                        .collect();
77                    Ok(out)
78                } else {
79                    let mut out = Vec::with_capacity(rewards.len());
80                    for group in rewards.chunks_exact(group_size) {
81                        out.extend_from_slice(&compute_group_advantages(group));
82                    }
83                    Ok(out)
84                }
85            }
86
87            /// Token-level KL divergence using the Schulman (2020) estimator:
88            /// `sum(exp(log_p - log_q) - (log_p - log_q) - 1)`.
89            ///
90            /// This is the estimator used by TRL (HuggingFace). It is unbiased and
91            /// numerically more stable than the exact `exp(log_p) * (log_p - log_q)`.
92            pub fn compute_token_kl_schulman(
93                log_probs_policy: &[$float],
94                log_probs_ref: &[$float],
95            ) -> Result<$float, RloxError> {
96                if log_probs_policy.len() != log_probs_ref.len() {
97                    return Err(RloxError::ShapeMismatch {
98                        expected: format!("len={}", log_probs_policy.len()),
99                        got: format!("len={}", log_probs_ref.len()),
100                    });
101                }
102
103                Ok(log_probs_policy
104                    .iter()
105                    .zip(log_probs_ref.iter())
106                    .map(|(&log_p, &log_q)| {
107                        let r = log_p - log_q;
108                        r.exp() - r - 1.0 as $float
109                    })
110                    .sum())
111            }
112
113            /// Batched token-level KL divergence: process all sequences in a single call.
114            ///
115            /// `log_probs_policy` and `log_probs_ref` are flat slices of length `batch * seq_len`.
116            /// Returns a Vec of length `batch` with per-sequence KL values.
117            pub fn compute_batch_token_kl(
118                log_probs_policy: &[$float],
119                log_probs_ref: &[$float],
120                seq_len: usize,
121            ) -> Result<Vec<$float>, RloxError> {
122                if log_probs_policy.len() != log_probs_ref.len() {
123                    return Err(RloxError::ShapeMismatch {
124                        expected: format!("len={}", log_probs_policy.len()),
125                        got: format!("len={}", log_probs_ref.len()),
126                    });
127                }
128                if seq_len == 0 {
129                    return Err(RloxError::ShapeMismatch {
130                        expected: "seq_len > 0".to_string(),
131                        got: "0".to_string(),
132                    });
133                }
134                if log_probs_policy.len() % seq_len != 0 {
135                    return Err(RloxError::ShapeMismatch {
136                        expected: format!("len divisible by {seq_len}"),
137                        got: format!("len={}", log_probs_policy.len()),
138                    });
139                }
140
141                const PAR_ELEMENT_THRESHOLD: usize = 4096;
142                let batch_size = log_probs_policy.len() / seq_len;
143
144                let kl_for_seq = |i: usize| -> $float {
145                    let off = i * seq_len;
146                    let ps = &log_probs_policy[off..off + seq_len];
147                    let qs = &log_probs_ref[off..off + seq_len];
148                    ps.iter()
149                        .zip(qs.iter())
150                        .map(|(&log_p, &log_q)| log_p.exp() * (log_p - log_q))
151                        .sum()
152                };
153
154                let out = if log_probs_policy.len() >= PAR_ELEMENT_THRESHOLD {
155                    use rayon::prelude::*;
156                    (0..batch_size).into_par_iter().map(kl_for_seq).collect()
157                } else {
158                    (0..batch_size).map(kl_for_seq).collect()
159                };
160                Ok(out)
161            }
162
163            /// Batched token-level KL divergence using the Schulman (2020) estimator.
164            ///
165            /// `log_probs_policy` and `log_probs_ref` are flat slices of length `batch * seq_len`.
166            /// Returns a Vec of length `batch` with per-sequence KL values.
167            pub fn compute_batch_token_kl_schulman(
168                log_probs_policy: &[$float],
169                log_probs_ref: &[$float],
170                seq_len: usize,
171            ) -> Result<Vec<$float>, RloxError> {
172                if log_probs_policy.len() != log_probs_ref.len() {
173                    return Err(RloxError::ShapeMismatch {
174                        expected: format!("len={}", log_probs_policy.len()),
175                        got: format!("len={}", log_probs_ref.len()),
176                    });
177                }
178                if seq_len == 0 {
179                    return Err(RloxError::ShapeMismatch {
180                        expected: "seq_len > 0".to_string(),
181                        got: "0".to_string(),
182                    });
183                }
184                if log_probs_policy.len() % seq_len != 0 {
185                    return Err(RloxError::ShapeMismatch {
186                        expected: format!("len divisible by {seq_len}"),
187                        got: format!("len={}", log_probs_policy.len()),
188                    });
189                }
190
191                const PAR_ELEMENT_THRESHOLD: usize = 4096;
192                let batch_size = log_probs_policy.len() / seq_len;
193
194                let kl_for_seq = |i: usize| -> $float {
195                    let off = i * seq_len;
196                    let ps = &log_probs_policy[off..off + seq_len];
197                    let qs = &log_probs_ref[off..off + seq_len];
198                    ps.iter()
199                        .zip(qs.iter())
200                        .map(|(&log_p, &log_q)| {
201                            let r = log_p - log_q;
202                            r.exp() - r - 1.0 as $float
203                        })
204                        .sum()
205                };
206
207                let out = if log_probs_policy.len() >= PAR_ELEMENT_THRESHOLD {
208                    use rayon::prelude::*;
209                    (0..batch_size).into_par_iter().map(kl_for_seq).collect()
210                } else {
211                    (0..batch_size).map(kl_for_seq).collect()
212                };
213                Ok(out)
214            }
215        }
216    };
217}
218
219impl_kl_ops!(f64_ops, f64);
220impl_kl_ops!(f32_ops, f32);
221
222// Re-export f64 versions at module level for backward compatibility.
223pub use f64_ops::*;
224
225/// A DPO preference pair holding tokenized prompt, chosen, and rejected sequences.
226#[derive(Debug, Clone)]
227pub struct DPOPair {
228    pub prompt_tokens: Vec<u32>,
229    pub chosen_tokens: Vec<u32>,
230    pub rejected_tokens: Vec<u32>,
231}
232
233impl DPOPair {
234    pub fn new(
235        prompt_tokens: Vec<u32>,
236        chosen_tokens: Vec<u32>,
237        rejected_tokens: Vec<u32>,
238    ) -> Self {
239        Self {
240            prompt_tokens,
241            chosen_tokens,
242            rejected_tokens,
243        }
244    }
245
246    pub fn chosen_len(&self) -> usize {
247        self.chosen_tokens.len()
248    }
249
250    pub fn rejected_len(&self) -> usize {
251        self.rejected_tokens.len()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_group_advantages_basic() {
261        let rewards = [1.0, 0.5, 0.8];
262        let adv = compute_group_advantages(&rewards);
263        assert_eq!(adv.len(), 3);
264        let mean: f64 = adv.iter().sum::<f64>() / adv.len() as f64;
265        assert!(mean.abs() < 1e-10);
266    }
267
268    #[test]
269    fn test_group_advantages_constant_rewards() {
270        let rewards = [5.0, 5.0, 5.0];
271        let adv = compute_group_advantages(&rewards);
272        assert!(adv.iter().all(|&v| v == 0.0));
273    }
274
275    #[test]
276    fn test_group_advantages_empty() {
277        let adv = compute_group_advantages(&[]);
278        assert!(adv.is_empty());
279    }
280
281    #[test]
282    fn test_token_kl_identical() {
283        let log_p = [-1.0, -2.0, -0.5];
284        let kl = compute_token_kl(&log_p, &log_p).unwrap();
285        assert!(kl.abs() < 1e-15);
286    }
287
288    #[test]
289    fn test_token_kl_known_value() {
290        let log_p = [-1.0];
291        let log_q = [-2.0];
292        let kl = compute_token_kl(&log_p, &log_q).unwrap();
293        assert!((kl - (-1.0_f64).exp()).abs() < 1e-10);
294    }
295
296    #[test]
297    fn test_token_kl_mismatched_lengths_returns_err() {
298        let result = compute_token_kl(&[1.0, 2.0], &[1.0]);
299        assert!(result.is_err());
300    }
301
302    #[test]
303    fn token_kl_mismatched_lengths_returns_err_not_panic() {
304        let log_p = vec![-1.0f64, -2.0];
305        let log_q = vec![-1.0f64];
306        let result = compute_token_kl(&log_p, &log_q);
307        assert!(result.is_err(), "mismatched lengths must return Err");
308    }
309
310    #[test]
311    fn token_kl_matching_lengths_returns_ok() {
312        let log_p = vec![-1.0f64, -2.0, -0.5];
313        let log_q = vec![-1.0f64, -2.0, -0.5];
314        let result = compute_token_kl(&log_p, &log_q);
315        assert!(result.is_ok());
316        assert!(result.unwrap().abs() < 1e-15);
317    }
318
319    #[test]
320    fn token_kl_empty_slices_returns_zero() {
321        let result = compute_token_kl(&[], &[]);
322        assert!(result.is_ok());
323        assert_eq!(result.unwrap(), 0.0);
324    }
325
326    #[test]
327    fn token_kl_nan_input_propagates_to_output() {
328        let log_p = vec![f64::NAN];
329        let log_q = vec![-1.0f64];
330        let result = compute_token_kl(&log_p, &log_q);
331        if let Ok(v) = result {
332            assert!(v.is_nan(), "NaN input should produce NaN output");
333        }
334    }
335
336    #[test]
337    fn token_kl_inf_input_does_not_panic() {
338        let log_p = vec![f64::INFINITY];
339        let log_q = vec![-1.0f64];
340        let _result = compute_token_kl(&log_p, &log_q);
341    }
342
343    #[test]
344    fn token_kl_known_value_still_correct_after_refactor() {
345        let log_p = vec![-1.0f64];
346        let log_q = vec![-2.0f64];
347        let kl = compute_token_kl(&log_p, &log_q).unwrap();
348        assert!((kl - (-1.0_f64).exp()).abs() < 1e-10);
349    }
350
351    #[test]
352    fn test_batch_group_advantages() {
353        let rewards = [1.0, 2.0, 3.0, 10.0, 10.0, 10.0];
354        let adv = compute_batch_group_advantages(&rewards, 3).unwrap();
355        assert_eq!(adv.len(), 6);
356        let g1_mean: f64 = adv[..3].iter().sum::<f64>() / 3.0;
357        assert!(g1_mean.abs() < 1e-10);
358        assert!(adv[3..6].iter().all(|&v| v == 0.0));
359    }
360
361    #[test]
362    fn test_batch_group_advantages_bad_size() {
363        assert!(compute_batch_group_advantages(&[1.0, 2.0, 3.0], 2).is_err());
364        assert!(compute_batch_group_advantages(&[1.0], 0).is_err());
365    }
366
367    #[test]
368    fn test_token_kl_schulman_identical() {
369        let log_p = [-1.0, -2.0, -0.5];
370        let kl = compute_token_kl_schulman(&log_p, &log_p).unwrap();
371        assert!(kl.abs() < 1e-15);
372    }
373
374    #[test]
375    fn test_token_kl_schulman_known_value() {
376        let log_p = [-1.0];
377        let log_q = [-2.0];
378        let kl = compute_token_kl_schulman(&log_p, &log_q).unwrap();
379        assert!((kl - (1.0_f64.exp() - 2.0)).abs() < 1e-10);
380    }
381
382    #[test]
383    fn test_token_kl_schulman_non_negative() {
384        let log_p = [-0.5, -1.0, -3.0, 0.0];
385        let log_q = [-1.0, -0.5, -0.1, -2.0];
386        let kl = compute_token_kl_schulman(&log_p, &log_q).unwrap();
387        assert!(kl >= 0.0, "Schulman KL should be non-negative, got {kl}");
388    }
389
390    #[test]
391    fn test_dpo_pair() {
392        let pair = DPOPair::new(vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]);
393        assert_eq!(pair.chosen_len(), 2);
394        assert_eq!(pair.rejected_len(), 3);
395        assert_eq!(pair.prompt_tokens.len(), 3);
396    }
397
398    // --- Batched KL tests ---
399
400    #[test]
401    fn test_batch_token_kl_matches_unbatched() {
402        let log_p = vec![-1.0, -2.0, -0.5, -1.5, -0.3, -2.5];
403        let log_q = vec![-1.1, -1.9, -0.6, -1.4, -0.4, -2.4];
404        let batched = compute_batch_token_kl(&log_p, &log_q, 3).unwrap();
405        let kl0 = compute_token_kl(&log_p[..3], &log_q[..3]).unwrap();
406        let kl1 = compute_token_kl(&log_p[3..], &log_q[3..]).unwrap();
407        assert_eq!(batched.len(), 2);
408        assert!((batched[0] - kl0).abs() < 1e-12);
409        assert!((batched[1] - kl1).abs() < 1e-12);
410    }
411
412    #[test]
413    fn test_batch_token_kl_schulman_matches_unbatched() {
414        let log_p = vec![-1.0, -2.0, -0.5, -1.5, -0.3, -2.5];
415        let log_q = vec![-1.1, -1.9, -0.6, -1.4, -0.4, -2.4];
416        let batched = compute_batch_token_kl_schulman(&log_p, &log_q, 3).unwrap();
417        let kl0 = compute_token_kl_schulman(&log_p[..3], &log_q[..3]).unwrap();
418        let kl1 = compute_token_kl_schulman(&log_p[3..], &log_q[3..]).unwrap();
419        assert_eq!(batched.len(), 2);
420        assert!((batched[0] - kl0).abs() < 1e-12);
421        assert!((batched[1] - kl1).abs() < 1e-12);
422    }
423
424    #[test]
425    fn test_batch_token_kl_bad_seq_len() {
426        assert!(compute_batch_token_kl(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 0).is_err());
427        assert!(compute_batch_token_kl(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 2).is_err());
428    }
429
430    #[test]
431    fn test_batch_token_kl_mismatched_lengths() {
432        assert!(compute_batch_token_kl(&[1.0, 2.0], &[1.0], 1).is_err());
433    }
434
435    // --- f32 variant tests ---
436
437    #[test]
438    fn test_f32_token_kl_identical() {
439        let log_p: Vec<f32> = vec![-1.0, -2.0, -0.5];
440        let kl = f32_ops::compute_token_kl(&log_p, &log_p).unwrap();
441        assert!(kl.abs() < 1e-6);
442    }
443
444    #[test]
445    fn test_f32_batch_token_kl_schulman() {
446        let log_p: Vec<f32> = vec![-1.0, -2.0, -0.5, -1.5];
447        let log_q: Vec<f32> = vec![-1.1, -1.9, -0.6, -1.4];
448        let batched = f32_ops::compute_batch_token_kl_schulman(&log_p, &log_q, 2).unwrap();
449        assert_eq!(batched.len(), 2);
450        let kl0 = f32_ops::compute_token_kl_schulman(&log_p[..2], &log_q[..2]).unwrap();
451        assert!((batched[0] - kl0).abs() < 1e-6);
452    }
453
454    #[test]
455    fn test_f32_group_advantages() {
456        let rewards: Vec<f32> = vec![1.0, 2.0, 3.0];
457        let adv = f32_ops::compute_group_advantages(&rewards);
458        assert_eq!(adv.len(), 3);
459        let mean: f32 = adv.iter().sum::<f32>() / 3.0;
460        assert!(mean.abs() < 1e-5);
461    }
462}