1macro_rules! impl_kl_ops {
2 ($mod_name:ident, $float:ty) => {
3 pub mod $mod_name {
4 use crate::error::RloxError;
5
6 pub fn compute_group_advantages(rewards: &[$float]) -> Vec<$float> {
9 if rewards.is_empty() {
10 return Vec::new();
11 }
12
13 let n = rewards.len() as $float;
14 let mean = rewards.iter().sum::<$float>() / n;
15 let variance = rewards
16 .iter()
17 .map(|&r| (r - mean) * (r - mean))
18 .sum::<$float>()
19 / n;
20 let std = variance.sqrt();
21
22 if std < 1e-8 as $float {
23 return vec![0.0 as $float; rewards.len()];
24 }
25
26 let inv_std = 1.0 as $float / std;
27 rewards.iter().map(|&r| (r - mean) * inv_std).collect()
28 }
29
30 pub fn compute_token_kl(
32 log_probs_policy: &[$float],
33 log_probs_ref: &[$float],
34 ) -> Result<$float, RloxError> {
35 if log_probs_policy.len() != log_probs_ref.len() {
36 return Err(RloxError::ShapeMismatch {
37 expected: format!("len={}", log_probs_policy.len()),
38 got: format!("len={}", log_probs_ref.len()),
39 });
40 }
41
42 Ok(log_probs_policy
43 .iter()
44 .zip(log_probs_ref.iter())
45 .map(|(&log_p, &log_q)| log_p.exp() * (log_p - log_q))
46 .sum())
47 }
48
49 pub fn compute_batch_group_advantages(
54 rewards: &[$float],
55 group_size: usize,
56 ) -> Result<Vec<$float>, RloxError> {
57 if group_size == 0 {
58 return Err(RloxError::ShapeMismatch {
59 expected: "group_size > 0".to_string(),
60 got: "0".to_string(),
61 });
62 }
63 if rewards.len() % group_size != 0 {
64 return Err(RloxError::ShapeMismatch {
65 expected: format!("len divisible by {group_size}"),
66 got: format!("len={}", rewards.len()),
67 });
68 }
69
70 const PAR_ELEMENT_THRESHOLD: usize = 4096;
71 if rewards.len() >= PAR_ELEMENT_THRESHOLD {
72 use rayon::prelude::*;
73 let out: Vec<$float> = rewards
74 .par_chunks_exact(group_size)
75 .flat_map_iter(|group| compute_group_advantages(group))
76 .collect();
77 Ok(out)
78 } else {
79 let mut out = Vec::with_capacity(rewards.len());
80 for group in rewards.chunks_exact(group_size) {
81 out.extend_from_slice(&compute_group_advantages(group));
82 }
83 Ok(out)
84 }
85 }
86
87 pub fn compute_token_kl_schulman(
93 log_probs_policy: &[$float],
94 log_probs_ref: &[$float],
95 ) -> Result<$float, RloxError> {
96 if log_probs_policy.len() != log_probs_ref.len() {
97 return Err(RloxError::ShapeMismatch {
98 expected: format!("len={}", log_probs_policy.len()),
99 got: format!("len={}", log_probs_ref.len()),
100 });
101 }
102
103 Ok(log_probs_policy
104 .iter()
105 .zip(log_probs_ref.iter())
106 .map(|(&log_p, &log_q)| {
107 let r = log_p - log_q;
108 r.exp() - r - 1.0 as $float
109 })
110 .sum())
111 }
112
113 pub fn compute_batch_token_kl(
118 log_probs_policy: &[$float],
119 log_probs_ref: &[$float],
120 seq_len: usize,
121 ) -> Result<Vec<$float>, RloxError> {
122 if log_probs_policy.len() != log_probs_ref.len() {
123 return Err(RloxError::ShapeMismatch {
124 expected: format!("len={}", log_probs_policy.len()),
125 got: format!("len={}", log_probs_ref.len()),
126 });
127 }
128 if seq_len == 0 {
129 return Err(RloxError::ShapeMismatch {
130 expected: "seq_len > 0".to_string(),
131 got: "0".to_string(),
132 });
133 }
134 if log_probs_policy.len() % seq_len != 0 {
135 return Err(RloxError::ShapeMismatch {
136 expected: format!("len divisible by {seq_len}"),
137 got: format!("len={}", log_probs_policy.len()),
138 });
139 }
140
141 const PAR_ELEMENT_THRESHOLD: usize = 4096;
142 let batch_size = log_probs_policy.len() / seq_len;
143
144 let kl_for_seq = |i: usize| -> $float {
145 let off = i * seq_len;
146 let ps = &log_probs_policy[off..off + seq_len];
147 let qs = &log_probs_ref[off..off + seq_len];
148 ps.iter()
149 .zip(qs.iter())
150 .map(|(&log_p, &log_q)| log_p.exp() * (log_p - log_q))
151 .sum()
152 };
153
154 let out = if log_probs_policy.len() >= PAR_ELEMENT_THRESHOLD {
155 use rayon::prelude::*;
156 (0..batch_size).into_par_iter().map(kl_for_seq).collect()
157 } else {
158 (0..batch_size).map(kl_for_seq).collect()
159 };
160 Ok(out)
161 }
162
163 pub fn compute_batch_token_kl_schulman(
168 log_probs_policy: &[$float],
169 log_probs_ref: &[$float],
170 seq_len: usize,
171 ) -> Result<Vec<$float>, RloxError> {
172 if log_probs_policy.len() != log_probs_ref.len() {
173 return Err(RloxError::ShapeMismatch {
174 expected: format!("len={}", log_probs_policy.len()),
175 got: format!("len={}", log_probs_ref.len()),
176 });
177 }
178 if seq_len == 0 {
179 return Err(RloxError::ShapeMismatch {
180 expected: "seq_len > 0".to_string(),
181 got: "0".to_string(),
182 });
183 }
184 if log_probs_policy.len() % seq_len != 0 {
185 return Err(RloxError::ShapeMismatch {
186 expected: format!("len divisible by {seq_len}"),
187 got: format!("len={}", log_probs_policy.len()),
188 });
189 }
190
191 const PAR_ELEMENT_THRESHOLD: usize = 4096;
192 let batch_size = log_probs_policy.len() / seq_len;
193
194 let kl_for_seq = |i: usize| -> $float {
195 let off = i * seq_len;
196 let ps = &log_probs_policy[off..off + seq_len];
197 let qs = &log_probs_ref[off..off + seq_len];
198 ps.iter()
199 .zip(qs.iter())
200 .map(|(&log_p, &log_q)| {
201 let r = log_p - log_q;
202 r.exp() - r - 1.0 as $float
203 })
204 .sum()
205 };
206
207 let out = if log_probs_policy.len() >= PAR_ELEMENT_THRESHOLD {
208 use rayon::prelude::*;
209 (0..batch_size).into_par_iter().map(kl_for_seq).collect()
210 } else {
211 (0..batch_size).map(kl_for_seq).collect()
212 };
213 Ok(out)
214 }
215 }
216 };
217}
218
219impl_kl_ops!(f64_ops, f64);
220impl_kl_ops!(f32_ops, f32);
221
222pub use f64_ops::*;
224
225#[derive(Debug, Clone)]
227pub struct DPOPair {
228 pub prompt_tokens: Vec<u32>,
229 pub chosen_tokens: Vec<u32>,
230 pub rejected_tokens: Vec<u32>,
231}
232
233impl DPOPair {
234 pub fn new(
235 prompt_tokens: Vec<u32>,
236 chosen_tokens: Vec<u32>,
237 rejected_tokens: Vec<u32>,
238 ) -> Self {
239 Self {
240 prompt_tokens,
241 chosen_tokens,
242 rejected_tokens,
243 }
244 }
245
246 pub fn chosen_len(&self) -> usize {
247 self.chosen_tokens.len()
248 }
249
250 pub fn rejected_len(&self) -> usize {
251 self.rejected_tokens.len()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_group_advantages_basic() {
261 let rewards = [1.0, 0.5, 0.8];
262 let adv = compute_group_advantages(&rewards);
263 assert_eq!(adv.len(), 3);
264 let mean: f64 = adv.iter().sum::<f64>() / adv.len() as f64;
265 assert!(mean.abs() < 1e-10);
266 }
267
268 #[test]
269 fn test_group_advantages_constant_rewards() {
270 let rewards = [5.0, 5.0, 5.0];
271 let adv = compute_group_advantages(&rewards);
272 assert!(adv.iter().all(|&v| v == 0.0));
273 }
274
275 #[test]
276 fn test_group_advantages_empty() {
277 let adv = compute_group_advantages(&[]);
278 assert!(adv.is_empty());
279 }
280
281 #[test]
282 fn test_token_kl_identical() {
283 let log_p = [-1.0, -2.0, -0.5];
284 let kl = compute_token_kl(&log_p, &log_p).unwrap();
285 assert!(kl.abs() < 1e-15);
286 }
287
288 #[test]
289 fn test_token_kl_known_value() {
290 let log_p = [-1.0];
291 let log_q = [-2.0];
292 let kl = compute_token_kl(&log_p, &log_q).unwrap();
293 assert!((kl - (-1.0_f64).exp()).abs() < 1e-10);
294 }
295
296 #[test]
297 fn test_token_kl_mismatched_lengths_returns_err() {
298 let result = compute_token_kl(&[1.0, 2.0], &[1.0]);
299 assert!(result.is_err());
300 }
301
302 #[test]
303 fn token_kl_mismatched_lengths_returns_err_not_panic() {
304 let log_p = vec![-1.0f64, -2.0];
305 let log_q = vec![-1.0f64];
306 let result = compute_token_kl(&log_p, &log_q);
307 assert!(result.is_err(), "mismatched lengths must return Err");
308 }
309
310 #[test]
311 fn token_kl_matching_lengths_returns_ok() {
312 let log_p = vec![-1.0f64, -2.0, -0.5];
313 let log_q = vec![-1.0f64, -2.0, -0.5];
314 let result = compute_token_kl(&log_p, &log_q);
315 assert!(result.is_ok());
316 assert!(result.unwrap().abs() < 1e-15);
317 }
318
319 #[test]
320 fn token_kl_empty_slices_returns_zero() {
321 let result = compute_token_kl(&[], &[]);
322 assert!(result.is_ok());
323 assert_eq!(result.unwrap(), 0.0);
324 }
325
326 #[test]
327 fn token_kl_nan_input_propagates_to_output() {
328 let log_p = vec![f64::NAN];
329 let log_q = vec![-1.0f64];
330 let result = compute_token_kl(&log_p, &log_q);
331 if let Ok(v) = result {
332 assert!(v.is_nan(), "NaN input should produce NaN output");
333 }
334 }
335
336 #[test]
337 fn token_kl_inf_input_does_not_panic() {
338 let log_p = vec![f64::INFINITY];
339 let log_q = vec![-1.0f64];
340 let _result = compute_token_kl(&log_p, &log_q);
341 }
342
343 #[test]
344 fn token_kl_known_value_still_correct_after_refactor() {
345 let log_p = vec![-1.0f64];
346 let log_q = vec![-2.0f64];
347 let kl = compute_token_kl(&log_p, &log_q).unwrap();
348 assert!((kl - (-1.0_f64).exp()).abs() < 1e-10);
349 }
350
351 #[test]
352 fn test_batch_group_advantages() {
353 let rewards = [1.0, 2.0, 3.0, 10.0, 10.0, 10.0];
354 let adv = compute_batch_group_advantages(&rewards, 3).unwrap();
355 assert_eq!(adv.len(), 6);
356 let g1_mean: f64 = adv[..3].iter().sum::<f64>() / 3.0;
357 assert!(g1_mean.abs() < 1e-10);
358 assert!(adv[3..6].iter().all(|&v| v == 0.0));
359 }
360
361 #[test]
362 fn test_batch_group_advantages_bad_size() {
363 assert!(compute_batch_group_advantages(&[1.0, 2.0, 3.0], 2).is_err());
364 assert!(compute_batch_group_advantages(&[1.0], 0).is_err());
365 }
366
367 #[test]
368 fn test_token_kl_schulman_identical() {
369 let log_p = [-1.0, -2.0, -0.5];
370 let kl = compute_token_kl_schulman(&log_p, &log_p).unwrap();
371 assert!(kl.abs() < 1e-15);
372 }
373
374 #[test]
375 fn test_token_kl_schulman_known_value() {
376 let log_p = [-1.0];
377 let log_q = [-2.0];
378 let kl = compute_token_kl_schulman(&log_p, &log_q).unwrap();
379 assert!((kl - (1.0_f64.exp() - 2.0)).abs() < 1e-10);
380 }
381
382 #[test]
383 fn test_token_kl_schulman_non_negative() {
384 let log_p = [-0.5, -1.0, -3.0, 0.0];
385 let log_q = [-1.0, -0.5, -0.1, -2.0];
386 let kl = compute_token_kl_schulman(&log_p, &log_q).unwrap();
387 assert!(kl >= 0.0, "Schulman KL should be non-negative, got {kl}");
388 }
389
390 #[test]
391 fn test_dpo_pair() {
392 let pair = DPOPair::new(vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]);
393 assert_eq!(pair.chosen_len(), 2);
394 assert_eq!(pair.rejected_len(), 3);
395 assert_eq!(pair.prompt_tokens.len(), 3);
396 }
397
398 #[test]
401 fn test_batch_token_kl_matches_unbatched() {
402 let log_p = vec![-1.0, -2.0, -0.5, -1.5, -0.3, -2.5];
403 let log_q = vec![-1.1, -1.9, -0.6, -1.4, -0.4, -2.4];
404 let batched = compute_batch_token_kl(&log_p, &log_q, 3).unwrap();
405 let kl0 = compute_token_kl(&log_p[..3], &log_q[..3]).unwrap();
406 let kl1 = compute_token_kl(&log_p[3..], &log_q[3..]).unwrap();
407 assert_eq!(batched.len(), 2);
408 assert!((batched[0] - kl0).abs() < 1e-12);
409 assert!((batched[1] - kl1).abs() < 1e-12);
410 }
411
412 #[test]
413 fn test_batch_token_kl_schulman_matches_unbatched() {
414 let log_p = vec![-1.0, -2.0, -0.5, -1.5, -0.3, -2.5];
415 let log_q = vec![-1.1, -1.9, -0.6, -1.4, -0.4, -2.4];
416 let batched = compute_batch_token_kl_schulman(&log_p, &log_q, 3).unwrap();
417 let kl0 = compute_token_kl_schulman(&log_p[..3], &log_q[..3]).unwrap();
418 let kl1 = compute_token_kl_schulman(&log_p[3..], &log_q[3..]).unwrap();
419 assert_eq!(batched.len(), 2);
420 assert!((batched[0] - kl0).abs() < 1e-12);
421 assert!((batched[1] - kl1).abs() < 1e-12);
422 }
423
424 #[test]
425 fn test_batch_token_kl_bad_seq_len() {
426 assert!(compute_batch_token_kl(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 0).is_err());
427 assert!(compute_batch_token_kl(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 2).is_err());
428 }
429
430 #[test]
431 fn test_batch_token_kl_mismatched_lengths() {
432 assert!(compute_batch_token_kl(&[1.0, 2.0], &[1.0], 1).is_err());
433 }
434
435 #[test]
438 fn test_f32_token_kl_identical() {
439 let log_p: Vec<f32> = vec![-1.0, -2.0, -0.5];
440 let kl = f32_ops::compute_token_kl(&log_p, &log_p).unwrap();
441 assert!(kl.abs() < 1e-6);
442 }
443
444 #[test]
445 fn test_f32_batch_token_kl_schulman() {
446 let log_p: Vec<f32> = vec![-1.0, -2.0, -0.5, -1.5];
447 let log_q: Vec<f32> = vec![-1.1, -1.9, -0.6, -1.4];
448 let batched = f32_ops::compute_batch_token_kl_schulman(&log_p, &log_q, 2).unwrap();
449 assert_eq!(batched.len(), 2);
450 let kl0 = f32_ops::compute_token_kl_schulman(&log_p[..2], &log_q[..2]).unwrap();
451 assert!((batched[0] - kl0).abs() < 1e-6);
452 }
453
454 #[test]
455 fn test_f32_group_advantages() {
456 let rewards: Vec<f32> = vec![1.0, 2.0, 3.0];
457 let adv = f32_ops::compute_group_advantages(&rewards);
458 assert_eq!(adv.len(), 3);
459 let mean: f32 = adv.iter().sum::<f32>() / 3.0;
460 assert!(mean.abs() < 1e-5);
461 }
462}