1use std::path::Path;
2
3use crate::error::NNError;
4use crate::tensor_data::TensorData;
5
6#[derive(Debug, Clone)]
8pub struct ActionOutput {
9 pub actions: TensorData,
11 pub log_probs: TensorData,
13}
14
15#[derive(Debug, Clone)]
17pub struct EvalOutput {
18 pub log_probs: TensorData,
20 pub entropy: TensorData,
22 pub values: TensorData,
24}
25
26#[derive(Debug, Clone)]
28pub struct PPOStepConfig {
29 pub clip_eps: f32,
30 pub vf_coef: f32,
31 pub ent_coef: f32,
32 pub max_grad_norm: f32,
33 pub clip_vloss: bool,
34}
35
36impl Default for PPOStepConfig {
37 fn default() -> Self {
38 Self {
39 clip_eps: 0.2,
40 vf_coef: 0.5,
41 ent_coef: 0.01,
42 max_grad_norm: 0.5,
43 clip_vloss: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct DQNStepConfig {
51 pub gamma: f32,
52 pub n_step: usize,
53 pub double_dqn: bool,
54}
55
56impl Default for DQNStepConfig {
57 fn default() -> Self {
58 Self {
59 gamma: 0.99,
60 n_step: 1,
61 double_dqn: true,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct SACStepConfig {
69 pub gamma: f32,
70 pub tau: f32,
71 pub target_entropy: f32,
72 pub auto_entropy: bool,
73}
74
75impl Default for SACStepConfig {
76 fn default() -> Self {
77 Self {
78 gamma: 0.99,
79 tau: 0.005,
80 target_entropy: -1.0,
81 auto_entropy: true,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct TD3StepConfig {
89 pub gamma: f32,
90 pub tau: f32,
91 pub policy_delay: usize,
92 pub target_noise: f32,
93 pub noise_clip: f32,
94}
95
96impl Default for TD3StepConfig {
97 fn default() -> Self {
98 Self {
99 gamma: 0.99,
100 tau: 0.005,
101 policy_delay: 2,
102 target_noise: 0.2,
103 noise_clip: 0.5,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Default)]
110pub struct TrainMetrics {
111 pub entries: Vec<(String, f64)>,
112}
113
114impl TrainMetrics {
115 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn insert(&mut self, key: impl Into<String>, value: f64) {
120 self.entries.push((key.into(), value));
121 }
122
123 pub fn get(&self, key: &str) -> Option<f64> {
124 self.entries.iter().find(|(k, _)| k == key).map(|(_, v)| *v)
125 }
126}
127
128pub trait ActorCritic {
134 fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>;
136
137 fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>;
139
140 fn evaluate(&self, obs: &TensorData, actions: &TensorData) -> Result<EvalOutput, NNError>;
142
143 #[allow(clippy::too_many_arguments)]
145 fn ppo_step(
146 &mut self,
147 obs: &TensorData,
148 actions: &TensorData,
149 old_log_probs: &TensorData,
150 advantages: &TensorData,
151 returns: &TensorData,
152 old_values: &TensorData,
153 config: &PPOStepConfig,
154 ) -> Result<TrainMetrics, NNError>;
155
156 fn learning_rate(&self) -> f32;
157 fn set_learning_rate(&mut self, lr: f32);
158
159 fn save(&self, path: &Path) -> Result<(), NNError>;
160 fn load(&mut self, path: &Path) -> Result<(), NNError>;
161}
162
163pub trait QFunction {
165 fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
167
168 fn q_value_at(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError>;
170
171 fn td_step(
174 &mut self,
175 obs: &TensorData,
176 actions: &TensorData,
177 targets: &TensorData,
178 weights: Option<&TensorData>,
179 ) -> Result<(f64, TensorData), NNError>;
180
181 fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
183
184 fn hard_update_target(&mut self);
186}
187
188pub trait StochasticPolicy {
199 fn sample_actions(&self, obs: &TensorData) -> Result<(TensorData, TensorData), NNError>;
202
203 fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError>;
205
206 fn learning_rate(&self) -> f32;
207 fn set_learning_rate(&mut self, lr: f32);
208
209 fn save(&self, path: &Path) -> Result<(), NNError>;
210 fn load(&mut self, path: &Path) -> Result<(), NNError>;
211}
212
213pub trait ContinuousQFunction {
215 fn q_value(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError>;
217
218 fn twin_q_values(
220 &self,
221 obs: &TensorData,
222 actions: &TensorData,
223 ) -> Result<(TensorData, TensorData), NNError>;
224
225 fn target_twin_q_values(
227 &self,
228 obs: &TensorData,
229 actions: &TensorData,
230 ) -> Result<(TensorData, TensorData), NNError>;
231
232 fn critic_step(
234 &mut self,
235 obs: &TensorData,
236 actions: &TensorData,
237 targets: &TensorData,
238 ) -> Result<TrainMetrics, NNError>;
239
240 fn soft_update_targets(&mut self, tau: f32);
242}
243
244pub trait DeterministicPolicy {
251 fn act(&self, obs: &TensorData) -> Result<TensorData, NNError>;
253
254 fn target_act(&self, obs: &TensorData) -> Result<TensorData, NNError>;
256
257 fn soft_update_target(&mut self, tau: f32);
259
260 fn learning_rate(&self) -> f32;
261 fn set_learning_rate(&mut self, lr: f32);
262
263 fn save(&self, path: &Path) -> Result<(), NNError>;
264 fn load(&mut self, path: &Path) -> Result<(), NNError>;
265}
266
267pub trait EntropyTuner {
269 fn alpha(&self) -> f32;
270 fn update(&mut self, log_probs: &TensorData, target_entropy: f32) -> Result<f64, NNError>;
271}
272
273#[derive(Debug, Clone, Copy, PartialEq)]
279pub enum Activation {
280 ReLU,
281 Tanh,
282}
283
284#[derive(Debug, Clone)]
286pub struct MLPConfig {
287 pub input_dim: usize,
288 pub output_dim: usize,
289 pub hidden_dims: Vec<usize>,
290 pub activation: Activation,
291 pub output_activation: Option<Activation>,
292}
293
294impl MLPConfig {
295 pub fn new(input_dim: usize, output_dim: usize) -> Self {
296 Self {
297 input_dim,
298 output_dim,
299 hidden_dims: vec![64, 64],
300 activation: Activation::Tanh,
301 output_activation: None,
302 }
303 }
304
305 pub fn with_hidden(mut self, dims: Vec<usize>) -> Self {
306 self.hidden_dims = dims;
307 self
308 }
309
310 pub fn with_activation(mut self, act: Activation) -> Self {
311 self.activation = act;
312 self
313 }
314
315 pub fn with_output_activation(mut self, act: Activation) -> Self {
316 self.output_activation = Some(act);
317 self
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_ppo_config_default() {
327 let cfg = PPOStepConfig::default();
328 assert!((cfg.clip_eps - 0.2).abs() < 1e-6);
329 assert!((cfg.vf_coef - 0.5).abs() < 1e-6);
330 assert!((cfg.ent_coef - 0.01).abs() < 1e-6);
331 assert!((cfg.max_grad_norm - 0.5).abs() < 1e-6);
332 assert!(cfg.clip_vloss);
333 }
334
335 #[test]
336 fn test_train_metrics() {
337 let mut m = TrainMetrics::new();
338 m.insert("loss", 0.5);
339 m.insert("entropy", 1.2);
340 assert_eq!(m.get("loss"), Some(0.5));
341 assert_eq!(m.get("entropy"), Some(1.2));
342 assert_eq!(m.get("missing"), None);
343 }
344
345 #[test]
346 fn test_mlp_config_builder() {
347 let cfg = MLPConfig::new(4, 2)
348 .with_hidden(vec![128, 128])
349 .with_activation(Activation::ReLU)
350 .with_output_activation(Activation::Tanh);
351 assert_eq!(cfg.input_dim, 4);
352 assert_eq!(cfg.output_dim, 2);
353 assert_eq!(cfg.hidden_dims, vec![128, 128]);
354 assert_eq!(cfg.activation, Activation::ReLU);
355 assert_eq!(cfg.output_activation, Some(Activation::Tanh));
356 }
357
358 #[test]
359 fn test_action_output_shape() {
360 let out = ActionOutput {
361 actions: TensorData::zeros(vec![8]),
362 log_probs: TensorData::zeros(vec![8]),
363 };
364 assert_eq!(out.actions.shape, vec![8]);
365 assert_eq!(out.log_probs.shape, vec![8]);
366 }
367
368 #[test]
369 fn test_eval_output_shape() {
370 let out = EvalOutput {
371 log_probs: TensorData::zeros(vec![32]),
372 entropy: TensorData::zeros(vec![32]),
373 values: TensorData::zeros(vec![32]),
374 };
375 assert_eq!(out.log_probs.numel(), 32);
376 assert_eq!(out.entropy.numel(), 32);
377 assert_eq!(out.values.numel(), 32);
378 }
379
380 fn _assert_actor_critic_object_safe(_: &dyn ActorCritic) {}
382 fn _assert_q_function_object_safe(_: &dyn QFunction) {}
383 fn _assert_stochastic_policy_object_safe(_: &dyn StochasticPolicy) {}
384 fn _assert_continuous_q_object_safe(_: &dyn ContinuousQFunction) {}
385 fn _assert_deterministic_policy_object_safe(_: &dyn DeterministicPolicy) {}
386 fn _assert_entropy_tuner_object_safe(_: &dyn EntropyTuner) {}
387}