rlox_candle/
entropy.rs

1use candle_core::{Device, Tensor};
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{EntropyTuner, NNError, TensorData};
5
6use crate::convert::*;
7
8pub struct CandleEntropyTuner {
9    #[allow(dead_code)]
10    varmap: VarMap,
11    log_alpha: Tensor,
12    optimizer: candle_nn::AdamW,
13    device: Device,
14}
15
16impl CandleEntropyTuner {
17    pub fn new(lr: f64, device: Device) -> Result<Self, NNError> {
18        let varmap = VarMap::new();
19        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
20        let log_alpha = vb
21            .get_with_hints(1, "log_alpha", candle_nn::Init::Const(0.0))
22            .nn_err()?;
23
24        let params = varmap.all_vars();
25        let optimizer = candle_nn::AdamW::new(
26            params,
27            candle_nn::ParamsAdamW {
28                lr,
29                ..Default::default()
30            },
31        )
32        .nn_err()?;
33
34        Ok(Self {
35            varmap,
36            log_alpha,
37            optimizer,
38            device,
39        })
40    }
41}
42
43impl EntropyTuner for CandleEntropyTuner {
44    fn alpha(&self) -> f32 {
45        let val: Vec<f32> = self.log_alpha.exp().unwrap().to_vec1().unwrap();
46        val[0]
47    }
48
49    fn update(&mut self, log_probs: &TensorData, target_entropy: f32) -> Result<f64, NNError> {
50        let lp = to_tensor_1d(log_probs, &self.device).nn_err()?;
51        let alpha = self.log_alpha.exp().nn_err()?;
52        let batch_size = log_probs.data.len();
53        let alpha_broadcast = alpha.broadcast_as(batch_size).nn_err()?;
54        let alpha_loss = (&alpha_broadcast * &(&lp + target_entropy as f64).nn_err()?)
55            .nn_err()?
56            .mean_all()
57            .nn_err()?
58            .neg()
59            .nn_err()?;
60
61        self.optimizer.backward_step(&alpha_loss).nn_err()?;
62
63        let loss_val: f32 = alpha_loss.to_scalar().nn_err()?;
64        Ok(loss_val as f64)
65    }
66}
67
68unsafe impl Send for CandleEntropyTuner {}
69unsafe impl Sync for CandleEntropyTuner {}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn test_initial_alpha() {
77        let tuner = CandleEntropyTuner::new(3e-4, Device::Cpu).unwrap();
78        let alpha = tuner.alpha();
79        assert!((alpha - 1.0).abs() < 1e-4, "exp(0) = 1: got {alpha}");
80    }
81
82    #[test]
83    fn test_update_runs() {
84        let mut tuner = CandleEntropyTuner::new(3e-4, Device::Cpu).unwrap();
85        let log_probs = TensorData::new(vec![-1.0; 32], vec![32]);
86        let loss = tuner.update(&log_probs, -1.0).unwrap();
87        assert!(loss.is_finite());
88    }
89
90    #[test]
91    fn test_alpha_changes() {
92        let mut tuner = CandleEntropyTuner::new(1e-2, Device::Cpu).unwrap();
93        let before = tuner.alpha();
94        let log_probs = TensorData::new(vec![-2.0; 32], vec![32]);
95        tuner.update(&log_probs, -1.0).unwrap();
96        let after = tuner.alpha();
97        assert!((before - after).abs() > 1e-6, "{before} vs {after}");
98    }
99}