rlox_candle/
deterministic.rs

1use std::path::Path;
2
3use candle_core::Device;
4use candle_nn::{Optimizer, VarBuilder, VarMap};
5
6use rlox_nn::{
7    Activation, DeterministicPolicy as DeterministicPolicyTrait, MLPConfig, NNError, TensorData,
8    TrainMetrics,
9};
10
11use crate::convert::*;
12use crate::mlp::MLP;
13
14pub struct CandleDeterministicPolicy {
15    net: MLP,
16    target_net: MLP,
17    varmap: VarMap,
18    target_varmap: VarMap,
19    optimizer: candle_nn::AdamW,
20    device: Device,
21    max_action: f32,
22    lr: f64,
23}
24
25impl CandleDeterministicPolicy {
26    pub fn new(
27        obs_dim: usize,
28        act_dim: usize,
29        hidden: usize,
30        max_action: f32,
31        lr: f64,
32        device: Device,
33    ) -> Result<Self, NNError> {
34        let config = MLPConfig::new(obs_dim, act_dim)
35            .with_hidden(vec![hidden, hidden])
36            .with_activation(Activation::ReLU)
37            .with_output_activation(Activation::Tanh);
38
39        let varmap = VarMap::new();
40        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
41        let net = MLP::new(&config, vb.pp("actor")).nn_err()?;
42
43        let target_varmap = VarMap::new();
44        let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
45        let target_net = MLP::new(&config, tvb.pp("actor")).nn_err()?;
46
47        // Copy weights
48        {
49            let src = varmap.data().lock().unwrap();
50            let tgt = target_varmap.data().lock().unwrap();
51            for (name, var) in src.iter() {
52                if let Some(tvar) = tgt.get(name) {
53                    tvar.set(&var.as_tensor().clone()).unwrap();
54                }
55            }
56        }
57
58        let params = varmap.all_vars();
59        let optimizer = candle_nn::AdamW::new(
60            params,
61            candle_nn::ParamsAdamW {
62                lr,
63                ..Default::default()
64            },
65        )
66        .nn_err()?;
67
68        Ok(Self {
69            net,
70            target_net,
71            varmap,
72            target_varmap,
73            optimizer,
74            device,
75            max_action,
76            lr,
77        })
78    }
79}
80
81impl CandleDeterministicPolicy {
82    /// TD3 actor gradient step with autograd flowing through the critic.
83    ///
84    /// Takes concrete `CandleTwinQ` to preserve gradient flow from Q1(s, a)
85    /// back to actor parameters via the deterministic actions.
86    pub fn td3_actor_step(
87        &mut self,
88        obs: &TensorData,
89        critic: &crate::continuous_q::CandleTwinQ,
90    ) -> Result<TrainMetrics, NNError> {
91        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
92        let actions = self.net.forward(&obs_t).nn_err()?;
93        let scaled = (&actions * self.max_action as f64).nn_err()?;
94
95        // Q1 with full autograd — gradient flows through critic back to actor
96        let q1 = critic
97            .q1_forward(&obs_t, &scaled)
98            .nn_err()?
99            .squeeze(1)
100            .nn_err()?;
101
102        let actor_loss = q1.neg().nn_err()?.mean_all().nn_err()?;
103
104        self.optimizer.backward_step(&actor_loss).nn_err()?;
105
106        let loss_val: f32 = actor_loss.to_scalar().nn_err()?;
107        let mut metrics = TrainMetrics::new();
108        metrics.insert("actor_loss", loss_val as f64);
109        Ok(metrics)
110    }
111}
112
113impl DeterministicPolicyTrait for CandleDeterministicPolicy {
114    fn act(&self, obs: &TensorData) -> Result<TensorData, NNError> {
115        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
116        let actions = self.net.forward(&obs_t).nn_err()?;
117        let scaled = (&actions * self.max_action as f64).nn_err()?;
118        from_tensor_2d(&scaled).nn_err()
119    }
120
121    fn target_act(&self, obs: &TensorData) -> Result<TensorData, NNError> {
122        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
123        let actions = self.target_net.forward(&obs_t).nn_err()?;
124        let scaled = (&actions * self.max_action as f64).nn_err()?;
125        from_tensor_2d(&scaled).nn_err()
126    }
127
128    fn soft_update_target(&mut self, tau: f32) {
129        let src = self.varmap.data().lock().unwrap();
130        let tgt = self.target_varmap.data().lock().unwrap();
131        for (name, var) in src.iter() {
132            if let Some(tvar) = tgt.get(name) {
133                let src_t = var.as_tensor();
134                let tgt_t = tvar.as_tensor();
135                let new_val = ((src_t * tau as f64).unwrap()
136                    + (tgt_t * (1.0 - tau) as f64).unwrap())
137                .unwrap();
138                tvar.set(&new_val).unwrap();
139            }
140        }
141    }
142
143    fn learning_rate(&self) -> f32 {
144        self.lr as f32
145    }
146
147    fn set_learning_rate(&mut self, lr: f32) {
148        self.lr = lr as f64;
149        self.optimizer.set_learning_rate(lr as f64);
150    }
151
152    fn save(&self, path: &Path) -> Result<(), NNError> {
153        self.varmap
154            .save(path)
155            .map_err(|e| NNError::Serialization(e.to_string()))
156    }
157
158    fn load(&mut self, path: &Path) -> Result<(), NNError> {
159        self.varmap
160            .load(path)
161            .map_err(|e| NNError::Serialization(e.to_string()))
162    }
163}
164
165unsafe impl Send for CandleDeterministicPolicy {}
166unsafe impl Sync for CandleDeterministicPolicy {}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_act_shape() {
174        let policy = CandleDeterministicPolicy::new(3, 1, 64, 1.0, 3e-4, Device::Cpu).unwrap();
175        let obs = TensorData::zeros(vec![8, 3]);
176        let actions = policy.act(&obs).unwrap();
177        assert_eq!(actions.shape, vec![8, 1]);
178    }
179
180    #[test]
181    fn test_target_matches_initially() {
182        let policy = CandleDeterministicPolicy::new(3, 1, 64, 1.0, 3e-4, Device::Cpu).unwrap();
183        let obs = TensorData::zeros(vec![4, 3]);
184        let act = policy.act(&obs).unwrap();
185        let tgt = policy.target_act(&obs).unwrap();
186        for (a, b) in act.data.iter().zip(tgt.data.iter()) {
187            assert!((a - b).abs() < 1e-5, "{a} vs {b}");
188        }
189    }
190
191    #[test]
192    fn test_action_range() {
193        let policy = CandleDeterministicPolicy::new(3, 1, 64, 2.0, 3e-4, Device::Cpu).unwrap();
194        let obs = TensorData::new(
195            (0..300).map(|i| (i as f32) * 0.1 - 15.0).collect(),
196            vec![100, 3],
197        );
198        let actions = policy.act(&obs).unwrap();
199        for &a in &actions.data {
200            assert!(
201                a >= -2.0 - 1e-4 && a <= 2.0 + 1e-4,
202                "action out of range: {a}"
203            );
204        }
205    }
206}