pub fn compute_token_kl( log_probs_policy: &[f32], log_probs_ref: &[f32], ) -> Result<f32, RloxError>
Token-level KL divergence: sum(exp(log_p) * (log_p - log_q)).
sum(exp(log_p) * (log_p - log_q))