rlox_candle/
stochastic.rs

1use std::path::Path;
2
3use candle_core::{Device, Tensor};
4use candle_nn::{linear, Linear, Module, Optimizer, VarBuilder, VarMap};
5use rand::SeedableRng;
6use rand_chacha::ChaCha8Rng;
7
8use rlox_nn::{Activation, MLPConfig, NNError, StochasticPolicy, TensorData, TrainMetrics};
9
10use crate::convert::*;
11use crate::mlp::MLP;
12
13const LOG_STD_MIN: f32 = -20.0;
14const LOG_STD_MAX: f32 = 2.0;
15
16pub struct CandleStochasticPolicy {
17    shared: MLP,
18    mean_head: Linear,
19    log_std_head: Linear,
20    varmap: VarMap,
21    optimizer: candle_nn::AdamW,
22    device: Device,
23    act_dim: usize,
24    lr: f64,
25    #[allow(dead_code)]
26    rng: ChaCha8Rng,
27}
28
29impl CandleStochasticPolicy {
30    pub fn new(
31        obs_dim: usize,
32        act_dim: usize,
33        hidden: usize,
34        lr: f64,
35        device: Device,
36        seed: u64,
37    ) -> Result<Self, NNError> {
38        let varmap = VarMap::new();
39        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
40
41        let shared_config = MLPConfig::new(obs_dim, hidden)
42            .with_hidden(vec![hidden])
43            .with_activation(Activation::ReLU);
44
45        let shared = MLP::new(&shared_config, vb.pp("shared")).nn_err()?;
46        let mean_head = linear(hidden, act_dim, vb.pp("mean")).nn_err()?;
47        let log_std_head = linear(hidden, act_dim, vb.pp("log_std")).nn_err()?;
48
49        let params = varmap.all_vars();
50        let optimizer = candle_nn::AdamW::new(
51            params,
52            candle_nn::ParamsAdamW {
53                lr,
54                ..Default::default()
55            },
56        )
57        .nn_err()?;
58
59        Ok(Self {
60            shared,
61            mean_head,
62            log_std_head,
63            varmap,
64            optimizer,
65            device,
66            act_dim,
67            lr,
68            rng: ChaCha8Rng::seed_from_u64(seed),
69        })
70    }
71
72    fn forward(&self, obs: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
73        let h = self.shared.forward(obs)?;
74        let mean = self.mean_head.forward(&h)?;
75        let log_std = self
76            .log_std_head
77            .forward(&h)?
78            .clamp(LOG_STD_MIN, LOG_STD_MAX)?;
79        Ok((mean, log_std))
80    }
81}
82
83impl CandleStochasticPolicy {
84    /// SAC actor gradient step with autograd flowing through the critic.
85    ///
86    /// Takes concrete `CandleTwinQ` to preserve gradient flow from Q-values
87    /// back to actor parameters via the reparameterized actions.
88    pub fn sac_actor_step(
89        &mut self,
90        obs: &TensorData,
91        alpha: f32,
92        critic: &crate::continuous_q::CandleTwinQ,
93    ) -> Result<TrainMetrics, NNError> {
94        let batch_size = obs.shape[0];
95        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
96        let (mean, log_std) = self.forward(&obs_t).nn_err()?;
97        let std = log_std.exp().nn_err()?;
98
99        let eps = Tensor::randn(0.0_f32, 1.0, (batch_size, self.act_dim), &self.device).nn_err()?;
100        let x_t = (&mean + &(&std * &eps).nn_err()?).nn_err()?;
101        let y_t = x_t.tanh().nn_err()?;
102
103        // Log-prob (differentiable) — use (x_t - mean) for correct values
104        let var = std.sqr().nn_err()?;
105        let residual = (&x_t - &mean).nn_err()?;
106        let normal_lp = (&residual.sqr().nn_err()? / &var)
107            .nn_err()?
108            .neg()
109            .nn_err()?
110            / 2.0;
111        let normal_lp = normal_lp.nn_err()?;
112        let log_std_term = std.log().nn_err()?.neg().nn_err()?;
113        let const_term = -0.5 * (2.0 * std::f32::consts::PI).ln();
114        let normal_lp_full =
115            (&(&normal_lp + &log_std_term).nn_err()? + const_term as f64).nn_err()?;
116
117        let y_sq = y_t.sqr().nn_err()?;
118        let one = Tensor::ones_like(&y_sq).nn_err()?;
119        let correction = ((&one - &y_sq).nn_err()? + 1e-6_f64)
120            .nn_err()?
121            .log()
122            .nn_err()?
123            .neg()
124            .nn_err()?;
125        let log_prob = (&normal_lp_full + &correction).nn_err()?.sum(1).nn_err()?;
126
127        // Q-values with full autograd through critic
128        let (q1, q2) = critic.twin_q_forward(&obs_t, &y_t).nn_err()?;
129        let q1 = q1.squeeze(1).nn_err()?;
130        let q2 = q2.squeeze(1).nn_err()?;
131        let q_min = q1.minimum(&q2).nn_err()?;
132
133        let actor_loss = (&(&log_prob * alpha as f64).nn_err()? - &q_min)
134            .nn_err()?
135            .mean_all()
136            .nn_err()?;
137
138        self.optimizer.backward_step(&actor_loss).nn_err()?;
139
140        let loss_val: f32 = actor_loss.to_scalar().nn_err()?;
141
142        let mut metrics = TrainMetrics::new();
143        metrics.insert("actor_loss", loss_val as f64);
144        Ok(metrics)
145    }
146}
147
148impl StochasticPolicy for CandleStochasticPolicy {
149    fn sample_actions(&self, obs: &TensorData) -> Result<(TensorData, TensorData), NNError> {
150        let batch_size = obs.shape[0];
151        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
152        let (mean, log_std) = self.forward(&obs_t).nn_err()?;
153        let std = log_std.exp().nn_err()?;
154
155        let eps = Tensor::randn(0.0_f32, 1.0, (batch_size, self.act_dim), &self.device).nn_err()?;
156        let x_t = (&mean + &(&std * &eps).nn_err()?).nn_err()?;
157        let y_t = x_t.tanh().nn_err()?;
158
159        let mean_data: Vec<f32> = mean.flatten_all().nn_err()?.to_vec1().nn_err()?;
160        let std_data: Vec<f32> = std.flatten_all().nn_err()?.to_vec1().nn_err()?;
161        let x_data: Vec<f32> = x_t.flatten_all().nn_err()?.to_vec1().nn_err()?;
162        let y_data: Vec<f32> = y_t.flatten_all().nn_err()?.to_vec1().nn_err()?;
163
164        let mut actions_flat = Vec::with_capacity(batch_size * self.act_dim);
165        let mut log_probs = Vec::with_capacity(batch_size);
166
167        for i in 0..batch_size {
168            let mut lp_sum = 0.0_f32;
169            for j in 0..self.act_dim {
170                let idx = i * self.act_dim + j;
171                let x = x_data[idx];
172                let m = mean_data[idx];
173                let s = std_data[idx];
174                let y = y_data[idx];
175
176                actions_flat.push(y);
177
178                let normal_lp =
179                    -0.5 * ((x - m) / s).powi(2) - s.ln() - 0.5 * (2.0 * std::f32::consts::PI).ln();
180                let correction = -(1.0 - y * y + 1e-6).ln();
181                lp_sum += normal_lp + correction;
182            }
183            log_probs.push(lp_sum);
184        }
185
186        Ok((
187            TensorData::new(actions_flat, vec![batch_size, self.act_dim]),
188            TensorData::new(log_probs, vec![batch_size]),
189        ))
190    }
191
192    fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError> {
193        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
194        let (mean, _) = self.forward(&obs_t).nn_err()?;
195        let action = mean.tanh().nn_err()?;
196        from_tensor_2d(&action).nn_err()
197    }
198
199    fn learning_rate(&self) -> f32 {
200        self.lr as f32
201    }
202
203    fn set_learning_rate(&mut self, lr: f32) {
204        self.lr = lr as f64;
205        self.optimizer.set_learning_rate(lr as f64);
206    }
207
208    fn save(&self, path: &Path) -> Result<(), NNError> {
209        self.varmap
210            .save(path)
211            .map_err(|e| NNError::Serialization(e.to_string()))
212    }
213
214    fn load(&mut self, path: &Path) -> Result<(), NNError> {
215        self.varmap
216            .load(path)
217            .map_err(|e| NNError::Serialization(e.to_string()))
218    }
219}
220
221unsafe impl Send for CandleStochasticPolicy {}
222unsafe impl Sync for CandleStochasticPolicy {}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_sample_actions_shape() {
230        let policy = CandleStochasticPolicy::new(3, 2, 64, 3e-4, Device::Cpu, 42).unwrap();
231        let obs = TensorData::zeros(vec![8, 3]);
232        let (actions, log_probs) = policy.sample_actions(&obs).unwrap();
233        assert_eq!(actions.shape, vec![8, 2]);
234        assert_eq!(log_probs.shape, vec![8]);
235    }
236
237    #[test]
238    fn test_sample_in_range() {
239        let policy = CandleStochasticPolicy::new(3, 1, 64, 3e-4, Device::Cpu, 42).unwrap();
240        let obs = TensorData::zeros(vec![100, 3]);
241        let (actions, _) = policy.sample_actions(&obs).unwrap();
242        for &a in &actions.data {
243            assert!(a >= -1.0 && a <= 1.0, "out of range: {a}");
244        }
245    }
246
247    #[test]
248    fn test_deterministic_shape() {
249        let policy = CandleStochasticPolicy::new(3, 2, 64, 3e-4, Device::Cpu, 42).unwrap();
250        let obs = TensorData::zeros(vec![4, 3]);
251        let actions = policy.deterministic_action(&obs).unwrap();
252        assert_eq!(actions.shape, vec![4, 2]);
253    }
254}