pub trait ContinuousQFunction {
// Required methods
fn q_value(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<TensorData, NNError>;
fn twin_q_values(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<(TensorData, TensorData), NNError>;
fn target_twin_q_values(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<(TensorData, TensorData), NNError>;
fn critic_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
targets: &TensorData,
) -> Result<TrainMetrics, NNError>;
fn soft_update_targets(&mut self, tau: f32);
}Expand description
Continuous Q-function for SAC/TD3 (takes obs + action as input).
Required Methods§
Sourcefn q_value(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<TensorData, NNError>
fn q_value( &self, obs: &TensorData, actions: &TensorData, ) -> Result<TensorData, NNError>
Compute Q-value for (obs, action). Returns [batch_size].
Sourcefn twin_q_values(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<(TensorData, TensorData), NNError>
fn twin_q_values( &self, obs: &TensorData, actions: &TensorData, ) -> Result<(TensorData, TensorData), NNError>
Compute twin Q-values for (obs, action). Returns (q1, q2), each [batch_size].
Sourcefn target_twin_q_values(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<(TensorData, TensorData), NNError>
fn target_twin_q_values( &self, obs: &TensorData, actions: &TensorData, ) -> Result<(TensorData, TensorData), NNError>
Compute target twin Q-values (from target networks).
Sourcefn critic_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
targets: &TensorData,
) -> Result<TrainMetrics, NNError>
fn critic_step( &mut self, obs: &TensorData, actions: &TensorData, targets: &TensorData, ) -> Result<TrainMetrics, NNError>
Perform one TD gradient step on both critics.
Sourcefn soft_update_targets(&mut self, tau: f32)
fn soft_update_targets(&mut self, tau: f32)
Polyak soft update of target networks.