LLM Post-Training with rlox¶
rlox provides Rust-accelerated primitives for LLM post-training (RLHF, DPO, GRPO) alongside pure-Python algorithm implementations. The Rust primitives handle the compute-heavy operations (KL divergence, GRPO advantages, sequence packing) while Python handles the model forward/backward passes.
Overview¶
| Algorithm | Class | Use Case |
|---|---|---|
| DPO | rlox.algorithms.dpo.DPO |
Offline preference optimization with static dataset |
| GRPO | rlox.algorithms.grpo.GRPO |
Group-relative policy optimization (DeepSeek-R1 style) |
| OnlineDPO | rlox.algorithms.online_dpo.OnlineDPO |
DPO with online generation + preference oracle |
| BestOfN | rlox.algorithms.best_of_n.BestOfN |
Inference-time rejection sampling |
Rust-Accelerated Primitives¶
These are the building blocks — use them directly for custom training loops:
import rlox
import numpy as np
# --- KL Divergence (2-27x faster than TRL/NumPy) ---
# Single sequence (f64)
kl = rlox.compute_token_kl(log_probs_policy, log_probs_ref)
# Schulman estimator (TRL-compatible)
kl = rlox.compute_token_kl_schulman(log_probs_policy, log_probs_ref)
# Batched (single FFI call for all sequences)
# log_probs are flat arrays of shape (batch * seq_len,)
kl_per_seq = rlox.compute_batch_token_kl(
log_p_flat, log_q_flat, seq_len=512
)
# f32 variant (2x faster, matches PyTorch native precision)
kl_per_seq = rlox.compute_batch_token_kl_schulman_f32(
log_p_flat.astype(np.float32),
log_q_flat.astype(np.float32),
seq_len=512,
)
# --- GRPO Group Advantages (27x faster than TRL at small groups) ---
# Single group
advantages = rlox.compute_group_advantages(rewards) # (G,) -> (G,)
# Batched: all prompts at once
# rewards shape: (n_prompts * group_size,)
advantages = rlox.compute_batch_group_advantages(rewards, group_size=8)
# --- Sequence Packing ---
# First-fit-decreasing bin packing for variable-length sequences
from rlox import pack_sequences
bins = pack_sequences(sequence_lengths, max_length=2048)
DPO — Direct Preference Optimization¶
Train a language model to prefer chosen completions over rejected ones, using a reference model for KL regularization.
import torch
import torch.nn as nn
from rlox.algorithms.dpo import DPO
# Your language model and a frozen reference copy
model = MyLanguageModel(vocab_size=32000, hidden=256)
ref_model = MyLanguageModel(vocab_size=32000, hidden=256)
ref_model.load_state_dict(model.state_dict())
for p in ref_model.parameters():
p.requires_grad = False
# Create DPO trainer
dpo = DPO(
model=model,
ref_model=ref_model,
beta=0.1, # KL penalty strength
learning_rate=1e-4,
)
# Train on preference pairs
# prompt, chosen, rejected are (1, seq_len) token ID tensors
for prompt, chosen, rejected in dataset:
metrics = dpo.train_step(prompt, chosen, rejected)
print(f"Loss: {metrics['loss']:.4f}")
# Save checkpoint
dpo.save("dpo_checkpoint.pt")
Key Features¶
- Gradient clipping (
max_grad_norm=1.0by default) - Callback support for monitoring
- Logger integration (W&B, TensorBoard, Console)
GRPO — Group Relative Policy Optimization¶
GRPO generates multiple completions per prompt, scores them with a reward function, and computes advantages relative to the group (DeepSeek-R1 approach).
from rlox.algorithms.grpo import GRPO
def reward_fn(completions: list[torch.Tensor]) -> list[float]:
"""Score completions. Higher is better."""
return [score_completion(c) for c in completions]
grpo = GRPO(
model=model,
ref_model=ref_model,
reward_fn=reward_fn,
group_size=4, # completions per prompt
kl_coef=0.1, # KL penalty
learning_rate=1e-4,
max_new_tokens=128,
)
# Train on prompts
for prompt_batch in dataloader:
metrics = grpo.train_step(prompt_batch)
print(f"Loss: {metrics['loss']:.4f}, Mean reward: {metrics['mean_reward']:.2f}, KL: {metrics['mean_kl']:.4f}")
grpo.save("grpo_checkpoint.pt")
How GRPO Uses Rust Primitives¶
Internally, GRPO uses two Rust-accelerated operations:
compute_batch_group_advantages— normalizes rewards within each group in a single Rust call (Rayon-parallelized for large batches)compute_batch_token_kl— computes per-sequence KL divergence in a single batched call instead of looping
OnlineDPO — Online Direct Preference Optimization¶
Like DPO but generates completions online and queries a preference oracle:
from rlox.algorithms.online_dpo import OnlineDPO
def preference_fn(pairs: list[tuple[torch.Tensor, torch.Tensor]]) -> list[int]:
"""Return 0 if first completion is preferred, 1 if second."""
return [0 if score(a) > score(b) else 1 for a, b in pairs]
online_dpo = OnlineDPO(
model=model,
ref_model=ref_model,
preference_fn=preference_fn,
beta=0.1,
learning_rate=1e-4,
)
# Each step generates pairs, queries preferences, and updates
for prompt_batch in dataloader:
metrics = online_dpo.train_step(prompt_batch)
BestOfN — Rejection Sampling¶
Generate N completions per prompt, score them, and return the best:
from rlox.algorithms.best_of_n import BestOfN
bon = BestOfN(
model=model,
reward_fn=reward_fn,
n=8, # generate 8 candidates
max_new_tokens=128,
)
# Generate best completions (no training, inference only)
best_completions = bon.generate(prompts) # (B, P+T)
Performance Comparison vs TRL¶
Benchmarked on GCP (8 vCPU, CPU-only):
| Operation | rlox | TRL (PyTorch) | Speedup |
|---|---|---|---|
| GRPO 16x4 | 1.3 us | 34.1 us | 27x |
| GRPO 1024x32 | 115 us | 243 us | 2.1x |
| KL 1x128 (f32) | 0.3 us | 3.1 us | 9.4x |
| KL 32x2048 (f32) | 72 us | 194 us | 2.7x |
Custom Training Loop¶
For full control, use the Rust primitives directly:
import rlox
import torch
import numpy as np
model = MyModel()
ref_model = MyModel()
ref_model.load_state_dict(model.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for batch in dataloader:
prompts, completions = batch
# Forward pass
policy_logprobs = get_per_token_logprobs(model, completions)
with torch.no_grad():
ref_logprobs = get_per_token_logprobs(ref_model, completions)
# Compute KL using Rust (batched, f32)
p = policy_logprobs.detach().cpu().numpy().astype(np.float32).ravel()
r = ref_logprobs.cpu().numpy().astype(np.float32).ravel()
seq_len = policy_logprobs.shape[-1]
kl_per_seq = rlox.compute_batch_token_kl_schulman_f32(p, r, seq_len)
mean_kl = float(kl_per_seq.mean())
# Compute GRPO advantages using Rust
rewards = np.array(reward_fn(completions))
advantages = rlox.compute_batch_group_advantages(rewards, group_size=4)
# Your custom loss
loss = compute_loss(policy_logprobs, advantages, kl_per_seq)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Monitoring Training¶
from rlox.logging import ConsoleLogger, WandbLogger
# Console output every 100 steps
logger = ConsoleLogger(log_interval=100)
# Or Weights & Biases
logger = WandbLogger(project="my-llm-training")
dpo = DPO(model=model, ref_model=ref_model, logger=logger)
Saving and Loading¶
# Save
dpo.save("checkpoint.pt")
# Load
dpo_loaded = DPO.from_checkpoint("checkpoint.pt", env_id="unused")