1use std::sync::{Arc, RwLock};
34
35use rlox_nn::{ActorCritic, TensorData};
36
37use crate::actor_critic::CandleActorCritic;
38
39pub struct SharedPolicy {
42 inner: Arc<RwLock<CandleActorCritic>>,
43}
44
45impl SharedPolicy {
46 pub fn new(policy: CandleActorCritic) -> Self {
47 Self {
48 inner: Arc::new(RwLock::new(policy)),
49 }
50 }
51
52 pub fn clone_ref(&self) -> Arc<RwLock<CandleActorCritic>> {
54 self.inner.clone()
55 }
56
57 pub fn sync_weights(&self, flat_params: &[f32]) -> Result<(), rlox_nn::NNError> {
62 let mut policy = self
63 .inner
64 .write()
65 .map_err(|e| rlox_nn::NNError::Backend(format!("Failed to acquire write lock: {e}")))?;
66 load_flat_params(&mut policy, flat_params)
67 }
68
69 pub fn get_weights(&self) -> Result<Vec<f32>, rlox_nn::NNError> {
71 let policy = self
72 .inner
73 .read()
74 .map_err(|e| rlox_nn::NNError::Backend(format!("Failed to acquire read lock: {e}")))?;
75 extract_flat_params(&policy)
76 }
77}
78
79pub fn make_candle_callbacks(
84 policy: Arc<RwLock<CandleActorCritic>>,
85 obs_dim: usize,
86) -> (
87 Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync>,
88 Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync>,
89) {
90 let policy_act = policy.clone();
91 let policy_val = policy;
92
93 let action_fn: Arc<dyn Fn(&[f32]) -> (Vec<f32>, Vec<f64>) + Send + Sync> =
94 Arc::new(move |obs_flat: &[f32]| {
95 let n_envs = obs_flat.len() / obs_dim;
96 let obs = TensorData::new(obs_flat.to_vec(), vec![n_envs, obs_dim]);
97
98 let p = policy_act.read().unwrap();
99 let output = p.act(&obs).expect("Candle act() failed");
100
101 let actions = output.actions.data;
102 let log_probs: Vec<f64> = output.log_probs.data.iter().map(|&x| x as f64).collect();
103 (actions, log_probs)
104 });
105
106 let value_fn: Arc<dyn Fn(&[f32]) -> Vec<f64> + Send + Sync> =
107 Arc::new(move |obs_flat: &[f32]| {
108 let n_envs = obs_flat.len() / obs_dim;
109 let obs = TensorData::new(obs_flat.to_vec(), vec![n_envs, obs_dim]);
110
111 let p = policy_val.read().unwrap();
112 let values = p.value(&obs).expect("Candle value() failed");
113 values.data.iter().map(|&x| x as f64).collect()
114 });
115
116 (action_fn, value_fn)
117}
118
119fn load_flat_params(
121 policy: &mut CandleActorCritic,
122 flat_params: &[f32],
123) -> Result<(), rlox_nn::NNError> {
124 let vars = policy.varmap.all_vars();
125 let mut offset = 0;
126
127 for var in &vars {
128 let numel = var.elem_count();
129 if offset + numel > flat_params.len() {
130 return Err(rlox_nn::NNError::ShapeMismatch {
131 expected: format!("at least {} elements", offset + numel),
132 got: format!("{} elements", flat_params.len()),
133 });
134 }
135
136 let slice = &flat_params[offset..offset + numel];
137 let shape = var.dims();
138 let tensor = candle_core::Tensor::from_vec(slice.to_vec(), shape, var.device())
139 .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
140 var.set(&tensor)
141 .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
142 offset += numel;
143 }
144
145 if offset != flat_params.len() {
146 return Err(rlox_nn::NNError::ShapeMismatch {
147 expected: format!("{} total elements", offset),
148 got: format!("{} elements", flat_params.len()),
149 });
150 }
151
152 Ok(())
153}
154
155fn extract_flat_params(policy: &CandleActorCritic) -> Result<Vec<f32>, rlox_nn::NNError> {
157 let vars = policy.varmap.all_vars();
158 let total: usize = vars.iter().map(|v| v.elem_count()).sum();
159 let mut flat = Vec::with_capacity(total);
160
161 for var in &vars {
162 let data: Vec<f32> = var
163 .flatten_all()
164 .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?
165 .to_vec1()
166 .map_err(|e| rlox_nn::NNError::Backend(e.to_string()))?;
167 flat.extend(data);
168 }
169
170 Ok(flat)
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use candle_core::Device;
177 use rlox_nn::ActorCritic;
178
179 #[test]
180 fn test_shared_policy_act() {
181 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
182 let shared = SharedPolicy::new(ac);
183 let policy_ref = shared.clone_ref();
184
185 let obs = TensorData::zeros(vec![8, 4]);
186 let p = policy_ref.read().unwrap();
187 let result = p.act(&obs).unwrap();
188 assert_eq!(result.actions.shape, vec![8]);
189 }
190
191 #[test]
192 fn test_weight_roundtrip() {
193 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
194 let shared = SharedPolicy::new(ac);
195
196 let weights = shared.get_weights().unwrap();
198 assert!(!weights.is_empty());
199
200 shared.sync_weights(&weights).unwrap();
202
203 let weights2 = shared.get_weights().unwrap();
205 assert_eq!(weights.len(), weights2.len());
206 for (a, b) in weights.iter().zip(weights2.iter()) {
207 assert!((a - b).abs() < 1e-6, "Weight mismatch: {a} vs {b}");
208 }
209 }
210
211 #[test]
212 fn test_sync_weights_wrong_size() {
213 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
214 let shared = SharedPolicy::new(ac);
215
216 let result = shared.sync_weights(&[1.0, 2.0, 3.0]); assert!(result.is_err());
218 }
219
220 #[test]
221 fn test_make_candle_callbacks() {
222 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
223 let shared = SharedPolicy::new(ac);
224 let (action_fn, value_fn) = make_candle_callbacks(shared.clone_ref(), 4);
225
226 let obs = vec![0.0f32; 8];
228 let (actions, log_probs) = action_fn(&obs);
229 assert_eq!(actions.len(), 2);
230 assert_eq!(log_probs.len(), 2);
231
232 let values = value_fn(&obs);
233 assert_eq!(values.len(), 2);
234 }
235
236 #[test]
237 fn test_candle_callbacks_with_async_collector() {
238 use rlox_core::env::builtins::CartPole;
239 use rlox_core::env::parallel::VecEnv;
240 use rlox_core::env::RLEnv;
241 use rlox_core::pipeline::channel::Pipeline;
242 use rlox_core::pipeline::collector::AsyncCollector;
243 use rlox_core::seed::derive_seed;
244
245 let n_envs = 2;
246 let envs: Vec<Box<dyn RLEnv>> = (0..n_envs)
247 .map(|i| Box::new(CartPole::new(Some(derive_seed(42, i)))) as Box<dyn RLEnv>)
248 .collect();
249 let vec_env = Box::new(VecEnv::new(envs).unwrap());
250
251 let ac = CandleActorCritic::new(4, 2, 64, 2.5e-4, Device::Cpu, 42).unwrap();
253 let shared = SharedPolicy::new(ac);
254 let (action_fn, value_fn) = make_candle_callbacks(shared.clone_ref(), 4);
255
256 let pipe = Pipeline::new(4);
258 let tx = pipe.sender();
259 let mut collector = AsyncCollector::start(vec_env, 16, 0.99, 0.95, tx, value_fn, action_fn);
260
261 let batch = pipe.recv().unwrap();
263 assert_eq!(batch.n_steps, 16);
264 assert_eq!(batch.n_envs, 2);
265 assert_eq!(batch.observations.len(), 16 * 2 * 4);
266 assert_eq!(batch.log_probs.len(), 16 * 2);
267 assert_eq!(batch.values.len(), 16 * 2);
268 assert_eq!(batch.advantages.len(), 16 * 2);
269
270 for &lp in &batch.log_probs {
271 assert!(lp.is_finite(), "log_prob must be finite");
272 }
273 for &v in &batch.values {
274 assert!(v.is_finite(), "value must be finite");
275 }
276
277 let weights = shared.get_weights().unwrap();
279 shared.sync_weights(&weights).unwrap();
280
281 collector.stop();
282 }
283}