Benchmark: LLM Post-Training Operations¶
Operations used in LLM alignment / post-training (GRPO, RLHF, DPO). No TorchRL or SB3 equivalents exist — comparison is against NumPy and PyTorch baselines.
What is Measured¶
GRPO Group Advantages¶
Group-Relative Policy Optimization computes z-score normalization per completion group: (reward - mean) / std.
| Framework | Implementation |
|---|---|
| rlox | compute_group_advantages(rewards) — Rust, f64, Welford's online algorithm |
| NumPy | (rewards - rewards.mean()) / rewards.std() — vectorized |
| PyTorch | (g - g.mean()) / g.std() — tensor ops |
Tested with n_prompts × k_completions groups: (16×4), (64×8), (256×16).
Token-Level KL Divergence¶
KL(policy ‖ reference) = sum(exp(log_p) * (log_p - log_q)). Used in PPO-based RLHF to constrain policy drift from the reference model.
| Framework | Implementation |
|---|---|
| rlox | compute_token_kl(log_p, log_q) — Rust f64 loop |
| NumPy | np.sum(np.exp(log_p) * (log_p - log_q)) — vectorized |
| PyTorch | torch.sum(torch.exp(log_p) * (log_p - log_q)).item() — BLAS-backed |
Tested at sequence lengths: 128, 512, 2048, 8192.
Results¶
GRPO Group Advantages¶
| Config | rlox (median) | NumPy (median) | PyTorch (median) |
|---|---|---|---|
| 16 × 4 | 2.2 us | 79.0 us | 76.2 us |
| 64 × 8 | 8.6 us | 309.9 us | 320.3 us |
| 256 × 16 | 36.1 us | 1,243.3 us | 1,261.6 us |
| Config | vs NumPy | 95% CI | vs PyTorch | 95% CI |
|---|---|---|---|---|
| 16 × 4 | 36.5x | [35.7, 37.2] | 35.2x | [35.0, 37.6] |
| 64 × 8 | 35.9x | [35.5, 36.5] | 37.1x | [36.6, 37.7] |
| 256 × 16 | 34.4x | [34.2, 34.6] | 34.9x | [34.6, 35.5] |
Token-Level KL Divergence¶
| Seq Length | rlox (median) | NumPy (median) | PyTorch (median) |
|---|---|---|---|
| 128 | 0.4 us | 1.8 us | 2.0 us |
| 512 | 1.2 us | 2.8 us | 3.2 us |
| 2,048 | 3.9 us | 6.6 us | 5.8 us |
| 8,192 | 17.0 us | 26.8 us | 50.3 us |
| Seq Length | vs NumPy | 95% CI | vs PyTorch | 95% CI |
|---|---|---|---|---|
| 128 | 4.7x | [4.7, 4.7] | 5.4x | [5.4, 5.6] |
| 512 | 2.4x | [2.3, 2.4] | 2.7x | [2.6, 2.7] |
| 2,048 | 1.7x | [1.7, 1.9] | 1.5x | [1.5, 1.5] |
| 8,192 | 1.6x | [1.6, 1.6] | 3.0x | [2.8, 3.1] |
Analysis¶
GRPO: Why 35x faster¶
GRPO processes many small groups (4–16 elements each). For 256 prompts × 16 completions, that's 256 separate calls to compute mean/std/normalize.
Each Python call to np.mean() + np.std() on a 16-element array incurs:
- Function call overhead (~500ns)
- NumPy type dispatch
- Array creation for the result
rlox crosses the PyO3 boundary once per group and computes mean + std + normalization in a single Rust pass with no allocation. The per-call overhead drops from ~5us (Python) to ~140ns (Rust).
NumPy and PyTorch perform nearly identically here — both are limited by Python function call overhead, not by the actual arithmetic.
Token KL: Advantage narrows at large sequences¶
At seq_len=128, rlox is 4x faster than NumPy. At seq_len=8192, only 1.6x.
The reason: NumPy's vectorized exp() and sum() call into optimized BLAS/LAPACK routines (via Accelerate framework on macOS). For large arrays, these SIMD-optimized C routines approach the same throughput as Rust's f64 loop. The rlox advantage at small sizes comes from avoiding NumPy's per-call dispatch overhead.
PyTorch's surprising regression at 8192 (3.0x slower than rlox, worse than NumPy) is likely due to PyTorch's larger per-op dispatch overhead compared to NumPy for simple element-wise operations.
Source: benchmarks/bench_llm_ops.py