ActorCritic

Trait ActorCritic 

Source
pub trait ActorCritic {
    // Required methods
    fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>;
    fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>;
    fn evaluate(
        &self,
        obs: &TensorData,
        actions: &TensorData,
    ) -> Result<EvalOutput, NNError>;
    fn ppo_step(
        &mut self,
        obs: &TensorData,
        actions: &TensorData,
        old_log_probs: &TensorData,
        advantages: &TensorData,
        returns: &TensorData,
        old_values: &TensorData,
        config: &PPOStepConfig,
    ) -> Result<TrainMetrics, NNError>;
    fn learning_rate(&self) -> f32;
    fn set_learning_rate(&mut self, lr: f32);
    fn save(&self, path: &Path) -> Result<(), NNError>;
    fn load(&mut self, path: &Path) -> Result<(), NNError>;
}
Expand description

Actor-Critic policy for on-policy algorithms (PPO, A2C).

Required Methods§

Source

fn act(&self, obs: &TensorData) -> Result<ActionOutput, NNError>

Sample actions from the policy (inference, no gradient tracking).

Source

fn value(&self, obs: &TensorData) -> Result<TensorData, NNError>

Compute state values (inference, no gradient tracking).

Source

fn evaluate( &self, obs: &TensorData, actions: &TensorData, ) -> Result<EvalOutput, NNError>

Evaluate the policy on (obs, actions) pairs. Differentiable.

Source

fn ppo_step( &mut self, obs: &TensorData, actions: &TensorData, old_log_probs: &TensorData, advantages: &TensorData, returns: &TensorData, old_values: &TensorData, config: &PPOStepConfig, ) -> Result<TrainMetrics, NNError>

Perform one PPO gradient step. Bundles forward→loss→backward→clip→step.

Source

fn learning_rate(&self) -> f32

Source

fn set_learning_rate(&mut self, lr: f32)

Source

fn save(&self, path: &Path) -> Result<(), NNError>

Source

fn load(&mut self, path: &Path) -> Result<(), NNError>

Implementors§