1use std::path::Path;
2
3use candle_core::Device;
4use candle_nn::{Optimizer, VarBuilder, VarMap};
5
6use rlox_nn::{
7 Activation, DeterministicPolicy as DeterministicPolicyTrait, MLPConfig, NNError, TensorData,
8 TrainMetrics,
9};
10
11use crate::convert::*;
12use crate::mlp::MLP;
13
14pub struct CandleDeterministicPolicy {
15 net: MLP,
16 target_net: MLP,
17 varmap: VarMap,
18 target_varmap: VarMap,
19 optimizer: candle_nn::AdamW,
20 device: Device,
21 max_action: f32,
22 lr: f64,
23}
24
25impl CandleDeterministicPolicy {
26 pub fn new(
27 obs_dim: usize,
28 act_dim: usize,
29 hidden: usize,
30 max_action: f32,
31 lr: f64,
32 device: Device,
33 ) -> Result<Self, NNError> {
34 let config = MLPConfig::new(obs_dim, act_dim)
35 .with_hidden(vec![hidden, hidden])
36 .with_activation(Activation::ReLU)
37 .with_output_activation(Activation::Tanh);
38
39 let varmap = VarMap::new();
40 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
41 let net = MLP::new(&config, vb.pp("actor")).nn_err()?;
42
43 let target_varmap = VarMap::new();
44 let tvb = VarBuilder::from_varmap(&target_varmap, candle_core::DType::F32, &device);
45 let target_net = MLP::new(&config, tvb.pp("actor")).nn_err()?;
46
47 {
49 let src = varmap.data().lock().unwrap();
50 let tgt = target_varmap.data().lock().unwrap();
51 for (name, var) in src.iter() {
52 if let Some(tvar) = tgt.get(name) {
53 tvar.set(&var.as_tensor().clone()).unwrap();
54 }
55 }
56 }
57
58 let params = varmap.all_vars();
59 let optimizer = candle_nn::AdamW::new(
60 params,
61 candle_nn::ParamsAdamW {
62 lr,
63 ..Default::default()
64 },
65 )
66 .nn_err()?;
67
68 Ok(Self {
69 net,
70 target_net,
71 varmap,
72 target_varmap,
73 optimizer,
74 device,
75 max_action,
76 lr,
77 })
78 }
79}
80
81impl CandleDeterministicPolicy {
82 pub fn td3_actor_step(
87 &mut self,
88 obs: &TensorData,
89 critic: &crate::continuous_q::CandleTwinQ,
90 ) -> Result<TrainMetrics, NNError> {
91 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
92 let actions = self.net.forward(&obs_t).nn_err()?;
93 let scaled = (&actions * self.max_action as f64).nn_err()?;
94
95 let q1 = critic
97 .q1_forward(&obs_t, &scaled)
98 .nn_err()?
99 .squeeze(1)
100 .nn_err()?;
101
102 let actor_loss = q1.neg().nn_err()?.mean_all().nn_err()?;
103
104 self.optimizer.backward_step(&actor_loss).nn_err()?;
105
106 let loss_val: f32 = actor_loss.to_scalar().nn_err()?;
107 let mut metrics = TrainMetrics::new();
108 metrics.insert("actor_loss", loss_val as f64);
109 Ok(metrics)
110 }
111}
112
113impl DeterministicPolicyTrait for CandleDeterministicPolicy {
114 fn act(&self, obs: &TensorData) -> Result<TensorData, NNError> {
115 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
116 let actions = self.net.forward(&obs_t).nn_err()?;
117 let scaled = (&actions * self.max_action as f64).nn_err()?;
118 from_tensor_2d(&scaled).nn_err()
119 }
120
121 fn target_act(&self, obs: &TensorData) -> Result<TensorData, NNError> {
122 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
123 let actions = self.target_net.forward(&obs_t).nn_err()?;
124 let scaled = (&actions * self.max_action as f64).nn_err()?;
125 from_tensor_2d(&scaled).nn_err()
126 }
127
128 fn soft_update_target(&mut self, tau: f32) {
129 let src = self.varmap.data().lock().unwrap();
130 let tgt = self.target_varmap.data().lock().unwrap();
131 for (name, var) in src.iter() {
132 if let Some(tvar) = tgt.get(name) {
133 let src_t = var.as_tensor();
134 let tgt_t = tvar.as_tensor();
135 let new_val = ((src_t * tau as f64).unwrap()
136 + (tgt_t * (1.0 - tau) as f64).unwrap())
137 .unwrap();
138 tvar.set(&new_val).unwrap();
139 }
140 }
141 }
142
143 fn learning_rate(&self) -> f32 {
144 self.lr as f32
145 }
146
147 fn set_learning_rate(&mut self, lr: f32) {
148 self.lr = lr as f64;
149 self.optimizer.set_learning_rate(lr as f64);
150 }
151
152 fn save(&self, path: &Path) -> Result<(), NNError> {
153 self.varmap
154 .save(path)
155 .map_err(|e| NNError::Serialization(e.to_string()))
156 }
157
158 fn load(&mut self, path: &Path) -> Result<(), NNError> {
159 self.varmap
160 .load(path)
161 .map_err(|e| NNError::Serialization(e.to_string()))
162 }
163}
164
165unsafe impl Send for CandleDeterministicPolicy {}
166unsafe impl Sync for CandleDeterministicPolicy {}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_act_shape() {
174 let policy = CandleDeterministicPolicy::new(3, 1, 64, 1.0, 3e-4, Device::Cpu).unwrap();
175 let obs = TensorData::zeros(vec![8, 3]);
176 let actions = policy.act(&obs).unwrap();
177 assert_eq!(actions.shape, vec![8, 1]);
178 }
179
180 #[test]
181 fn test_target_matches_initially() {
182 let policy = CandleDeterministicPolicy::new(3, 1, 64, 1.0, 3e-4, Device::Cpu).unwrap();
183 let obs = TensorData::zeros(vec![4, 3]);
184 let act = policy.act(&obs).unwrap();
185 let tgt = policy.target_act(&obs).unwrap();
186 for (a, b) in act.data.iter().zip(tgt.data.iter()) {
187 assert!((a - b).abs() < 1e-5, "{a} vs {b}");
188 }
189 }
190
191 #[test]
192 fn test_action_range() {
193 let policy = CandleDeterministicPolicy::new(3, 1, 64, 2.0, 3e-4, Device::Cpu).unwrap();
194 let obs = TensorData::new(
195 (0..300).map(|i| (i as f32) * 0.1 - 15.0).collect(),
196 vec![100, 3],
197 );
198 let actions = policy.act(&obs).unwrap();
199 for &a in &actions.data {
200 assert!(
201 a >= -2.0 - 1e-4 && a <= 2.0 + 1e-4,
202 "action out of range: {a}"
203 );
204 }
205 }
206}