rlox_core/training/
kl.rs

1/// Adaptive KL penalty controller (Ziegler et al. 2019).
2///
3/// Adjusts a coefficient based on measured vs target KL divergence:
4/// - If measured KL > 1.5 * target: multiply coefficient by 2
5/// - If measured KL < target / 1.5: divide coefficient by 2
6/// - Otherwise: leave coefficient unchanged
7pub struct KLController {
8    coefficient: f64,
9    target_kl: f64,
10}
11
12impl KLController {
13    pub fn new(init_coeff: f64, target_kl: f64) -> Self {
14        Self {
15            coefficient: init_coeff,
16            target_kl,
17        }
18    }
19
20    pub fn coefficient(&self) -> f64 {
21        self.coefficient
22    }
23
24    pub fn update(&mut self, measured_kl: f64) {
25        if measured_kl > 1.5 * self.target_kl {
26            self.coefficient *= 2.0;
27        } else if measured_kl < self.target_kl / 1.5 {
28            self.coefficient /= 2.0;
29        }
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36
37    #[test]
38    fn kl_controller_initial_coefficient() {
39        let kl = KLController::new(0.01, 0.02);
40        assert_eq!(kl.coefficient(), 0.01);
41    }
42
43    #[test]
44    fn kl_controller_increases_on_high_kl() {
45        let mut kl = KLController::new(0.01, 0.02);
46        kl.update(0.05); // measured KL >> target
47        assert!(kl.coefficient() > 0.01);
48    }
49
50    #[test]
51    fn kl_controller_decreases_on_low_kl() {
52        let mut kl = KLController::new(0.01, 0.02);
53        kl.update(0.005); // measured KL << target
54        assert!(kl.coefficient() < 0.01);
55    }
56
57    #[test]
58    fn kl_controller_stays_near_target() {
59        let mut kl = KLController::new(0.01, 0.02);
60        kl.update(0.02); // measured KL == target
61                         // Coefficient should remain exactly the same (within dead zone)
62        assert!((kl.coefficient() - 0.01).abs() < 0.005);
63    }
64
65    #[test]
66    fn kl_controller_has_floor() {
67        let mut kl = KLController::new(0.01, 0.02);
68        for _ in 0..100 {
69            kl.update(0.0001); // very low KL
70        }
71        assert!(kl.coefficient() > 0.0); // never goes to zero
72    }
73}