pub trait QFunction {
// Required methods
fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
fn q_value_at(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<TensorData, NNError>;
fn td_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
targets: &TensorData,
weights: Option<&TensorData>,
) -> Result<(f64, TensorData), NNError>;
fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
fn hard_update_target(&mut self);
}Expand description
Q-value network for off-policy algorithms (DQN).
Required Methods§
Sourcefn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>
fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>
Compute Q-values for all actions. Returns [batch_size, n_actions].
Sourcefn q_value_at(
&self,
obs: &TensorData,
actions: &TensorData,
) -> Result<TensorData, NNError>
fn q_value_at( &self, obs: &TensorData, actions: &TensorData, ) -> Result<TensorData, NNError>
Compute Q-value for (obs, action) pairs. Returns [batch_size].
Sourcefn td_step(
&mut self,
obs: &TensorData,
actions: &TensorData,
targets: &TensorData,
weights: Option<&TensorData>,
) -> Result<(f64, TensorData), NNError>
fn td_step( &mut self, obs: &TensorData, actions: &TensorData, targets: &TensorData, weights: Option<&TensorData>, ) -> Result<(f64, TensorData), NNError>
Perform one DQN TD gradient step. Returns (loss, td_errors) where td_errors can be used for PER.
Sourcefn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>
fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>
Compute target Q-values using the target network.
Sourcefn hard_update_target(&mut self)
fn hard_update_target(&mut self)
Hard-copy parameters to the target network.