pub trait ActorCritic {
// Required methods
fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>;
fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>;
fn evaluate(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<EvalOutput, NNError>;
fn ppo_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
old_log_probs: &TensorData,
advantages: &TensorData,
returns: &TensorData,
old_values: &TensorData,
config: &PPOStepConfig,
) -> Result<TrainMetrics, NNError>;
fn learning_rate(&self) -> f32;
fn set_learning_rate(&mut self, lr: f32);
fn save(&self, path: &Path) -> Result<(), NNError>;
fn load(&mut self, path: &Path) -> Result<(), NNError>;
}Expand description
Actor-Critic policy for on-policy algorithms (PPO, A2C).
Required Methods§
Sourcefn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>
fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>
Sample actions from the policy (inference, no gradient tracking).
Sourcefn value(&self, obs: &TensorData) -> Result<TensorData, NNError>
fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>
Compute state values (inference, no gradient tracking).
Sourcefn evaluate(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<EvalOutput, NNError>
fn evaluate( &self, obs: &TensorData, actions: &TensorData, ) -> Result<EvalOutput, NNError>
Evaluate the policy on (obs, actions) pairs. Differentiable.
Sourcefn ppo_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
old_log_probs: &TensorData,
advantages: &TensorData,
returns: &TensorData,
old_values: &TensorData,
config: &PPOStepConfig,
) -> Result<TrainMetrics, NNError>
fn ppo_step( &mut self, obs: &TensorData, actions: &TensorData, old_log_probs: &TensorData, advantages: &TensorData, returns: &TensorData, old_values: &TensorData, config: &PPOStepConfig, ) -> Result<TrainMetrics, NNError>
Perform one PPO gradient step. Bundles forward→loss→backward→clip→step.