rlox_nn/
traits.rs

1use std::path::Path;
2
3use crate::error::NNError;
4use crate::tensor_data::TensorData;
5
6/// Action output from a policy.
7#[derive(Debug, Clone)]
8pub struct ActionOutput {
9    /// Actions: [batch_size] for discrete, [batch_size, act_dim] for continuous.
10    pub actions: TensorData,
11    /// Log-probabilities of the selected actions: [batch_size].
12    pub log_probs: TensorData,
13}
14
15/// Output from evaluating a policy on (obs, actions) pairs.
16#[derive(Debug, Clone)]
17pub struct EvalOutput {
18    /// Log-probabilities: [batch_size].
19    pub log_probs: TensorData,
20    /// Entropy of the distribution: [batch_size].
21    pub entropy: TensorData,
22    /// State values: [batch_size].
23    pub values: TensorData,
24}
25
26/// PPO step configuration.
27#[derive(Debug, Clone)]
28pub struct PPOStepConfig {
29    pub clip_eps: f32,
30    pub vf_coef: f32,
31    pub ent_coef: f32,
32    pub max_grad_norm: f32,
33    pub clip_vloss: bool,
34}
35
36impl Default for PPOStepConfig {
37    fn default() -> Self {
38        Self {
39            clip_eps: 0.2,
40            vf_coef: 0.5,
41            ent_coef: 0.01,
42            max_grad_norm: 0.5,
43            clip_vloss: true,
44        }
45    }
46}
47
48/// DQN step configuration.
49#[derive(Debug, Clone)]
50pub struct DQNStepConfig {
51    pub gamma: f32,
52    pub n_step: usize,
53    pub double_dqn: bool,
54}
55
56impl Default for DQNStepConfig {
57    fn default() -> Self {
58        Self {
59            gamma: 0.99,
60            n_step: 1,
61            double_dqn: true,
62        }
63    }
64}
65
66/// SAC step configuration.
67#[derive(Debug, Clone)]
68pub struct SACStepConfig {
69    pub gamma: f32,
70    pub tau: f32,
71    pub target_entropy: f32,
72    pub auto_entropy: bool,
73}
74
75impl Default for SACStepConfig {
76    fn default() -> Self {
77        Self {
78            gamma: 0.99,
79            tau: 0.005,
80            target_entropy: -1.0,
81            auto_entropy: true,
82        }
83    }
84}
85
86/// TD3 step configuration.
87#[derive(Debug, Clone)]
88pub struct TD3StepConfig {
89    pub gamma: f32,
90    pub tau: f32,
91    pub policy_delay: usize,
92    pub target_noise: f32,
93    pub noise_clip: f32,
94}
95
96impl Default for TD3StepConfig {
97    fn default() -> Self {
98        Self {
99            gamma: 0.99,
100            tau: 0.005,
101            policy_delay: 2,
102            target_noise: 0.2,
103            noise_clip: 0.5,
104        }
105    }
106}
107
108/// Training metrics dictionary.
109#[derive(Debug, Clone, Default)]
110pub struct TrainMetrics {
111    pub entries: Vec<(String, f64)>,
112}
113
114impl TrainMetrics {
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    pub fn insert(&mut self, key: impl Into<String>, value: f64) {
120        self.entries.push((key.into(), value));
121    }
122
123    pub fn get(&self, key: &str) -> Option<f64> {
124        self.entries.iter().find(|(k, _)| k == key).map(|(_, v)| *v)
125    }
126}
127
128// ────────────────────────────────────────────────────────────
129// Phase 1 Traits: Actor-Critic (PPO/A2C) and Q-Function (DQN)
130// ────────────────────────────────────────────────────────────
131
132/// Actor-Critic policy for on-policy algorithms (PPO, A2C).
133pub trait ActorCritic {
134    /// Sample actions from the policy (inference, no gradient tracking).
135    fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>;
136
137    /// Compute state values (inference, no gradient tracking).
138    fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>;
139
140    /// Evaluate the policy on (obs, actions) pairs. Differentiable.
141    fn evaluate(&self, obs: &TensorData, actions: &TensorData) -> Result<EvalOutput, NNError>;
142
143    /// Perform one PPO gradient step. Bundles forward→loss→backward→clip→step.
144    #[allow(clippy::too_many_arguments)]
145    fn ppo_step(
146        &mut self,
147        obs: &TensorData,
148        actions: &TensorData,
149        old_log_probs: &TensorData,
150        advantages: &TensorData,
151        returns: &TensorData,
152        old_values: &TensorData,
153        config: &PPOStepConfig,
154    ) -> Result<TrainMetrics, NNError>;
155
156    fn learning_rate(&self) -> f32;
157    fn set_learning_rate(&mut self, lr: f32);
158
159    fn save(&self, path: &Path) -> Result<(), NNError>;
160    fn load(&mut self, path: &Path) -> Result<(), NNError>;
161}
162
163/// Q-value network for off-policy algorithms (DQN).
164pub trait QFunction {
165    /// Compute Q-values for all actions. Returns [batch_size, n_actions].
166    fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
167
168    /// Compute Q-value for (obs, action) pairs. Returns [batch_size].
169    fn q_value_at(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError>;
170
171    /// Perform one DQN TD gradient step.
172    /// Returns (loss, td_errors) where td_errors can be used for PER.
173    fn td_step(
174        &mut self,
175        obs: &TensorData,
176        actions: &TensorData,
177        targets: &TensorData,
178        weights: Option<&TensorData>,
179    ) -> Result<(f64, TensorData), NNError>;
180
181    /// Compute target Q-values using the target network.
182    fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
183
184    /// Hard-copy parameters to the target network.
185    fn hard_update_target(&mut self);
186}
187
188// ────────────────────────────────────────────────────────────
189// Phase 2 Traits: Stochastic Policy (SAC) and Deterministic (TD3)
190// ────────────────────────────────────────────────────────────
191
192/// Continuous stochastic policy for SAC.
193///
194/// Training steps (`sac_actor_step`) are intentionally NOT on this trait because
195/// they require autograd to flow through the critic's Q-network. Trait methods
196/// convert tensors to `TensorData` (Vec<f32>), severing the computation graph.
197/// Use the backend-specific inherent `sac_actor_step` method instead.
198pub trait StochasticPolicy {
199    /// Sample actions with reparameterization trick.
200    /// Returns (squashed_actions [batch, act_dim], log_probs [batch]).
201    fn sample_actions(&self, obs: &TensorData) -> Result<(TensorData, TensorData), NNError>;
202
203    /// Deterministic action (mean through squashing).
204    fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError>;
205
206    fn learning_rate(&self) -> f32;
207    fn set_learning_rate(&mut self, lr: f32);
208
209    fn save(&self, path: &Path) -> Result<(), NNError>;
210    fn load(&mut self, path: &Path) -> Result<(), NNError>;
211}
212
213/// Continuous Q-function for SAC/TD3 (takes obs + action as input).
214pub trait ContinuousQFunction {
215    /// Compute Q-value for (obs, action). Returns [batch_size].
216    fn q_value(&self, obs: &TensorData, actions: &TensorData) -> Result<TensorData, NNError>;
217
218    /// Compute twin Q-values for (obs, action). Returns (q1, q2), each [batch_size].
219    fn twin_q_values(
220        &self,
221        obs: &TensorData,
222        actions: &TensorData,
223    ) -> Result<(TensorData, TensorData), NNError>;
224
225    /// Compute target twin Q-values (from target networks).
226    fn target_twin_q_values(
227        &self,
228        obs: &TensorData,
229        actions: &TensorData,
230    ) -> Result<(TensorData, TensorData), NNError>;
231
232    /// Perform one TD gradient step on both critics.
233    fn critic_step(
234        &mut self,
235        obs: &TensorData,
236        actions: &TensorData,
237        targets: &TensorData,
238    ) -> Result<TrainMetrics, NNError>;
239
240    /// Polyak soft update of target networks.
241    fn soft_update_targets(&mut self, tau: f32);
242}
243
244/// Deterministic policy for TD3.
245///
246/// Training steps (`td3_actor_step`) are intentionally NOT on this trait because
247/// they require autograd to flow through the critic's Q-network. Trait methods
248/// convert tensors to `TensorData` (Vec<f32>), severing the computation graph.
249/// Use the backend-specific inherent `td3_actor_step` method instead.
250pub trait DeterministicPolicy {
251    /// Compute deterministic action. Returns [batch_size, act_dim].
252    fn act(&self, obs: &TensorData) -> Result<TensorData, NNError>;
253
254    /// Compute target policy action (from target network).
255    fn target_act(&self, obs: &TensorData) -> Result<TensorData, NNError>;
256
257    /// Polyak soft update of target network.
258    fn soft_update_target(&mut self, tau: f32);
259
260    fn learning_rate(&self) -> f32;
261    fn set_learning_rate(&mut self, lr: f32);
262
263    fn save(&self, path: &Path) -> Result<(), NNError>;
264    fn load(&mut self, path: &Path) -> Result<(), NNError>;
265}
266
267/// Entropy tuning for SAC.
268pub trait EntropyTuner {
269    fn alpha(&self) -> f32;
270    fn update(&mut self, log_probs: &TensorData, target_entropy: f32) -> Result<f64, NNError>;
271}
272
273// ────────────────────────────────────────────────────────────
274// Network builder config (shared across backends)
275// ────────────────────────────────────────────────────────────
276
277/// Activation function.
278#[derive(Debug, Clone, Copy, PartialEq)]
279pub enum Activation {
280    ReLU,
281    Tanh,
282}
283
284/// Configuration for building an MLP.
285#[derive(Debug, Clone)]
286pub struct MLPConfig {
287    pub input_dim: usize,
288    pub output_dim: usize,
289    pub hidden_dims: Vec<usize>,
290    pub activation: Activation,
291    pub output_activation: Option<Activation>,
292}
293
294impl MLPConfig {
295    pub fn new(input_dim: usize, output_dim: usize) -> Self {
296        Self {
297            input_dim,
298            output_dim,
299            hidden_dims: vec![64, 64],
300            activation: Activation::Tanh,
301            output_activation: None,
302        }
303    }
304
305    pub fn with_hidden(mut self, dims: Vec<usize>) -> Self {
306        self.hidden_dims = dims;
307        self
308    }
309
310    pub fn with_activation(mut self, act: Activation) -> Self {
311        self.activation = act;
312        self
313    }
314
315    pub fn with_output_activation(mut self, act: Activation) -> Self {
316        self.output_activation = Some(act);
317        self
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_ppo_config_default() {
327        let cfg = PPOStepConfig::default();
328        assert!((cfg.clip_eps - 0.2).abs() < 1e-6);
329        assert!((cfg.vf_coef - 0.5).abs() < 1e-6);
330        assert!((cfg.ent_coef - 0.01).abs() < 1e-6);
331        assert!((cfg.max_grad_norm - 0.5).abs() < 1e-6);
332        assert!(cfg.clip_vloss);
333    }
334
335    #[test]
336    fn test_train_metrics() {
337        let mut m = TrainMetrics::new();
338        m.insert("loss", 0.5);
339        m.insert("entropy", 1.2);
340        assert_eq!(m.get("loss"), Some(0.5));
341        assert_eq!(m.get("entropy"), Some(1.2));
342        assert_eq!(m.get("missing"), None);
343    }
344
345    #[test]
346    fn test_mlp_config_builder() {
347        let cfg = MLPConfig::new(4, 2)
348            .with_hidden(vec![128, 128])
349            .with_activation(Activation::ReLU)
350            .with_output_activation(Activation::Tanh);
351        assert_eq!(cfg.input_dim, 4);
352        assert_eq!(cfg.output_dim, 2);
353        assert_eq!(cfg.hidden_dims, vec![128, 128]);
354        assert_eq!(cfg.activation, Activation::ReLU);
355        assert_eq!(cfg.output_activation, Some(Activation::Tanh));
356    }
357
358    #[test]
359    fn test_action_output_shape() {
360        let out = ActionOutput {
361            actions: TensorData::zeros(vec![8]),
362            log_probs: TensorData::zeros(vec![8]),
363        };
364        assert_eq!(out.actions.shape, vec![8]);
365        assert_eq!(out.log_probs.shape, vec![8]);
366    }
367
368    #[test]
369    fn test_eval_output_shape() {
370        let out = EvalOutput {
371            log_probs: TensorData::zeros(vec![32]),
372            entropy: TensorData::zeros(vec![32]),
373            values: TensorData::zeros(vec![32]),
374        };
375        assert_eq!(out.log_probs.numel(), 32);
376        assert_eq!(out.entropy.numel(), 32);
377        assert_eq!(out.values.numel(), 32);
378    }
379
380    // Compile-time trait object safety checks
381    fn _assert_actor_critic_object_safe(_: &dyn ActorCritic) {}
382    fn _assert_q_function_object_safe(_: &dyn QFunction) {}
383    fn _assert_stochastic_policy_object_safe(_: &dyn StochasticPolicy) {}
384    fn _assert_continuous_q_object_safe(_: &dyn ContinuousQFunction) {}
385    fn _assert_deterministic_policy_object_safe(_: &dyn DeterministicPolicy) {}
386    fn _assert_entropy_tuner_object_safe(_: &dyn EntropyTuner) {}
387}