pub fn categorical_log_prob(logits: &[f32], action: usize) -> f32
Compute log_prob for a categorical distribution.