StochasticPolicy

Trait StochasticPolicy 

Source
pub trait StochasticPolicy {
    // Required methods
    fn sample_actions(
        &self,
        obs: &TensorData,
    ) -> Result<(TensorData, TensorData), NNError>;
    fn deterministic_action(
        &self,
        obs: &TensorData,
    ) -> Result<TensorData, NNError>;
    fn learning_rate(&self) -> f32;
    fn set_learning_rate(&mut self, lr: f32);
    fn save(&self, path: &Path) -> Result<(), NNError>;
    fn load(&mut self, path: &Path) -> Result<(), NNError>;
}
Expand description

Continuous stochastic policy for SAC.

Training steps (sac_actor_step) are intentionally NOT on this trait because they require autograd to flow through the critic’s Q-network. Trait methods convert tensors to TensorData (Vec), severing the computation graph. Use the backend-specific inherent sac_actor_step method instead.

Required Methods§

Source

fn sample_actions( &self, obs: &TensorData, ) -> Result<(TensorData, TensorData), NNError>

Sample actions with reparameterization trick. Returns (squashed_actions [batch, act_dim], log_probs [batch]).

Source

fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError>

Deterministic action (mean through squashing).

Source

fn learning_rate(&self) -> f32

Source

fn set_learning_rate(&mut self, lr: f32)

Source

fn save(&self, path: &Path) -> Result<(), NNError>

Source

fn load(&mut self, path: &Path) -> Result<(), NNError>

Implementors§