compute_vtrace

Function compute_vtrace 

Source
pub fn compute_vtrace(
    log_rhos: &[f32],
    rewards: &[f32],
    values: &[f32],
    dones: &[f32],
    bootstrap_value: f32,
    gamma: f32,
    rho_bar: f32,
    c_bar: f32,
) -> Result<(Vec<f32>, Vec<f32>), RloxError>
Expand description

Compute V-trace targets and policy gradient advantages (Espeholt et al. 2018).

Processes backwards from t=n-1 to t=0: rho_t = min(rho_bar, exp(log_rhos[t])) c_t = min(c_bar, exp(log_rhos[t])) non_terminal = 1.0 - dones[t] delta_t = rho_t * (rewards[t] + gamma * non_terminal * values[t+1] - values[t]) vs[t] = values[t] + delta_t + gamma * non_terminal * c_t * (vs[t+1] - values[t+1]) pg_advantages[t] = rho_t * (rewards[t] + gamma * non_terminal * vs[t+1] - values[t])

Uses bootstrap_value for values[n] and vs[n], zeroed when the last step is terminal (dones[n-1] == 1.0).

Returns (vs, pg_advantages).