QFunction

Trait QFunction 

Source
pub trait QFunction {
    // Required methods
    fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
    fn q_value_at(
        &self,
        obs: &TensorData,
        actions: &TensorData,
    ) -> Result<TensorData, NNError>;
    fn td_step(
        &mut self,
        obs: &TensorData,
        actions: &TensorData,
        targets: &TensorData,
        weights: Option<&TensorData>,
    ) -> Result<(f64, TensorData), NNError>;
    fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>;
    fn hard_update_target(&mut self);
}
Expand description

Q-value network for off-policy algorithms (DQN).

Required Methods§

Source

fn q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>

Compute Q-values for all actions. Returns [batch_size, n_actions].

Source

fn q_value_at( &self, obs: &TensorData, actions: &TensorData, ) -> Result<TensorData, NNError>

Compute Q-value for (obs, action) pairs. Returns [batch_size].

Source

fn td_step( &mut self, obs: &TensorData, actions: &TensorData, targets: &TensorData, weights: Option<&TensorData>, ) -> Result<(f64, TensorData), NNError>

Perform one DQN TD gradient step. Returns (loss, td_errors) where td_errors can be used for PER.

Source

fn target_q_values(&self, obs: &TensorData) -> Result<TensorData, NNError>

Compute target Q-values using the target network.

Source

fn hard_update_target(&mut self)

Hard-copy parameters to the target network.

Implementors§