pub fn tanh_log_prob_correction(pre_tanh: f32) -> f32
Tanh squashing log-prob correction: log_prob -= log(1 - tanh(x)^2 + eps)