ContinuousQFunction

Trait ContinuousQFunction 

Source
pub trait ContinuousQFunction {
    // Required methods
    fn q_value(
        &self,
        obs: &TensorData,
        actions: &TensorData,
    ) -> Result<TensorData, NNError>;
    fn twin_q_values(
        &self,
        obs: &TensorData,
        actions: &TensorData,
    ) -> Result<(TensorData, TensorData), NNError>;
    fn target_twin_q_values(
        &self,
        obs: &TensorData,
        actions: &TensorData,
    ) -> Result<(TensorData, TensorData), NNError>;
    fn critic_step(
        &mut self,
        obs: &TensorData,
        actions: &TensorData,
        targets: &TensorData,
    ) -> Result<TrainMetrics, NNError>;
    fn soft_update_targets(&mut self, tau: f32);
}
Expand description

Continuous Q-function for SAC/TD3 (takes obs + action as input).

Required Methods§

Source

fn q_value( &self, obs: &TensorData, actions: &TensorData, ) -> Result<TensorData, NNError>

Compute Q-value for (obs, action). Returns [batch_size].

Source

fn twin_q_values( &self, obs: &TensorData, actions: &TensorData, ) -> Result<(TensorData, TensorData), NNError>

Compute twin Q-values for (obs, action). Returns (q1, q2), each [batch_size].

Source

fn target_twin_q_values( &self, obs: &TensorData, actions: &TensorData, ) -> Result<(TensorData, TensorData), NNError>

Compute target twin Q-values (from target networks).

Source

fn critic_step( &mut self, obs: &TensorData, actions: &TensorData, targets: &TensorData, ) -> Result<TrainMetrics, NNError>

Perform one TD gradient step on both critics.

Source

fn soft_update_targets(&mut self, tau: f32)

Polyak soft update of target networks.

Implementors§