1use std::cell::RefCell;
2use std::path::Path;
3
4use candle_core::Device;
5use candle_nn::{Optimizer, VarBuilder, VarMap};
6use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8
9use rlox_nn::distributions::{categorical_entropy, categorical_log_prob, categorical_sample};
10use rlox_nn::{
11 ActionOutput, Activation, EvalOutput, MLPConfig, NNError, PPOStepConfig, TensorData,
12 TrainMetrics,
13};
14
15use crate::convert::*;
16use crate::mlp::MLP;
17
18pub struct CandleActorCritic {
19 actor: MLP,
20 critic: MLP,
21 pub varmap: VarMap,
22 optimizer: candle_nn::AdamW,
23 device: Device,
24 n_actions: usize,
25 lr: f64,
26 rng: RefCell<ChaCha8Rng>,
27}
28
29impl CandleActorCritic {
30 pub fn new(
31 obs_dim: usize,
32 n_actions: usize,
33 hidden: usize,
34 lr: f64,
35 device: Device,
36 seed: u64,
37 ) -> Result<Self, NNError> {
38 let varmap = VarMap::new();
39 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
40
41 let actor_config = MLPConfig::new(obs_dim, n_actions)
42 .with_hidden(vec![hidden, hidden])
43 .with_activation(Activation::Tanh);
44 let critic_config = MLPConfig::new(obs_dim, 1)
45 .with_hidden(vec![hidden, hidden])
46 .with_activation(Activation::Tanh);
47
48 let actor = MLP::new(&actor_config, vb.pp("actor")).nn_err()?;
49 let critic = MLP::new(&critic_config, vb.pp("critic")).nn_err()?;
50
51 let params = varmap.all_vars();
52 let optimizer = candle_nn::AdamW::new(
53 params,
54 candle_nn::ParamsAdamW {
55 lr,
56 ..Default::default()
57 },
58 )
59 .nn_err()?;
60
61 Ok(Self {
62 actor,
63 critic,
64 varmap,
65 optimizer,
66 device,
67 n_actions,
68 lr,
69 rng: RefCell::new(ChaCha8Rng::seed_from_u64(seed)),
70 })
71 }
72
73 fn compute_logits(&self, obs: &TensorData) -> Result<Vec<Vec<f32>>, NNError> {
74 let batch_size = obs.shape[0];
75 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
76 let logits = self.actor.forward(&obs_t).nn_err()?;
77 let logits_flat: Vec<f32> = logits.flatten_all().nn_err()?.to_vec1().nn_err()?;
78
79 Ok((0..batch_size)
80 .map(|i| logits_flat[i * self.n_actions..(i + 1) * self.n_actions].to_vec())
81 .collect())
82 }
83}
84
85impl rlox_nn::ActorCritic for CandleActorCritic {
86 fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError> {
87 if obs.shape.len() != 2 {
88 return Err(NNError::ShapeMismatch {
89 expected: "2D [batch, obs_dim]".into(),
90 got: format!("{:?}", obs.shape),
91 });
92 }
93
94 let batch_size = obs.shape[0];
95 let batch_logits = self.compute_logits(obs)?;
96
97 let mut actions = Vec::with_capacity(batch_size);
98 let mut log_probs = Vec::with_capacity(batch_size);
99
100 let mut rng = self.rng.borrow_mut();
101 for logits in &batch_logits {
102 let u: f32 = rand::Rng::random(&mut *rng);
103 let action = categorical_sample(logits, u);
104 let lp = categorical_log_prob(logits, action);
105 actions.push(action as f32);
106 log_probs.push(lp);
107 }
108
109 Ok(ActionOutput {
110 actions: TensorData::new(actions, vec![batch_size]),
111 log_probs: TensorData::new(log_probs, vec![batch_size]),
112 })
113 }
114
115 fn value(&self, obs: &TensorData) -> Result<TensorData, NNError> {
116 if obs.shape.len() != 2 {
117 return Err(NNError::ShapeMismatch {
118 expected: "2D [batch, obs_dim]".into(),
119 got: format!("{:?}", obs.shape),
120 });
121 }
122
123 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
124 let values = self.critic.forward(&obs_t).nn_err()?;
125 let values = values.squeeze(1).nn_err()?;
126 from_tensor_1d(&values).nn_err()
127 }
128
129 fn evaluate(&self, obs: &TensorData, actions: &TensorData) -> Result<EvalOutput, NNError> {
130 let batch_size = obs.shape[0];
131 let batch_logits = self.compute_logits(obs)?;
132
133 let mut log_probs = Vec::with_capacity(batch_size);
134 let mut entropies = Vec::with_capacity(batch_size);
135
136 for (i, logits) in batch_logits.iter().enumerate() {
137 let action = actions.data[i] as usize;
138 log_probs.push(categorical_log_prob(logits, action));
139 entropies.push(categorical_entropy(logits));
140 }
141
142 let values = self.value(obs)?;
143
144 Ok(EvalOutput {
145 log_probs: TensorData::new(log_probs, vec![batch_size]),
146 entropy: TensorData::new(entropies, vec![batch_size]),
147 values,
148 })
149 }
150
151 fn ppo_step(
152 &mut self,
153 obs: &TensorData,
154 actions: &TensorData,
155 old_log_probs: &TensorData,
156 advantages: &TensorData,
157 returns: &TensorData,
158 old_values: &TensorData,
159 config: &PPOStepConfig,
160 ) -> Result<TrainMetrics, NNError> {
161 let batch_size = obs.shape[0];
162 let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
163
164 let logits = self.actor.forward(&obs_t).nn_err()?;
166 let log_probs_all = candle_nn::ops::log_softmax(&logits, 1).nn_err()?;
167
168 let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
170 let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
171 let new_log_probs = log_probs_all
172 .gather(&actions_2d, 1)
173 .nn_err()?
174 .squeeze(1)
175 .nn_err()?;
176
177 let probs = candle_nn::ops::softmax(&logits, 1).nn_err()?;
179 let entropy = (&probs * &log_probs_all)
180 .nn_err()?
181 .sum(1)
182 .nn_err()?
183 .neg()
184 .nn_err()?;
185
186 let new_values = self.critic.forward(&obs_t).nn_err()?.squeeze(1).nn_err()?;
188
189 let old_lp = to_tensor_1d(old_log_probs, &self.device).nn_err()?;
191 let adv = to_tensor_1d(advantages, &self.device).nn_err()?;
192 let ret = to_tensor_1d(returns, &self.device).nn_err()?;
193
194 let log_ratio = (&new_log_probs - &old_lp).nn_err()?;
195 let ratio = log_ratio.exp().nn_err()?;
196
197 let pg_loss1 = (&adv.neg().nn_err()? * &ratio).nn_err()?;
198 let clamped = ratio
199 .clamp(1.0 - config.clip_eps, 1.0 + config.clip_eps)
200 .nn_err()?;
201 let pg_loss2 = (&adv.neg().nn_err()? * &clamped).nn_err()?;
202 let policy_loss = pg_loss1.maximum(&pg_loss2).nn_err()?.mean_all().nn_err()?;
203
204 let value_loss = if config.clip_vloss {
206 let old_v = to_tensor_1d(old_values, &self.device).nn_err()?;
207 let v_diff = (&new_values - &old_v).nn_err()?;
208 let v_clipped =
209 (&old_v + v_diff.clamp(-config.clip_eps, config.clip_eps).nn_err()?).nn_err()?;
210 let vf1 = (&new_values - &ret).nn_err()?.sqr().nn_err()?;
211 let vf2 = (&v_clipped - &ret).nn_err()?.sqr().nn_err()?;
212 (vf1.maximum(&vf2).nn_err()?.mean_all().nn_err()? * 0.5).nn_err()?
213 } else {
214 ((&new_values - &ret)
215 .nn_err()?
216 .sqr()
217 .nn_err()?
218 .mean_all()
219 .nn_err()?
220 * 0.5)
221 .nn_err()?
222 };
223
224 let entropy_loss = entropy.mean_all().nn_err()?;
225
226 let total_loss = ((&policy_loss + (&value_loss * config.vf_coef as f64).nn_err()?)
227 .nn_err()?
228 - (&entropy_loss * config.ent_coef as f64).nn_err()?)
229 .nn_err()?;
230
231 self.optimizer.backward_step(&total_loss).nn_err()?;
232
233 let policy_loss_val: f32 = policy_loss.to_scalar().nn_err()?;
235 let value_loss_val: f32 = value_loss.to_scalar().nn_err()?;
236 let entropy_val: f32 = entropy_loss.to_scalar().nn_err()?;
237
238 let ratio_data: Vec<f32> = ratio.to_vec1().nn_err()?;
239 let log_ratio_data: Vec<f32> = log_ratio.to_vec1().nn_err()?;
240 let approx_kl: f32 = ratio_data
241 .iter()
242 .zip(log_ratio_data.iter())
243 .map(|(&r, &lr)| (r - 1.0) - lr)
244 .sum::<f32>()
245 / batch_size as f32;
246 let clip_fraction: f32 = ratio_data
247 .iter()
248 .filter(|&&r| (r - 1.0).abs() > config.clip_eps)
249 .count() as f32
250 / batch_size as f32;
251
252 let mut metrics = TrainMetrics::new();
253 metrics.insert("policy_loss", policy_loss_val as f64);
254 metrics.insert("value_loss", value_loss_val as f64);
255 metrics.insert("entropy", entropy_val as f64);
256 metrics.insert("approx_kl", approx_kl as f64);
257 metrics.insert("clip_fraction", clip_fraction as f64);
258
259 Ok(metrics)
260 }
261
262 fn learning_rate(&self) -> f32 {
263 self.lr as f32
264 }
265
266 fn set_learning_rate(&mut self, lr: f32) {
267 self.lr = lr as f64;
268 self.optimizer.set_learning_rate(lr as f64);
269 }
270
271 fn save(&self, path: &Path) -> Result<(), NNError> {
272 self.varmap
273 .save(path)
274 .map_err(|e| NNError::Serialization(e.to_string()))
275 }
276
277 fn load(&mut self, path: &Path) -> Result<(), NNError> {
278 self.varmap
279 .load(path)
280 .map_err(|e| NNError::Serialization(e.to_string()))
281 }
282}
283
284unsafe impl Send for CandleActorCritic {}
286unsafe impl Sync for CandleActorCritic {}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use rlox_nn::ActorCritic;
292
293 #[test]
294 fn test_act_shapes() {
295 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
296 let obs = TensorData::zeros(vec![8, 4]);
297 let result = ac.act(&obs).unwrap();
298 assert_eq!(result.actions.shape, vec![8]);
299 assert_eq!(result.log_probs.shape, vec![8]);
300 }
301
302 #[test]
303 fn test_value_shape() {
304 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
305 let obs = TensorData::zeros(vec![8, 4]);
306 let values = ac.value(&obs).unwrap();
307 assert_eq!(values.shape, vec![8]);
308 }
309
310 #[test]
311 fn test_evaluate_shapes() {
312 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
313 let obs = TensorData::zeros(vec![4, 4]);
314 let actions = TensorData::new(vec![0.0, 1.0, 0.0, 1.0], vec![4]);
315 let eval = ac.evaluate(&obs, &actions).unwrap();
316 assert_eq!(eval.log_probs.shape, vec![4]);
317 assert_eq!(eval.entropy.shape, vec![4]);
318 assert_eq!(eval.values.shape, vec![4]);
319 }
320
321 #[test]
322 fn test_ppo_step_runs() {
323 let mut ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
324 let bs = 32;
325 let obs = TensorData::zeros(vec![bs, 4]);
326 let actions = TensorData::new(vec![0.0; bs], vec![bs]);
327 let old_lp = TensorData::new(vec![-0.7; bs], vec![bs]);
328 let adv = TensorData::new(vec![1.0; bs], vec![bs]);
329 let ret = TensorData::new(vec![1.0; bs], vec![bs]);
330 let old_v = TensorData::zeros(vec![bs]);
331 let config = PPOStepConfig::default();
332
333 let metrics = ac
334 .ppo_step(&obs, &actions, &old_lp, &adv, &ret, &old_v, &config)
335 .unwrap();
336 assert!(metrics.get("policy_loss").is_some());
337 assert!(metrics.get("value_loss").is_some());
338 assert!(metrics.get("entropy").is_some());
339 }
340
341 #[test]
342 fn test_lr_get_set() {
343 let mut ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
344 assert!((ac.learning_rate() - 2.5e-4).abs() < 1e-8);
345 ac.set_learning_rate(1e-3);
346 assert!((ac.learning_rate() - 1e-3).abs() < 1e-8);
347 }
348
349 #[test]
350 fn test_act_invalid_shape() {
351 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
352 let obs = TensorData::zeros(vec![4]); assert!(ac.act(&obs).is_err());
354 }
355
356 #[test]
357 fn test_act_rng_advances() {
358 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
359 let obs = TensorData::zeros(vec![1, 4]);
360 let mut seen_different = false;
363 let first = ac.act(&obs).unwrap().log_probs.data[0];
364 for _ in 0..20 {
365 let lp = ac.act(&obs).unwrap().log_probs.data[0];
366 if (lp - first).abs() > 1e-6 {
367 seen_different = true;
368 break;
369 }
370 }
371 assert!(seen_different, "RNG should advance between act() calls");
373 }
374}