1pub struct KLController {
8 coefficient: f64,
9 target_kl: f64,
10}
11
12impl KLController {
13 pub fn new(init_coeff: f64, target_kl: f64) -> Self {
14 Self {
15 coefficient: init_coeff,
16 target_kl,
17 }
18 }
19
20 pub fn coefficient(&self) -> f64 {
21 self.coefficient
22 }
23
24 pub fn update(&mut self, measured_kl: f64) {
25 if measured_kl > 1.5 * self.target_kl {
26 self.coefficient *= 2.0;
27 } else if measured_kl < self.target_kl / 1.5 {
28 self.coefficient /= 2.0;
29 }
30 }
31}
32
33#[cfg(test)]
34mod tests {
35 use super::*;
36
37 #[test]
38 fn kl_controller_initial_coefficient() {
39 let kl = KLController::new(0.01, 0.02);
40 assert_eq!(kl.coefficient(), 0.01);
41 }
42
43 #[test]
44 fn kl_controller_increases_on_high_kl() {
45 let mut kl = KLController::new(0.01, 0.02);
46 kl.update(0.05); assert!(kl.coefficient() > 0.01);
48 }
49
50 #[test]
51 fn kl_controller_decreases_on_low_kl() {
52 let mut kl = KLController::new(0.01, 0.02);
53 kl.update(0.005); assert!(kl.coefficient() < 0.01);
55 }
56
57 #[test]
58 fn kl_controller_stays_near_target() {
59 let mut kl = KLController::new(0.01, 0.02);
60 kl.update(0.02); assert!((kl.coefficient() - 0.01).abs() < 0.005);
63 }
64
65 #[test]
66 fn kl_controller_has_floor() {
67 let mut kl = KLController::new(0.01, 0.02);
68 for _ in 0..100 {
69 kl.update(0.0001); }
71 assert!(kl.coefficient() > 0.0); }
73}