compute_token_kl

Function compute_token_kl 

Source
pub fn compute_token_kl(
    log_probs_policy: &[f64],
    log_probs_ref: &[f64],
) -> Result<f64, RloxError>
Expand description

Token-level KL divergence: sum(exp(log_p) * (log_p - log_q)).