1use candle_core::{Device, Tensor};
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{EntropyTuner, NNError, TensorData};
5
6use crate::convert::*;
7
8pub struct CandleEntropyTuner {
9 #[allow(dead_code)]
10 varmap: VarMap,
11 log_alpha: Tensor,
12 optimizer: candle_nn::AdamW,
13 device: Device,
14}
15
16impl CandleEntropyTuner {
17 pub fn new(lr: f64, device: Device) -> Result<Self, NNError> {
18 let varmap = VarMap::new();
19 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
20 let log_alpha = vb
21 .get_with_hints(1, "log_alpha", candle_nn::Init::Const(0.0))
22 .nn_err()?;
23
24 let params = varmap.all_vars();
25 let optimizer = candle_nn::AdamW::new(
26 params,
27 candle_nn::ParamsAdamW {
28 lr,
29 ..Default::default()
30 },
31 )
32 .nn_err()?;
33
34 Ok(Self {
35 varmap,
36 log_alpha,
37 optimizer,
38 device,
39 })
40 }
41}
42
43impl EntropyTuner for CandleEntropyTuner {
44 fn alpha(&self) -> f32 {
45 let val: Vec<f32> = self.log_alpha.exp().unwrap().to_vec1().unwrap();
46 val[0]
47 }
48
49 fn update(&mut self, log_probs: &TensorData, target_entropy: f32) -> Result<f64, NNError> {
50 let lp = to_tensor_1d(log_probs, &self.device).nn_err()?;
51 let alpha = self.log_alpha.exp().nn_err()?;
52 let batch_size = log_probs.data.len();
53 let alpha_broadcast = alpha.broadcast_as(batch_size).nn_err()?;
54 let alpha_loss = (&alpha_broadcast * &(&lp + target_entropy as f64).nn_err()?)
55 .nn_err()?
56 .mean_all()
57 .nn_err()?
58 .neg()
59 .nn_err()?;
60
61 self.optimizer.backward_step(&alpha_loss).nn_err()?;
62
63 let loss_val: f32 = alpha_loss.to_scalar().nn_err()?;
64 Ok(loss_val as f64)
65 }
66}
67
68unsafe impl Send for CandleEntropyTuner {}
69unsafe impl Sync for CandleEntropyTuner {}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn test_initial_alpha() {
77 let tuner = CandleEntropyTuner::new(3e-4, Device::Cpu).unwrap();
78 let alpha = tuner.alpha();
79 assert!((alpha - 1.0).abs() < 1e-4, "exp(0) = 1: got {alpha}");
80 }
81
82 #[test]
83 fn test_update_runs() {
84 let mut tuner = CandleEntropyTuner::new(3e-4, Device::Cpu).unwrap();
85 let log_probs = TensorData::new(vec![-1.0; 32], vec![32]);
86 let loss = tuner.update(&log_probs, -1.0).unwrap();
87 assert!(loss.is_finite());
88 }
89
90 #[test]
91 fn test_alpha_changes() {
92 let mut tuner = CandleEntropyTuner::new(1e-2, Device::Cpu).unwrap();
93 let before = tuner.alpha();
94 let log_probs = TensorData::new(vec![-2.0; 32], vec![32]);
95 tuner.update(&log_probs, -1.0).unwrap();
96 let after = tuner.alpha();
97 assert!((before - after).abs() > 1e-6, "{before} vs {after}");
98 }
99}