rlox_candle/
actor_critic.rs

1use std::cell::RefCell;
2use std::path::Path;
3
4use candle_core::Device;
5use candle_nn::{Optimizer, VarBuilder, VarMap};
6use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8
9use rlox_nn::distributions::{categorical_entropy, categorical_log_prob, categorical_sample};
10use rlox_nn::{
11    ActionOutput, Activation, EvalOutput, MLPConfig, NNError, PPOStepConfig, TensorData,
12    TrainMetrics,
13};
14
15use crate::convert::*;
16use crate::mlp::MLP;
17
18pub struct CandleActorCritic {
19    actor: MLP,
20    critic: MLP,
21    pub varmap: VarMap,
22    optimizer: candle_nn::AdamW,
23    device: Device,
24    n_actions: usize,
25    lr: f64,
26    rng: RefCell<ChaCha8Rng>,
27}
28
29impl CandleActorCritic {
30    pub fn new(
31        obs_dim: usize,
32        n_actions: usize,
33        hidden: usize,
34        lr: f64,
35        device: Device,
36        seed: u64,
37    ) -> Result<Self, NNError> {
38        let varmap = VarMap::new();
39        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
40
41        let actor_config = MLPConfig::new(obs_dim, n_actions)
42            .with_hidden(vec![hidden, hidden])
43            .with_activation(Activation::Tanh);
44        let critic_config = MLPConfig::new(obs_dim, 1)
45            .with_hidden(vec![hidden, hidden])
46            .with_activation(Activation::Tanh);
47
48        let actor = MLP::new(&actor_config, vb.pp("actor")).nn_err()?;
49        let critic = MLP::new(&critic_config, vb.pp("critic")).nn_err()?;
50
51        let params = varmap.all_vars();
52        let optimizer = candle_nn::AdamW::new(
53            params,
54            candle_nn::ParamsAdamW {
55                lr,
56                ..Default::default()
57            },
58        )
59        .nn_err()?;
60
61        Ok(Self {
62            actor,
63            critic,
64            varmap,
65            optimizer,
66            device,
67            n_actions,
68            lr,
69            rng: RefCell::new(ChaCha8Rng::seed_from_u64(seed)),
70        })
71    }
72
73    fn compute_logits(&self, obs: &TensorData) -> Result<Vec<Vec<f32>>, NNError> {
74        let batch_size = obs.shape[0];
75        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
76        let logits = self.actor.forward(&obs_t).nn_err()?;
77        let logits_flat: Vec<f32> = logits.flatten_all().nn_err()?.to_vec1().nn_err()?;
78
79        Ok((0..batch_size)
80            .map(|i| logits_flat[i * self.n_actions..(i + 1) * self.n_actions].to_vec())
81            .collect())
82    }
83}
84
85impl rlox_nn::ActorCritic for CandleActorCritic {
86    fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError> {
87        if obs.shape.len() != 2 {
88            return Err(NNError::ShapeMismatch {
89                expected: "2D [batch, obs_dim]".into(),
90                got: format!("{:?}", obs.shape),
91            });
92        }
93
94        let batch_size = obs.shape[0];
95        let batch_logits = self.compute_logits(obs)?;
96
97        let mut actions = Vec::with_capacity(batch_size);
98        let mut log_probs = Vec::with_capacity(batch_size);
99
100        let mut rng = self.rng.borrow_mut();
101        for logits in &batch_logits {
102            let u: f32 = rand::Rng::random(&mut *rng);
103            let action = categorical_sample(logits, u);
104            let lp = categorical_log_prob(logits, action);
105            actions.push(action as f32);
106            log_probs.push(lp);
107        }
108
109        Ok(ActionOutput {
110            actions: TensorData::new(actions, vec![batch_size]),
111            log_probs: TensorData::new(log_probs, vec![batch_size]),
112        })
113    }
114
115    fn value(&self, obs: &TensorData) -> Result<TensorData, NNError> {
116        if obs.shape.len() != 2 {
117            return Err(NNError::ShapeMismatch {
118                expected: "2D [batch, obs_dim]".into(),
119                got: format!("{:?}", obs.shape),
120            });
121        }
122
123        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
124        let values = self.critic.forward(&obs_t).nn_err()?;
125        let values = values.squeeze(1).nn_err()?;
126        from_tensor_1d(&values).nn_err()
127    }
128
129    fn evaluate(&self, obs: &TensorData, actions: &TensorData) -> Result<EvalOutput, NNError> {
130        let batch_size = obs.shape[0];
131        let batch_logits = self.compute_logits(obs)?;
132
133        let mut log_probs = Vec::with_capacity(batch_size);
134        let mut entropies = Vec::with_capacity(batch_size);
135
136        for (i, logits) in batch_logits.iter().enumerate() {
137            let action = actions.data[i] as usize;
138            log_probs.push(categorical_log_prob(logits, action));
139            entropies.push(categorical_entropy(logits));
140        }
141
142        let values = self.value(obs)?;
143
144        Ok(EvalOutput {
145            log_probs: TensorData::new(log_probs, vec![batch_size]),
146            entropy: TensorData::new(entropies, vec![batch_size]),
147            values,
148        })
149    }
150
151    fn ppo_step(
152        &mut self,
153        obs: &TensorData,
154        actions: &TensorData,
155        old_log_probs: &TensorData,
156        advantages: &TensorData,
157        returns: &TensorData,
158        old_values: &TensorData,
159        config: &PPOStepConfig,
160    ) -> Result<TrainMetrics, NNError> {
161        let batch_size = obs.shape[0];
162        let obs_t = to_tensor_2d(obs, &self.device).nn_err()?;
163
164        // Actor forward
165        let logits = self.actor.forward(&obs_t).nn_err()?;
166        let log_probs_all = candle_nn::ops::log_softmax(&logits, 1).nn_err()?;
167
168        // Gather action log-probs
169        let actions_idx = to_int_tensor_1d(actions, &self.device).nn_err()?;
170        let actions_2d = actions_idx.unsqueeze(1).nn_err()?;
171        let new_log_probs = log_probs_all
172            .gather(&actions_2d, 1)
173            .nn_err()?
174            .squeeze(1)
175            .nn_err()?;
176
177        // Entropy
178        let probs = candle_nn::ops::softmax(&logits, 1).nn_err()?;
179        let entropy = (&probs * &log_probs_all)
180            .nn_err()?
181            .sum(1)
182            .nn_err()?
183            .neg()
184            .nn_err()?;
185
186        // Critic forward
187        let new_values = self.critic.forward(&obs_t).nn_err()?.squeeze(1).nn_err()?;
188
189        // PPO loss
190        let old_lp = to_tensor_1d(old_log_probs, &self.device).nn_err()?;
191        let adv = to_tensor_1d(advantages, &self.device).nn_err()?;
192        let ret = to_tensor_1d(returns, &self.device).nn_err()?;
193
194        let log_ratio = (&new_log_probs - &old_lp).nn_err()?;
195        let ratio = log_ratio.exp().nn_err()?;
196
197        let pg_loss1 = (&adv.neg().nn_err()? * &ratio).nn_err()?;
198        let clamped = ratio
199            .clamp(1.0 - config.clip_eps, 1.0 + config.clip_eps)
200            .nn_err()?;
201        let pg_loss2 = (&adv.neg().nn_err()? * &clamped).nn_err()?;
202        let policy_loss = pg_loss1.maximum(&pg_loss2).nn_err()?.mean_all().nn_err()?;
203
204        // Value loss
205        let value_loss = if config.clip_vloss {
206            let old_v = to_tensor_1d(old_values, &self.device).nn_err()?;
207            let v_diff = (&new_values - &old_v).nn_err()?;
208            let v_clipped =
209                (&old_v + v_diff.clamp(-config.clip_eps, config.clip_eps).nn_err()?).nn_err()?;
210            let vf1 = (&new_values - &ret).nn_err()?.sqr().nn_err()?;
211            let vf2 = (&v_clipped - &ret).nn_err()?.sqr().nn_err()?;
212            (vf1.maximum(&vf2).nn_err()?.mean_all().nn_err()? * 0.5).nn_err()?
213        } else {
214            ((&new_values - &ret)
215                .nn_err()?
216                .sqr()
217                .nn_err()?
218                .mean_all()
219                .nn_err()?
220                * 0.5)
221                .nn_err()?
222        };
223
224        let entropy_loss = entropy.mean_all().nn_err()?;
225
226        let total_loss = ((&policy_loss + (&value_loss * config.vf_coef as f64).nn_err()?)
227            .nn_err()?
228            - (&entropy_loss * config.ent_coef as f64).nn_err()?)
229        .nn_err()?;
230
231        self.optimizer.backward_step(&total_loss).nn_err()?;
232
233        // Extract metrics
234        let policy_loss_val: f32 = policy_loss.to_scalar().nn_err()?;
235        let value_loss_val: f32 = value_loss.to_scalar().nn_err()?;
236        let entropy_val: f32 = entropy_loss.to_scalar().nn_err()?;
237
238        let ratio_data: Vec<f32> = ratio.to_vec1().nn_err()?;
239        let log_ratio_data: Vec<f32> = log_ratio.to_vec1().nn_err()?;
240        let approx_kl: f32 = ratio_data
241            .iter()
242            .zip(log_ratio_data.iter())
243            .map(|(&r, &lr)| (r - 1.0) - lr)
244            .sum::<f32>()
245            / batch_size as f32;
246        let clip_fraction: f32 = ratio_data
247            .iter()
248            .filter(|&&r| (r - 1.0).abs() > config.clip_eps)
249            .count() as f32
250            / batch_size as f32;
251
252        let mut metrics = TrainMetrics::new();
253        metrics.insert("policy_loss", policy_loss_val as f64);
254        metrics.insert("value_loss", value_loss_val as f64);
255        metrics.insert("entropy", entropy_val as f64);
256        metrics.insert("approx_kl", approx_kl as f64);
257        metrics.insert("clip_fraction", clip_fraction as f64);
258
259        Ok(metrics)
260    }
261
262    fn learning_rate(&self) -> f32 {
263        self.lr as f32
264    }
265
266    fn set_learning_rate(&mut self, lr: f32) {
267        self.lr = lr as f64;
268        self.optimizer.set_learning_rate(lr as f64);
269    }
270
271    fn save(&self, path: &Path) -> Result<(), NNError> {
272        self.varmap
273            .save(path)
274            .map_err(|e| NNError::Serialization(e.to_string()))
275    }
276
277    fn load(&mut self, path: &Path) -> Result<(), NNError> {
278        self.varmap
279            .load(path)
280            .map_err(|e| NNError::Serialization(e.to_string()))
281    }
282}
283
284// Send + Sync are required by the trait. Candle tensors on CPU are Send + Sync.
285unsafe impl Send for CandleActorCritic {}
286unsafe impl Sync for CandleActorCritic {}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use rlox_nn::ActorCritic;
292
293    #[test]
294    fn test_act_shapes() {
295        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
296        let obs = TensorData::zeros(vec![8, 4]);
297        let result = ac.act(&obs).unwrap();
298        assert_eq!(result.actions.shape, vec![8]);
299        assert_eq!(result.log_probs.shape, vec![8]);
300    }
301
302    #[test]
303    fn test_value_shape() {
304        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
305        let obs = TensorData::zeros(vec![8, 4]);
306        let values = ac.value(&obs).unwrap();
307        assert_eq!(values.shape, vec![8]);
308    }
309
310    #[test]
311    fn test_evaluate_shapes() {
312        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
313        let obs = TensorData::zeros(vec![4, 4]);
314        let actions = TensorData::new(vec![0.0, 1.0, 0.0, 1.0], vec![4]);
315        let eval = ac.evaluate(&obs, &actions).unwrap();
316        assert_eq!(eval.log_probs.shape, vec![4]);
317        assert_eq!(eval.entropy.shape, vec![4]);
318        assert_eq!(eval.values.shape, vec![4]);
319    }
320
321    #[test]
322    fn test_ppo_step_runs() {
323        let mut ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
324        let bs = 32;
325        let obs = TensorData::zeros(vec![bs, 4]);
326        let actions = TensorData::new(vec![0.0; bs], vec![bs]);
327        let old_lp = TensorData::new(vec![-0.7; bs], vec![bs]);
328        let adv = TensorData::new(vec![1.0; bs], vec![bs]);
329        let ret = TensorData::new(vec![1.0; bs], vec![bs]);
330        let old_v = TensorData::zeros(vec![bs]);
331        let config = PPOStepConfig::default();
332
333        let metrics = ac
334            .ppo_step(&obs, &actions, &old_lp, &adv, &ret, &old_v, &config)
335            .unwrap();
336        assert!(metrics.get("policy_loss").is_some());
337        assert!(metrics.get("value_loss").is_some());
338        assert!(metrics.get("entropy").is_some());
339    }
340
341    #[test]
342    fn test_lr_get_set() {
343        let mut ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
344        assert!((ac.learning_rate() - 2.5e-4).abs() < 1e-8);
345        ac.set_learning_rate(1e-3);
346        assert!((ac.learning_rate() - 1e-3).abs() < 1e-8);
347    }
348
349    #[test]
350    fn test_act_invalid_shape() {
351        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
352        let obs = TensorData::zeros(vec![4]); // 1D should fail
353        assert!(ac.act(&obs).is_err());
354    }
355
356    #[test]
357    fn test_act_rng_advances() {
358        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
359        let obs = TensorData::zeros(vec![1, 4]);
360        // Same obs, same logits → without RNG advance, same action every time.
361        // With the fix, the RNG state advances so we may get different random draws.
362        let mut seen_different = false;
363        let first = ac.act(&obs).unwrap().log_probs.data[0];
364        for _ in 0..20 {
365            let lp = ac.act(&obs).unwrap().log_probs.data[0];
366            if (lp - first).abs() > 1e-6 {
367                seen_different = true;
368                break;
369            }
370        }
371        // With 2 actions and advancing RNG, probability of 20 identical draws is (0.5)^20 ≈ 1e-6
372        assert!(seen_different, "RNG should advance between act() calls");
373    }
374}