rlox_candle/
collector.rs

1//! Candle-powered rollout collector.
2//!
3//! Connects `CandleActorCritic` directly to `AsyncCollector` so that
4//! policy inference runs in pure Rust — no Python calls during collection.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │  Background Thread (pure Rust)          │
11//! │                                         │
12//! │  loop {                                 │
13//! │    obs = VecEnv.step_all()     ~96μs    │
14//! │    act = Candle.act(obs)       ~15μs    │  ← no Python dispatch
15//! │    val = Candle.value(obs)     ~10μs    │
16//! │    buffer.push(obs, act, val)           │
17//! │    GAE = compute_gae_batched() ~5μs     │
18//! │    channel.send(batch)                  │
19//! │  }                                      │
20//! └─────────────────────────────────────────┘
21//!         ↕ crossbeam channel
22//! ┌─────────────────────────────────────────┐
23//! │  Python Main Thread                     │
24//! │                                         │
25//! │  batch = collector.recv()               │
26//! │  loss = ppo_loss(batch)    (PyTorch)    │
27//! │  loss.backward()                        │
28//! │  optimizer.step()                       │
29//! │  collector.sync_weights(flat_params)    │
30//! └─────────────────────────────────────────┘
31//! ```
32
33use std::sync::{Arc, RwLock};
34
35use rlox_nn::{ActorCritic, TensorData};
36
37use crate::actor_critic::CandleActorCritic;
38
39/// Shared policy wrapper that allows the collection thread to read the policy
40/// while the main thread updates weights.
41pub struct SharedPolicy {
42    inner: Arc<RwLock<CandleActorCritic>>,
43}
44
45impl SharedPolicy {
46    pub fn new(policy: CandleActorCritic) -> Self {
47        Self {
48            inner: Arc::new(RwLock::new(policy)),
49        }
50    }
51
52    /// Get a clone of the Arc for the collection thread.
53    pub fn clone_ref(&self) -> Arc<RwLock<CandleActorCritic>> {
54        self.inner.clone()
55    }
56
57    /// Synchronize weights from a flat f32 buffer (exported from PyTorch).
58    ///
59    /// The buffer must contain all parameters in the same order as
60    /// `CandleActorCritic`'s VarMap, flattened and concatenated.
61    pub fn sync_weights(&self, flat_params: &[f32]) -> Result<(), rlox_nn::NNError> {
62        let mut policy = self
63            .inner
64            .write()
65            .map_err(|e| rlox_nn::NNError::Backend(format!("Failed to acquire write lock: {e}")))?;
66        load_flat_params(&mut policy, flat_params)
67    }
68
69    /// Extract weights as a flat f32 buffer (for PyTorch initialization).
70    pub fn get_weights(&self) -> Result<Vec<f32>, rlox_nn::NNError> {
71        let policy = self
72            .inner
73            .read()
74            .map_err(|e| rlox_nn::NNError::Backend(format!("Failed to acquire read lock: {e}")))?;
75        extract_flat_params(&policy)
76    }
77}
78
79/// Build `action_fn` and `value_fn` closures that read from a shared policy.
80///
81/// These closures are compatible with [`AsyncCollector::start`] and perform
82/// policy inference entirely in Rust (no GIL, no Python dispatch).
83pub fn make_candle_callbacks(
84    policy: Arc<RwLock<CandleActorCritic>>,
85    obs_dim: usize,
86) -> (
87    Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync>,
88    Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync>,
89) {
90    let policy_act = policy.clone();
91    let policy_val = policy;
92
93    let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
94        Arc::new(move |obs_flat: &[f32]| {
95            let n_envs = obs_flat.len() / obs_dim;
96            let obs = TensorData::new(obs_flat.to_vec(), vec![n_envs, obs_dim]);
97
98            let p = policy_act.read().unwrap();
99            let output = p.act(&obs).expect("Candle act() failed");
100
101            let actions = output.actions.data;
102            let log_probs: Vec<f64> = output.log_probs.data.iter().map(|&x| x as f64).collect();
103            (actions, log_probs)
104        });
105
106    let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
107        Arc::new(move |obs_flat: &[f32]| {
108            let n_envs = obs_flat.len() / obs_dim;
109            let obs = TensorData::new(obs_flat.to_vec(), vec![n_envs, obs_dim]);
110
111            let p = policy_val.read().unwrap();
112            let values = p.value(&obs).expect("Candle value() failed");
113            values.data.iter().map(|&x| x as f64).collect()
114        });
115
116    (action_fn, value_fn)
117}
118
119/// Load parameters from a flat f32 buffer into a CandleActorCritic's VarMap.
120fn load_flat_params(
121    policy: &mut CandleActorCritic,
122    flat_params: &[f32],
123) -> Result<(), rlox_nn::NNError> {
124    let vars = policy.varmap.all_vars();
125    let mut offset = 0;
126
127    for var in &vars {
128        let numel = var.elem_count();
129        if offset + numel > flat_params.len() {
130            return Err(rlox_nn::NNError::ShapeMismatch {
131                expected: format!("at least {} elements", offset + numel),
132                got: format!("{} elements", flat_params.len()),
133            });
134        }
135
136        let slice = &flat_params[offset..offset + numel];
137        let shape = var.dims();
138        let tensor = candle_core::Tensor::from_vec(slice.to_vec(), shape, var.device())
139            .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
140        var.set(&tensor)
141            .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
142        offset += numel;
143    }
144
145    if offset != flat_params.len() {
146        return Err(rlox_nn::NNError::ShapeMismatch {
147            expected: format!("{} total elements", offset),
148            got: format!("{} elements", flat_params.len()),
149        });
150    }
151
152    Ok(())
153}
154
155/// Extract all parameters from a CandleActorCritic as a flat f32 buffer.
156fn extract_flat_params(policy: &CandleActorCritic) -> Result<Vec<f32>, rlox_nn::NNError> {
157    let vars = policy.varmap.all_vars();
158    let total: usize = vars.iter().map(|v| v.elem_count()).sum();
159    let mut flat = Vec::with_capacity(total);
160
161    for var in &vars {
162        let data: Vec<f32> = var
163            .flatten_all()
164            .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?
165            .to_vec1()
166            .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
167        flat.extend(data);
168    }
169
170    Ok(flat)
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use candle_core::Device;
177    use rlox_nn::ActorCritic;
178
179    #[test]
180    fn test_shared_policy_act() {
181        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
182        let shared = SharedPolicy::new(ac);
183        let policy_ref = shared.clone_ref();
184
185        let obs = TensorData::zeros(vec![8, 4]);
186        let p = policy_ref.read().unwrap();
187        let result = p.act(&obs).unwrap();
188        assert_eq!(result.actions.shape, vec![8]);
189    }
190
191    #[test]
192    fn test_weight_roundtrip() {
193        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
194        let shared = SharedPolicy::new(ac);
195
196        // Extract weights
197        let weights = shared.get_weights().unwrap();
198        assert!(!weights.is_empty());
199
200        // Sync back (should be a no-op identity)
201        shared.sync_weights(&weights).unwrap();
202
203        // Verify still works
204        let weights2 = shared.get_weights().unwrap();
205        assert_eq!(weights.len(), weights2.len());
206        for (a, b) in weights.iter().zip(weights2.iter()) {
207            assert!((a - b).abs() < 1e-6, "Weight mismatch: {a} vs {b}");
208        }
209    }
210
211    #[test]
212    fn test_sync_weights_wrong_size() {
213        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
214        let shared = SharedPolicy::new(ac);
215
216        let result = shared.sync_weights(&[1.0, 2.0, 3.0]); // too few
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_make_candle_callbacks() {
222        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
223        let shared = SharedPolicy::new(ac);
224        let (action_fn, value_fn) = make_candle_callbacks(shared.clone_ref(), 4);
225
226        // 2 envs, obs_dim=4
227        let obs = vec![0.0f32; 8];
228        let (actions, log_probs) = action_fn(&obs);
229        assert_eq!(actions.len(), 2);
230        assert_eq!(log_probs.len(), 2);
231
232        let values = value_fn(&obs);
233        assert_eq!(values.len(), 2);
234    }
235
236    #[test]
237    fn test_candle_callbacks_with_async_collector() {
238        use rlox_core::env::builtins::CartPole;
239        use rlox_core::env::parallel::VecEnv;
240        use rlox_core::env::RLEnv;
241        use rlox_core::pipeline::channel::Pipeline;
242        use rlox_core::pipeline::collector::AsyncCollector;
243        use rlox_core::seed::derive_seed;
244
245        let n_envs = 2;
246        let envs: Vec<Box<dyn RLEnv>> = (0..n_envs)
247            .map(|i| Box::new(CartPole::new(Some(derive_seed(42, i)))) as Box<dyn RLEnv>)
248            .collect();
249        let vec_env = Box::new(VecEnv::new(envs).unwrap());
250
251        // Create Candle policy and callbacks
252        let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
253        let shared = SharedPolicy::new(ac);
254        let (action_fn, value_fn) = make_candle_callbacks(shared.clone_ref(), 4);
255
256        // Run async collector with Candle callbacks
257        let pipe = Pipeline::new(4);
258        let tx = pipe.sender();
259        let mut collector = AsyncCollector::start(vec_env, 16, 0.99, 0.95, tx, value_fn, action_fn);
260
261        // Should receive a batch with log_probs and values
262        let batch = pipe.recv().unwrap();
263        assert_eq!(batch.n_steps, 16);
264        assert_eq!(batch.n_envs, 2);
265        assert_eq!(batch.observations.len(), 16 * 2 * 4);
266        assert_eq!(batch.log_probs.len(), 16 * 2);
267        assert_eq!(batch.values.len(), 16 * 2);
268        assert_eq!(batch.advantages.len(), 16 * 2);
269
270        for &lp in &batch.log_probs {
271            assert!(lp.is_finite(), "log_prob must be finite");
272        }
273        for &v in &batch.values {
274            assert!(v.is_finite(), "value must be finite");
275        }
276
277        // Weight sync should work while collector is running
278        let weights = shared.get_weights().unwrap();
279        shared.sync_weights(&weights).unwrap();
280
281        collector.stop();
282    }
283}