rlox_candle/
dqn.rs

1use candle_core::Device;
2use candle_nn::{Optimizer, VarBuilder, VarMap};
3
4use rlox_nn::{Activation, MLPConfig, NNError, QFunction, TensorData};
5
6use crate::convert::*;
7use crate::mlp::MLP;
8
9#[allow(dead_code)]
10pub struct CandleDQN {
11    q_network: MLP,
12    target_network: MLP,
13    varmap: VarMap,
14    target_varmap: VarMap,
15    optimizer: candle_nn::AdamW,
16    device: Device,
17    n_actions: usize,
18    obs_dim: usize,
19    hidden: usize,
20    lr: f64,
21}
22
23impl CandleDQN {
24    pub fn new(
25        obs_dim: usize,
26        n_actions: usize,
27        hidden: usize,
28        lr: f64,
29        device: Device,
30    ) -> Result<Self, NNError> {
31        let config = MLPConfig::new(obs_dim, n_actions)
32            .with_hidden(vec![hidden, hidden])
33            .with_activation(Activation::ReLU);
34
35        let varmap = VarMap::new();
36        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
37        let q_network = MLP::new(&config, vb.pp("q")).nn_err()?;
38
39        let target_varmap = VarMap::new();
40        let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
41        let target_network = MLP::new(&config, tvb.pp("q")).nn_err()?;
42
43        // Copy weights from q to target
44        let q_data = varmap.data().lock().unwrap();
45        let t_data = target_varmap.data().lock().unwrap();
46        for (name, var) in q_data.iter() {
47            if let Some(tvar) = t_data.get(name) {
48                tvar.set(&var.as_tensor().clone()).unwrap();
49            }
50        }
51        drop(q_data);
52        drop(t_data);
53
54        let params = varmap.all_vars();
55        let optimizer = candle_nn::AdamW::new(
56            params,
57            candle_nn::ParamsAdamW {
58                lr,
59                ..Default::default()
60            },
61        )
62        .nn_err()?;
63
64        Ok(Self {
65            q_network,
66            target_network,
67            varmap,
68            target_varmap,
69            optimizer,
70            device,
71            n_actions,
72            obs_dim,
73            hidden,
74            lr,
75        })
76    }
77}
78
79impl QFunction for CandleDQN {
80    fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError> {
81        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
82        let q = self.q_network.forward(&obs_t).nn_err()?;
83        from_tensor_2d(&q).nn_err()
84    }
85
86    fn q_value_at(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError> {
87        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
88        let q = self.q_network.forward(&obs_t).nn_err()?;
89        let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
90        let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
91        let gathered = q.gather(&actions_2d, 1).nn_err()?.squeeze(1).nn_err()?;
92        from_tensor_1d(&gathered).nn_err()
93    }
94
95    fn td_step(
96        &mut self,
97        obs: &TensorData,
98        actions: &TensorData,
99        targets: &TensorData,
100        weights: Option<&TensorData>,
101    ) -> Result<(f64, TensorData), NNError> {
102        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
103        let q_all = self.q_network.forward(&obs_t).nn_err()?;
104        let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
105        let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
106        let q = q_all.gather(&actions_2d, 1).nn_err()?.squeeze(1).nn_err()?;
107
108        let target = to_tensor_1d(targets, &self.device).nn_err()?;
109        let td_error = (&q - &target).nn_err()?;
110
111        let loss = if let Some(w) = weights {
112            let w_t = to_tensor_1d(w, &self.device).nn_err()?;
113            (&w_t * &td_error.sqr().nn_err()?)
114                .nn_err()?
115                .mean_all()
116                .nn_err()?
117        } else {
118            td_error.sqr().nn_err()?.mean_all().nn_err()?
119        };
120
121        self.optimizer.backward_step(&loss).nn_err()?;
122
123        let loss_val: f32 = loss.to_scalar().nn_err()?;
124        let td_err_data = from_tensor_1d(&td_error.detach()).nn_err()?;
125
126        Ok((loss_val as f64, td_err_data))
127    }
128
129    fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError> {
130        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
131        let q = self.target_network.forward(&obs_t).nn_err()?;
132        from_tensor_2d(&q).nn_err()
133    }
134
135    fn hard_update_target(&mut self) {
136        let q_data = self.varmap.data().lock().unwrap();
137        let t_data = self.target_varmap.data().lock().unwrap();
138        for (name, var) in q_data.iter() {
139            if let Some(tvar) = t_data.get(name) {
140                tvar.set(&var.as_tensor().clone()).unwrap();
141            }
142        }
143    }
144}
145
146unsafe impl Send for CandleDQN {}
147unsafe impl Sync for CandleDQN {}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_q_values_shape() {
155        let dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
156        let obs = TensorData::zeros(vec![8, 4]);
157        let q = dqn.q_values(&obs).unwrap();
158        assert_eq!(q.shape, vec![8, 2]);
159    }
160
161    #[test]
162    fn test_q_value_at() {
163        let dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
164        let obs = TensorData::zeros(vec![4, 4]);
165        let actions = TensorData::new(vec![0.0, 1.0, 0.0, 1.0], vec![4]);
166        let q = dqn.q_value_at(&obs, &actions).unwrap();
167        assert_eq!(q.shape, vec![4]);
168    }
169
170    #[test]
171    fn test_td_step() {
172        let mut dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
173        let obs = TensorData::zeros(vec![32, 4]);
174        let actions = TensorData::new(vec![0.0; 32], vec![32]);
175        let targets = TensorData::new(vec![1.0; 32], vec![32]);
176
177        let (loss, td_errors) = dqn.td_step(&obs, &actions, &targets, None).unwrap();
178        assert!(loss.is_finite());
179        assert_eq!(td_errors.shape, vec![32]);
180    }
181
182    #[test]
183    fn test_hard_update() {
184        let mut dqn = CandleDQN::new(4, 2, 64, 1e-4, Device::Cpu).unwrap();
185        let obs = TensorData::zeros(vec![4, 4]);
186
187        // Train to change q_network
188        let actions = TensorData::new(vec![0.0; 4], vec![4]);
189        let targets = TensorData::new(vec![10.0; 4], vec![4]);
190        dqn.td_step(&obs, &actions, &targets, None).unwrap();
191
192        // Should differ before update
193        let q = dqn.q_values(&obs).unwrap();
194        let tq = dqn.target_q_values(&obs).unwrap();
195        assert_ne!(q.data, tq.data);
196
197        // Should match after hard update
198        dqn.hard_update_target();
199        let q2 = dqn.q_values(&obs).unwrap();
200        let tq2 = dqn.target_q_values(&obs).unwrap();
201        assert_eq!(q2.data, tq2.data);
202    }
203}