rlox_candle/
continuous_q.rs

1use candle_core::{Device, Tensor};
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{Activation, ContinuousQFunction, MLPConfig, NNError, TensorData, TrainMetrics};
5
6use crate::convert::*;
7use crate::mlp::MLP;
8
9#[allow(dead_code)]
10pub struct CandleTwinQ {
11    q1: MLP,
12    q2: MLP,
13    q1_target: MLP,
14    q2_target: MLP,
15    varmap: VarMap,
16    target_varmap: VarMap,
17    q1_optimizer: candle_nn::AdamW,
18    q2_optimizer: candle_nn::AdamW,
19    device: Device,
20    lr: f64,
21}
22
23impl CandleTwinQ {
24    pub fn new(
25        obs_dim: usize,
26        act_dim: usize,
27        hidden: usize,
28        lr: f64,
29        device: Device,
30    ) -> Result<Self, NNError> {
31        let config = MLPConfig::new(obs_dim + act_dim, 1)
32            .with_hidden(vec![hidden, hidden])
33            .with_activation(Activation::ReLU);
34
35        let varmap = VarMap::new();
36        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
37        let q1 = MLP::new(&config, vb.pp("q1")).nn_err()?;
38        let q2 = MLP::new(&config, vb.pp("q2")).nn_err()?;
39
40        let target_varmap = VarMap::new();
41        let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
42        let q1_target = MLP::new(&config, tvb.pp("q1")).nn_err()?;
43        let q2_target = MLP::new(&config, tvb.pp("q2")).nn_err()?;
44
45        // Copy weights
46        {
47            let src = varmap.data().lock().unwrap();
48            let tgt = target_varmap.data().lock().unwrap();
49            for (name, var) in src.iter() {
50                if let Some(tvar) = tgt.get(name) {
51                    tvar.set(&var.as_tensor().clone()).unwrap();
52                }
53            }
54        }
55
56        // Separate optimizers for q1 and q2 params
57        let all_params = varmap.all_vars();
58        let q1_optimizer = candle_nn::AdamW::new(
59            all_params.clone(),
60            candle_nn::ParamsAdamW {
61                lr,
62                ..Default::default()
63            },
64        )
65        .nn_err()?;
66        let q2_optimizer = candle_nn::AdamW::new(
67            all_params,
68            candle_nn::ParamsAdamW {
69                lr,
70                ..Default::default()
71            },
72        )
73        .nn_err()?;
74
75        Ok(Self {
76            q1,
77            q2,
78            q1_target,
79            q2_target,
80            varmap,
81            target_varmap,
82            q1_optimizer,
83            q2_optimizer,
84            device,
85            lr,
86        })
87    }
88
89    fn forward_q(q: &MLP, obs: &Tensor, actions: &Tensor) -> candle_core::Result<Tensor> {
90        let input = Tensor::cat(&[obs, actions], 1)?;
91        q.forward(&input)
92    }
93
94    /// Forward through Q1 with autograd preserved (for actor loss in SAC/TD3).
95    pub fn q1_forward(&self, obs: &Tensor, actions: &Tensor) -> candle_core::Result<Tensor> {
96        Self::forward_q(&self.q1, obs, actions)
97    }
98
99    /// Forward through both Q-networks with autograd preserved (for SAC actor loss).
100    pub fn twin_q_forward(
101        &self,
102        obs: &Tensor,
103        actions: &Tensor,
104    ) -> candle_core::Result<(Tensor, Tensor)> {
105        let q1 = Self::forward_q(&self.q1, obs, actions)?;
106        let q2 = Self::forward_q(&self.q2, obs, actions)?;
107        Ok((q1, q2))
108    }
109}
110
111impl ContinuousQFunction for CandleTwinQ {
112    fn q_value(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError> {
113        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
114        let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
115        let q = Self::forward_q(&self.q1, &obs_t, &act_t)
116            .nn_err()?
117            .squeeze(1)
118            .nn_err()?;
119        from_tensor_1d(&q).nn_err()
120    }
121
122    fn twin_q_values(
123        &self,
124        obs: &TensorData,
125        actions: &TensorData,
126    ) -> Result<(TensorData, TensorData), NNError> {
127        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
128        let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
129        let q1 = Self::forward_q(&self.q1, &obs_t, &act_t)
130            .nn_err()?
131            .squeeze(1)
132            .nn_err()?;
133        let q2 = Self::forward_q(&self.q2, &obs_t, &act_t)
134            .nn_err()?
135            .squeeze(1)
136            .nn_err()?;
137        Ok((from_tensor_1d(&q1).nn_err()?, from_tensor_1d(&q2).nn_err()?))
138    }
139
140    fn target_twin_q_values(
141        &self,
142        obs: &TensorData,
143        actions: &TensorData,
144    ) -> Result<(TensorData, TensorData), NNError> {
145        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
146        let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
147        let q1 = Self::forward_q(&self.q1_target, &obs_t, &act_t)
148            .nn_err()?
149            .squeeze(1)
150            .nn_err()?;
151        let q2 = Self::forward_q(&self.q2_target, &obs_t, &act_t)
152            .nn_err()?
153            .squeeze(1)
154            .nn_err()?;
155        Ok((from_tensor_1d(&q1).nn_err()?, from_tensor_1d(&q2).nn_err()?))
156    }
157
158    fn critic_step(
159        &mut self,
160        obs: &TensorData,
161        actions: &TensorData,
162        targets: &TensorData,
163    ) -> Result<TrainMetrics, NNError> {
164        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
165        let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
166        let target_t = to_tensor_1d(targets, &self.device).nn_err()?;
167
168        let q1 = Self::forward_q(&self.q1, &obs_t, &act_t)
169            .nn_err()?
170            .squeeze(1)
171            .nn_err()?;
172        let q1_loss = (&q1 - &target_t)
173            .nn_err()?
174            .sqr()
175            .nn_err()?
176            .mean_all()
177            .nn_err()?;
178
179        let q2 = Self::forward_q(&self.q2, &obs_t, &act_t)
180            .nn_err()?
181            .squeeze(1)
182            .nn_err()?;
183        let q2_loss = (&q2 - &target_t)
184            .nn_err()?
185            .sqr()
186            .nn_err()?
187            .mean_all()
188            .nn_err()?;
189
190        let total = (&q1_loss + &q2_loss).nn_err()?;
191        self.q1_optimizer.backward_step(&total).nn_err()?;
192
193        let q1_val: f32 = q1_loss.to_scalar().nn_err()?;
194        let q2_val: f32 = q2_loss.to_scalar().nn_err()?;
195
196        let mut metrics = TrainMetrics::new();
197        metrics.insert("q1_loss", q1_val as f64);
198        metrics.insert("q2_loss", q2_val as f64);
199        metrics.insert("critic_loss", ((q1_val + q2_val) / 2.0) as f64);
200        Ok(metrics)
201    }
202
203    fn soft_update_targets(&mut self, tau: f32) {
204        let src = self.varmap.data().lock().unwrap();
205        let tgt = self.target_varmap.data().lock().unwrap();
206        for (name, var) in src.iter() {
207            if let Some(tvar) = tgt.get(name) {
208                let src_t = var.as_tensor();
209                let tgt_t = tvar.as_tensor();
210                let new_val = ((src_t * tau as f64).unwrap()
211                    + (tgt_t * (1.0 - tau) as f64).unwrap())
212                .unwrap();
213                tvar.set(&new_val).unwrap();
214            }
215        }
216    }
217}
218
219unsafe impl Send for CandleTwinQ {}
220unsafe impl Sync for CandleTwinQ {}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_twin_q_shapes() {
228        let q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
229        let obs = TensorData::zeros(vec![4, 3]);
230        let actions = TensorData::zeros(vec![4, 1]);
231        let (q1, q2) = q.twin_q_values(&obs, &actions).unwrap();
232        assert_eq!(q1.shape, vec![4]);
233        assert_eq!(q2.shape, vec![4]);
234    }
235
236    #[test]
237    fn test_critic_step() {
238        let mut q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
239        let obs = TensorData::zeros(vec![16, 3]);
240        let actions = TensorData::zeros(vec![16, 1]);
241        let targets = TensorData::new(vec![1.0; 16], vec![16]);
242        let metrics = q.critic_step(&obs, &actions, &targets).unwrap();
243        assert!(metrics.get("critic_loss").unwrap().is_finite());
244    }
245
246    #[test]
247    fn test_soft_update() {
248        let mut q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
249        let obs = TensorData::zeros(vec![4, 3]);
250        let actions = TensorData::zeros(vec![4, 1]);
251
252        // Train
253        let targets = TensorData::new(vec![10.0; 4], vec![4]);
254        q.critic_step(&obs, &actions, &targets).unwrap();
255
256        // Before update
257        let (q1, _) = q.twin_q_values(&obs, &actions).unwrap();
258        let (tq1, _) = q.target_twin_q_values(&obs, &actions).unwrap();
259        assert_ne!(q1.data, tq1.data);
260
261        // Hard update
262        q.soft_update_targets(1.0);
263        let (q1b, _) = q.twin_q_values(&obs, &actions).unwrap();
264        let (tq1b, _) = q.target_twin_q_values(&obs, &actions).unwrap();
265        for (a, b) in q1b.data.iter().zip(tq1b.data.iter()) {
266            assert!((a - b).abs() < 1e-4, "should match: {a} vs {b}");
267        }
268    }
269}