1use std::path::Path;
2
3use candle_core::{Device, Tensor};
4use candle_nn::{linear, Linear, Module, Optimizer, VarBuilder, VarMap};
5use rand::SeedableRng;
6use rand_chacha::ChaCha8Rng;
7
8use rlox_nn::{Activation, MLPConfig, NNError, StochasticPolicy, TensorData, TrainMetrics};
9
10use crate::convert::*;
11use crate::mlp::MLP;
12
13const LOG_STD_MIN: f32 = -20.0;
14const LOG_STD_MAX: f32 = 2.0;
15
16pub struct CandleStochasticPolicy {
17 shared: MLP,
18 mean_head: Linear,
19 log_std_head: Linear,
20 varmap: VarMap,
21 optimizer: candle_nn::AdamW,
22 device: Device,
23 act_dim: usize,
24 lr: f64,
25 #[allow(dead_code)]
26 rng: ChaCha8Rng,
27}
28
29impl CandleStochasticPolicy {
30 pub fn new(
31 obs_dim: usize,
32 act_dim: usize,
33 hidden: usize,
34 lr: f64,
35 device: Device,
36 seed: u64,
37 ) -> Result<Self, NNError> {
38 let varmap = VarMap::new();
39 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
40
41 let shared_config = MLPConfig::new(obs_dim, hidden)
42 .with_hidden(vec![hidden])
43 .with_activation(Activation::ReLU);
44
45 let shared = MLP::new(&shared_config, vb.pp("shared")).nn_err()?;
46 let mean_head = linear(hidden, act_dim, vb.pp("mean")).nn_err()?;
47 let log_std_head = linear(hidden, act_dim, vb.pp("log_std")).nn_err()?;
48
49 let params = varmap.all_vars();
50 let optimizer = candle_nn::AdamW::new(
51 params,
52 candle_nn::ParamsAdamW {
53 lr,
54 ..Default::default()
55 },
56 )
57 .nn_err()?;
58
59 Ok(Self {
60 shared,
61 mean_head,
62 log_std_head,
63 varmap,
64 optimizer,
65 device,
66 act_dim,
67 lr,
68 rng: ChaCha8Rng::seed_from_u64(seed),
69 })
70 }
71
72 fn forward(&self, obs: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
73 let h = self.shared.forward(obs)?;
74 let mean = self.mean_head.forward(&h)?;
75 let log_std = self
76 .log_std_head
77 .forward(&h)?
78 .clamp(LOG_STD_MIN, LOG_STD_MAX)?;
79 Ok((mean, log_std))
80 }
81}
82
83impl CandleStochasticPolicy {
84 pub fn sac_actor_step(
89 &mut self,
90 obs: &TensorData,
91 alpha: f32,
92 critic: &crate::continuous_q::CandleTwinQ,
93 ) -> Result<TrainMetrics, NNError> {
94 let batch_size = obs.shape[0];
95 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
96 let (mean, log_std) = self.forward(&obs_t).nn_err()?;
97 let std = log_std.exp().nn_err()?;
98
99 let eps = Tensor::randn(0.0_f32, 1.0, (batch_size, self.act_dim), &self.device).nn_err()?;
100 let x_t = (&mean + &(&std * &eps).nn_err()?).nn_err()?;
101 let y_t = x_t.tanh().nn_err()?;
102
103 let var = std.sqr().nn_err()?;
105 let residual = (&x_t - &mean).nn_err()?;
106 let normal_lp = (&residual.sqr().nn_err()? / &var)
107 .nn_err()?
108 .neg()
109 .nn_err()?
110 / 2.0;
111 let normal_lp = normal_lp.nn_err()?;
112 let log_std_term = std.log().nn_err()?.neg().nn_err()?;
113 let const_term = -0.5 * (2.0 * std::f32::consts::PI).ln();
114 let normal_lp_full =
115 (&(&normal_lp + &log_std_term).nn_err()? + const_term as f64).nn_err()?;
116
117 let y_sq = y_t.sqr().nn_err()?;
118 let one = Tensor::ones_like(&y_sq).nn_err()?;
119 let correction = ((&one - &y_sq).nn_err()? + 1e-6_f64)
120 .nn_err()?
121 .log()
122 .nn_err()?
123 .neg()
124 .nn_err()?;
125 let log_prob = (&normal_lp_full + &correction).nn_err()?.sum(1).nn_err()?;
126
127 let (q1, q2) = critic.twin_q_forward(&obs_t, &y_t).nn_err()?;
129 let q1 = q1.squeeze(1).nn_err()?;
130 let q2 = q2.squeeze(1).nn_err()?;
131 let q_min = q1.minimum(&q2).nn_err()?;
132
133 let actor_loss = (&(&log_prob * alpha as f64).nn_err()? - &q_min)
134 .nn_err()?
135 .mean_all()
136 .nn_err()?;
137
138 self.optimizer.backward_step(&actor_loss).nn_err()?;
139
140 let loss_val: f32 = actor_loss.to_scalar().nn_err()?;
141
142 let mut metrics = TrainMetrics::new();
143 metrics.insert("actor_loss", loss_val as f64);
144 Ok(metrics)
145 }
146}
147
148impl StochasticPolicy for CandleStochasticPolicy {
149 fn sample_actions(&self, obs: &TensorData) -> Result<(TensorData, TensorData), NNError> {
150 let batch_size = obs.shape[0];
151 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
152 let (mean, log_std) = self.forward(&obs_t).nn_err()?;
153 let std = log_std.exp().nn_err()?;
154
155 let eps = Tensor::randn(0.0_f32, 1.0, (batch_size, self.act_dim), &self.device).nn_err()?;
156 let x_t = (&mean + &(&std * &eps).nn_err()?).nn_err()?;
157 let y_t = x_t.tanh().nn_err()?;
158
159 let mean_data: Vec<f32> = mean.flatten_all().nn_err()?.to_vec1().nn_err()?;
160 let std_data: Vec<f32> = std.flatten_all().nn_err()?.to_vec1().nn_err()?;
161 let x_data: Vec<f32> = x_t.flatten_all().nn_err()?.to_vec1().nn_err()?;
162 let y_data: Vec<f32> = y_t.flatten_all().nn_err()?.to_vec1().nn_err()?;
163
164 let mut actions_flat = Vec::with_capacity(batch_size * self.act_dim);
165 let mut log_probs = Vec::with_capacity(batch_size);
166
167 for i in 0..batch_size {
168 let mut lp_sum = 0.0_f32;
169 for j in 0..self.act_dim {
170 let idx = i * self.act_dim + j;
171 let x = x_data[idx];
172 let m = mean_data[idx];
173 let s = std_data[idx];
174 let y = y_data[idx];
175
176 actions_flat.push(y);
177
178 let normal_lp =
179 -0.5 * ((x - m) / s).powi(2) - s.ln() - 0.5 * (2.0 * std::f32::consts::PI).ln();
180 let correction = -(1.0 - y * y + 1e-6).ln();
181 lp_sum += normal_lp + correction;
182 }
183 log_probs.push(lp_sum);
184 }
185
186 Ok((
187 TensorData::new(actions_flat, vec![batch_size, self.act_dim]),
188 TensorData::new(log_probs, vec![batch_size]),
189 ))
190 }
191
192 fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError> {
193 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
194 let (mean, _) = self.forward(&obs_t).nn_err()?;
195 let action = mean.tanh().nn_err()?;
196 from_tensor_2d(&action).nn_err()
197 }
198
199 fn learning_rate(&self) -> f32 {
200 self.lr as f32
201 }
202
203 fn set_learning_rate(&mut self, lr: f32) {
204 self.lr = lr as f64;
205 self.optimizer.set_learning_rate(lr as f64);
206 }
207
208 fn save(&self, path: &Path) -> Result<(), NNError> {
209 self.varmap
210 .save(path)
211 .map_err(|e| NNError::Serialization(e.to_string()))
212 }
213
214 fn load(&mut self, path: &Path) -> Result<(), NNError> {
215 self.varmap
216 .load(path)
217 .map_err(|e| NNError::Serialization(e.to_string()))
218 }
219}
220
221unsafe impl Send for CandleStochasticPolicy {}
222unsafe impl Sync for CandleStochasticPolicy {}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_sample_actions_shape() {
230 let policy = CandleStochasticPolicy::new(3, 2, 64, 3e-4, Device::Cpu, 42).unwrap();
231 let obs = TensorData::zeros(vec![8, 3]);
232 let (actions, log_probs) = policy.sample_actions(&obs).unwrap();
233 assert_eq!(actions.shape, vec![8, 2]);
234 assert_eq!(log_probs.shape, vec![8]);
235 }
236
237 #[test]
238 fn test_sample_in_range() {
239 let policy = CandleStochasticPolicy::new(3, 1, 64, 3e-4, Device::Cpu, 42).unwrap();
240 let obs = TensorData::zeros(vec![100, 3]);
241 let (actions, _) = policy.sample_actions(&obs).unwrap();
242 for &a in &actions.data {
243 assert!(a >= -1.0 && a <= 1.0, "out of range: {a}");
244 }
245 }
246
247 #[test]
248 fn test_deterministic_shape() {
249 let policy = CandleStochasticPolicy::new(3, 2, 64, 3e-4, Device::Cpu, 42).unwrap();
250 let obs = TensorData::zeros(vec![4, 3]);
251 let actions = policy.deterministic_action(&obs).unwrap();
252 assert_eq!(actions.shape, vec![4, 2]);
253 }
254}