pub fn categorical_sample(logits: &[f32], uniform_rand: f32) -> usize
Sample from a categorical distribution given logits. Uses the Gumbel-max trick for differentiable-friendly sampling.