pub fn compute_vtrace(
log_rhos: &[f32],
rewards: &[f32],
values: &[f32],
dones: &[f32],
bootstrap_value: f32,
gamma: f32,
rho_bar: f32,
c_bar: f32,
) -> Result<(Vec<f32>, Vec<f32>), RloxError>Expand description
Compute V-trace targets and policy gradient advantages (Espeholt et al. 2018).
Processes backwards from t=n-1 to t=0: rho_t = min(rho_bar, exp(log_rhos[t])) c_t = min(c_bar, exp(log_rhos[t])) non_terminal = 1.0 - dones[t] delta_t = rho_t * (rewards[t] + gamma * non_terminal * values[t+1] - values[t]) vs[t] = values[t] + delta_t + gamma * non_terminal * c_t * (vs[t+1] - values[t+1]) pg_advantages[t] = rho_t * (rewards[t] + gamma * non_terminal * vs[t+1] - values[t])
Uses bootstrap_value for values[n] and vs[n], zeroed when the last step
is terminal (dones[n-1] == 1.0).
Returns (vs, pg_advantages).