1use candle_core::{Device, Tensor};
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{Activation, ContinuousQFunction, MLPConfig, NNError, TensorData, TrainMetrics};
5
6use crate::convert::*;
7use crate::mlp::MLP;
8
9#[allow(dead_code)]
10pub struct CandleTwinQ {
11 q1: MLP,
12 q2: MLP,
13 q1_target: MLP,
14 q2_target: MLP,
15 varmap: VarMap,
16 target_varmap: VarMap,
17 q1_optimizer: candle_nn::AdamW,
18 q2_optimizer: candle_nn::AdamW,
19 device: Device,
20 lr: f64,
21}
22
23impl CandleTwinQ {
24 pub fn new(
25 obs_dim: usize,
26 act_dim: usize,
27 hidden: usize,
28 lr: f64,
29 device: Device,
30 ) -> Result<Self, NNError> {
31 let config = MLPConfig::new(obs_dim + act_dim, 1)
32 .with_hidden(vec![hidden, hidden])
33 .with_activation(Activation::ReLU);
34
35 let varmap = VarMap::new();
36 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
37 let q1 = MLP::new(&config, vb.pp("q1")).nn_err()?;
38 let q2 = MLP::new(&config, vb.pp("q2")).nn_err()?;
39
40 let target_varmap = VarMap::new();
41 let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
42 let q1_target = MLP::new(&config, tvb.pp("q1")).nn_err()?;
43 let q2_target = MLP::new(&config, tvb.pp("q2")).nn_err()?;
44
45 {
47 let src = varmap.data().lock().unwrap();
48 let tgt = target_varmap.data().lock().unwrap();
49 for (name, var) in src.iter() {
50 if let Some(tvar) = tgt.get(name) {
51 tvar.set(&var.as_tensor().clone()).unwrap();
52 }
53 }
54 }
55
56 let all_params = varmap.all_vars();
58 let q1_optimizer = candle_nn::AdamW::new(
59 all_params.clone(),
60 candle_nn::ParamsAdamW {
61 lr,
62 ..Default::default()
63 },
64 )
65 .nn_err()?;
66 let q2_optimizer = candle_nn::AdamW::new(
67 all_params,
68 candle_nn::ParamsAdamW {
69 lr,
70 ..Default::default()
71 },
72 )
73 .nn_err()?;
74
75 Ok(Self {
76 q1,
77 q2,
78 q1_target,
79 q2_target,
80 varmap,
81 target_varmap,
82 q1_optimizer,
83 q2_optimizer,
84 device,
85 lr,
86 })
87 }
88
89 fn forward_q(q: &MLP, obs: &Tensor, actions: &Tensor) -> candle_core::Result<Tensor> {
90 let input = Tensor::cat(&[obs, actions], 1)?;
91 q.forward(&input)
92 }
93
94 pub fn q1_forward(&self, obs: &Tensor, actions: &Tensor) -> candle_core::Result<Tensor> {
96 Self::forward_q(&self.q1, obs, actions)
97 }
98
99 pub fn twin_q_forward(
101 &self,
102 obs: &Tensor,
103 actions: &Tensor,
104 ) -> candle_core::Result<(Tensor, Tensor)> {
105 let q1 = Self::forward_q(&self.q1, obs, actions)?;
106 let q2 = Self::forward_q(&self.q2, obs, actions)?;
107 Ok((q1, q2))
108 }
109}
110
111impl ContinuousQFunction for CandleTwinQ {
112 fn q_value(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError> {
113 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
114 let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
115 let q = Self::forward_q(&self.q1, &obs_t, &act_t)
116 .nn_err()?
117 .squeeze(1)
118 .nn_err()?;
119 from_tensor_1d(&q).nn_err()
120 }
121
122 fn twin_q_values(
123 &self,
124 obs: &TensorData,
125 actions: &TensorData,
126 ) -> Result<(TensorData, TensorData), NNError> {
127 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
128 let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
129 let q1 = Self::forward_q(&self.q1, &obs_t, &act_t)
130 .nn_err()?
131 .squeeze(1)
132 .nn_err()?;
133 let q2 = Self::forward_q(&self.q2, &obs_t, &act_t)
134 .nn_err()?
135 .squeeze(1)
136 .nn_err()?;
137 Ok((from_tensor_1d(&q1).nn_err()?, from_tensor_1d(&q2).nn_err()?))
138 }
139
140 fn target_twin_q_values(
141 &self,
142 obs: &TensorData,
143 actions: &TensorData,
144 ) -> Result<(TensorData, TensorData), NNError> {
145 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
146 let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
147 let q1 = Self::forward_q(&self.q1_target, &obs_t, &act_t)
148 .nn_err()?
149 .squeeze(1)
150 .nn_err()?;
151 let q2 = Self::forward_q(&self.q2_target, &obs_t, &act_t)
152 .nn_err()?
153 .squeeze(1)
154 .nn_err()?;
155 Ok((from_tensor_1d(&q1).nn_err()?, from_tensor_1d(&q2).nn_err()?))
156 }
157
158 fn critic_step(
159 &mut self,
160 obs: &TensorData,
161 actions: &TensorData,
162 targets: &TensorData,
163 ) -> Result<TrainMetrics, NNError> {
164 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
165 let act_t = to_tensor_2d(actions, &self.device).nn_err()?;
166 let target_t = to_tensor_1d(targets, &self.device).nn_err()?;
167
168 let q1 = Self::forward_q(&self.q1, &obs_t, &act_t)
169 .nn_err()?
170 .squeeze(1)
171 .nn_err()?;
172 let q1_loss = (&q1 - &target_t)
173 .nn_err()?
174 .sqr()
175 .nn_err()?
176 .mean_all()
177 .nn_err()?;
178
179 let q2 = Self::forward_q(&self.q2, &obs_t, &act_t)
180 .nn_err()?
181 .squeeze(1)
182 .nn_err()?;
183 let q2_loss = (&q2 - &target_t)
184 .nn_err()?
185 .sqr()
186 .nn_err()?
187 .mean_all()
188 .nn_err()?;
189
190 let total = (&q1_loss + &q2_loss).nn_err()?;
191 self.q1_optimizer.backward_step(&total).nn_err()?;
192
193 let q1_val: f32 = q1_loss.to_scalar().nn_err()?;
194 let q2_val: f32 = q2_loss.to_scalar().nn_err()?;
195
196 let mut metrics = TrainMetrics::new();
197 metrics.insert("q1_loss", q1_val as f64);
198 metrics.insert("q2_loss", q2_val as f64);
199 metrics.insert("critic_loss", ((q1_val + q2_val) / 2.0) as f64);
200 Ok(metrics)
201 }
202
203 fn soft_update_targets(&mut self, tau: f32) {
204 let src = self.varmap.data().lock().unwrap();
205 let tgt = self.target_varmap.data().lock().unwrap();
206 for (name, var) in src.iter() {
207 if let Some(tvar) = tgt.get(name) {
208 let src_t = var.as_tensor();
209 let tgt_t = tvar.as_tensor();
210 let new_val = ((src_t * tau as f64).unwrap()
211 + (tgt_t * (1.0 - tau) as f64).unwrap())
212 .unwrap();
213 tvar.set(&new_val).unwrap();
214 }
215 }
216 }
217}
218
219unsafe impl Send for CandleTwinQ {}
220unsafe impl Sync for CandleTwinQ {}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_twin_q_shapes() {
228 let q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
229 let obs = TensorData::zeros(vec![4, 3]);
230 let actions = TensorData::zeros(vec![4, 1]);
231 let (q1, q2) = q.twin_q_values(&obs, &actions).unwrap();
232 assert_eq!(q1.shape, vec![4]);
233 assert_eq!(q2.shape, vec![4]);
234 }
235
236 #[test]
237 fn test_critic_step() {
238 let mut q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
239 let obs = TensorData::zeros(vec![16, 3]);
240 let actions = TensorData::zeros(vec![16, 1]);
241 let targets = TensorData::new(vec![1.0; 16], vec![16]);
242 let metrics = q.critic_step(&obs, &actions, &targets).unwrap();
243 assert!(metrics.get("critic_loss").unwrap().is_finite());
244 }
245
246 #[test]
247 fn test_soft_update() {
248 let mut q = CandleTwinQ::new(3, 1, 64, 3e-4, Device::Cpu).unwrap();
249 let obs = TensorData::zeros(vec![4, 3]);
250 let actions = TensorData::zeros(vec![4, 1]);
251
252 let targets = TensorData::new(vec![10.0; 4], vec![4]);
254 q.critic_step(&obs, &actions, &targets).unwrap();
255
256 let (q1, _) = q.twin_q_values(&obs, &actions).unwrap();
258 let (tq1, _) = q.target_twin_q_values(&obs, &actions).unwrap();
259 assert_ne!(q1.data, tq1.data);
260
261 q.soft_update_targets(1.0);
263 let (q1b, _) = q.twin_q_values(&obs, &actions).unwrap();
264 let (tq1b, _) = q.target_twin_q_values(&obs, &actions).unwrap();
265 for (a, b) in q1b.data.iter().zip(tq1b.data.iter()) {
266 assert!((a - b).abs() < 1e-4, "should match: {a} vs {b}");
267 }
268 }
269}