1use candle_core::Device;
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{Activation, MLPConfig, NNError, QFunction, TensorData};
5
6use crate::convert::*;
7use crate::mlp::MLP;
8
9#[allow(dead_code)]
10pub struct CandleDQN {
11 q_network: MLP,
12 target_network: MLP,
13 varmap: VarMap,
14 target_varmap: VarMap,
15 optimizer: candle_nn::AdamW,
16 device: Device,
17 n_actions: usize,
18 obs_dim: usize,
19 hidden: usize,
20 lr: f64,
21}
22
23impl CandleDQN {
24 pub fn new(
25 obs_dim: usize,
26 n_actions: usize,
27 hidden: usize,
28 lr: f64,
29 device: Device,
30 ) -> Result<Self, NNError> {
31 let config = MLPConfig::new(obs_dim, n_actions)
32 .with_hidden(vec![hidden, hidden])
33 .with_activation(Activation::ReLU);
34
35 let varmap = VarMap::new();
36 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
37 let q_network = MLP::new(&config, vb.pp("q")).nn_err()?;
38
39 let target_varmap = VarMap::new();
40 let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
41 let target_network = MLP::new(&config, tvb.pp("q")).nn_err()?;
42
43 let q_data = varmap.data().lock().unwrap();
45 let t_data = target_varmap.data().lock().unwrap();
46 for (name, var) in q_data.iter() {
47 if let Some(tvar) = t_data.get(name) {
48 tvar.set(&var.as_tensor().clone()).unwrap();
49 }
50 }
51 drop(q_data);
52 drop(t_data);
53
54 let params = varmap.all_vars();
55 let optimizer = candle_nn::AdamW::new(
56 params,
57 candle_nn::ParamsAdamW {
58 lr,
59 ..Default::default()
60 },
61 )
62 .nn_err()?;
63
64 Ok(Self {
65 q_network,
66 target_network,
67 varmap,
68 target_varmap,
69 optimizer,
70 device,
71 n_actions,
72 obs_dim,
73 hidden,
74 lr,
75 })
76 }
77}
78
79impl QFunction for CandleDQN {
80 fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError> {
81 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
82 let q = self.q_network.forward(&obs_t).nn_err()?;
83 from_tensor_2d(&q).nn_err()
84 }
85
86 fn q_value_at(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError> {
87 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
88 let q = self.q_network.forward(&obs_t).nn_err()?;
89 let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
90 let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
91 let gathered = q.gather(&actions_2d, 1).nn_err()?.squeeze(1).nn_err()?;
92 from_tensor_1d(&gathered).nn_err()
93 }
94
95 fn td_step(
96 &mut self,
97 obs: &TensorData,
98 actions: &TensorData,
99 targets: &TensorData,
100 weights: Option<&TensorData>,
101 ) -> Result<(f64, TensorData), NNError> {
102 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
103 let q_all = self.q_network.forward(&obs_t).nn_err()?;
104 let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
105 let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
106 let q = q_all.gather(&actions_2d, 1).nn_err()?.squeeze(1).nn_err()?;
107
108 let target = to_tensor_1d(targets, &self.device).nn_err()?;
109 let td_error = (&q - &target).nn_err()?;
110
111 let loss = if let Some(w) = weights {
112 let w_t = to_tensor_1d(w, &self.device).nn_err()?;
113 (&w_t * &td_error.sqr().nn_err()?)
114 .nn_err()?
115 .mean_all()
116 .nn_err()?
117 } else {
118 td_error.sqr().nn_err()?.mean_all().nn_err()?
119 };
120
121 self.optimizer.backward_step(&loss).nn_err()?;
122
123 let loss_val: f32 = loss.to_scalar().nn_err()?;
124 let td_err_data = from_tensor_1d(&td_error.detach()).nn_err()?;
125
126 Ok((loss_val as f64, td_err_data))
127 }
128
129 fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError> {
130 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
131 let q = self.target_network.forward(&obs_t).nn_err()?;
132 from_tensor_2d(&q).nn_err()
133 }
134
135 fn hard_update_target(&mut self) {
136 let q_data = self.varmap.data().lock().unwrap();
137 let t_data = self.target_varmap.data().lock().unwrap();
138 for (name, var) in q_data.iter() {
139 if let Some(tvar) = t_data.get(name) {
140 tvar.set(&var.as_tensor().clone()).unwrap();
141 }
142 }
143 }
144}
145
146unsafe impl Send for CandleDQN {}
147unsafe impl Sync for CandleDQN {}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_q_values_shape() {
155 let dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
156 let obs = TensorData::zeros(vec![8, 4]);
157 let q = dqn.q_values(&obs).unwrap();
158 assert_eq!(q.shape, vec![8, 2]);
159 }
160
161 #[test]
162 fn test_q_value_at() {
163 let dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
164 let obs = TensorData::zeros(vec![4, 4]);
165 let actions = TensorData::new(vec![0.0, 1.0, 0.0, 1.0], vec![4]);
166 let q = dqn.q_value_at(&obs, &actions).unwrap();
167 assert_eq!(q.shape, vec![4]);
168 }
169
170 #[test]
171 fn test_td_step() {
172 let mut dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
173 let obs = TensorData::zeros(vec![32, 4]);
174 let actions = TensorData::new(vec![0.0; 32], vec![32]);
175 let targets = TensorData::new(vec![1.0; 32], vec![32]);
176
177 let (loss, td_errors) = dqn.td_step(&obs, &actions, &targets, None).unwrap();
178 assert!(loss.is_finite());
179 assert_eq!(td_errors.shape, vec![32]);
180 }
181
182 #[test]
183 fn test_hard_update() {
184 let mut dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
185 let obs = TensorData::zeros(vec![4, 4]);
186
187 let actions = TensorData::new(vec![0.0; 4], vec![4]);
189 let targets = TensorData::new(vec![10.0; 4], vec![4]);
190 dqn.td_step(&obs, &actions, &targets, None).unwrap();
191
192 let q = dqn.q_values(&obs).unwrap();
194 let tq = dqn.target_q_values(&obs).unwrap();
195 assert_ne!(q.data, tq.data);
196
197 dqn.hard_update_target();
199 let q2 = dqn.q_values(&obs).unwrap();
200 let tq2 = dqn.target_q_values(&obs).unwrap();
201 assert_eq!(q2.data, tq2.data);
202 }
203}