pub trait StochasticPolicy {
// Required methods
fn sample_actions(
&self,
obs: &TensorData,
) -> Result<(TensorData, TensorData), NNError>;
fn deterministic_action(
&self,
obs: &TensorData,
) -> Result<TensorData, NNError>;
fn learning_rate(&self) -> f32;
fn set_learning_rate(&mut self, lr: f32);
fn save(&self, path: &Path) -> Result<(), NNError>;
fn load(&mut self, path: &Path) -> Result<(), NNError>;
}Expand description
Continuous stochastic policy for SAC.
Training steps (sac_actor_step) are intentionally NOT on this trait because
they require autograd to flow through the critic’s Q-network. Trait methods
convert tensors to TensorData (Vecsac_actor_step method instead.
Required Methods§
Sourcefn sample_actions(
&self,
obs: &TensorData,
) -> Result<(TensorData, TensorData), NNError>
fn sample_actions( &self, obs: &TensorData, ) -> Result<(TensorData, TensorData), NNError>
Sample actions with reparameterization trick. Returns (squashed_actions [batch, act_dim], log_probs [batch]).
Sourcefn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError>
fn deterministic_action(&self, obs: &TensorData) -> Result<TensorData, NNError>
Deterministic action (mean through squashing).