rlox Mathematical Reference¶
This document provides the mathematical formulations for every algorithm and computation implemented in rlox. Equations are numbered for cross-referencing.
Notation¶
| Symbol | Meaning |
|---|---|
| \(s_t, a_t\) | State and action at time \(t\) |
| \(r_t\) | Reward at time \(t\) |
| \(\gamma \in [0, 1)\) | Discount factor |
| \(\pi_\theta(a \mid s)\) | Policy parameterised by \(\theta\) |
| \(V^\pi(s)\) | State value function under policy \(\pi\) |
| \(Q^\pi(s, a)\) | State-action value function under policy \(\pi\) |
| \(A^\pi(s, a)\) | Advantage function: \(Q^\pi(s,a) - V^\pi(s)\) |
| \(\hat{A}_t\) | Estimated advantage at time \(t\) |
| \(\mathcal{H}[\pi]\) | Entropy of the policy: \(-\mathbb{E}[\log \pi]\) |
| \(D_\text{KL}\) | Kullback-Leibler divergence |
| \(\tau\) | Polyak averaging coefficient |
1. Generalized Advantage Estimation (GAE)¶
Reference: Schulman et al. (2016) [2]
Implementation: rlox_core::training::gae::compute_gae (Rust), rlox.compute_gae (Python)
Derivation¶
The TD residual at time \(t\) is:
where \(d_t \in \{0, 1\}\) is the episode termination flag. GAE defines the advantage estimator as an exponentially-weighted average of \(k\)-step TD errors:
This is computed via the backward recursion:
with \(\hat{A}_T = 0\). The return target is:
Special Cases¶
- \(\lambda = 0\): \(\hat{A}_t = \delta_t\) (one-step TD error, low variance, high bias)
- \(\lambda = 1\): \(\hat{A}_t = \sum_{l=0}^{T-t-1} \gamma^l r_{t+l} - V(s_t)\) (Monte Carlo return minus baseline, high variance, low bias)
- \(\lambda \in (0,1)\): Interpolates between the two
Complexity¶
- Time: \(O(T)\) per trajectory (single backward pass)
- Space: \(O(T)\) for the advantages vector
2. Proximal Policy Optimization (PPO)¶
Reference: Schulman et al. (2017) [1]
Implementation: rlox.losses.PPOLoss (Python), rlox.algorithms.ppo.PPO
Clipped Surrogate Objective¶
Let \(r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{old}}(a_t \mid s_t)}\) be the probability ratio. The clipped objective is:
where \(\epsilon\) is the clip range (default 0.2).
Value Loss¶
With optional clipping (matching CleanRL):
where \(V_\text{clip}(s_t) = V_{\theta_\text{old}}(s_t) + \text{clip}(V_\theta(s_t) - V_{\theta_\text{old}}(s_t), -\epsilon, \epsilon)\).
Without clipping: \(L^V(\theta) = \frac{1}{2}\mathbb{E}_t\left[(V_\theta(s_t) - \hat{R}_t)^2\right]\).
Entropy Bonus¶
Total Loss¶
Default coefficients: \(c_v = 0.5\), \(c_h = 0.01\).
Diagnostics¶
Approximate KL divergence (from ratio):
Clip fraction: \(\mathbb{E}_t\left[\mathbf{1}[|r_t(\theta) - 1| > \epsilon]\right]\).
Training Procedure¶
For each update: 1. Collect \(n_\text{envs} \times n_\text{steps}\) transitions using \(\pi_{\theta_\text{old}}\) 2. Compute GAE advantages using Eq. (3) 3. For \(K\) epochs (default 4): - Shuffle and split into minibatches - Normalise advantages per minibatch: \(\hat{A}_t \leftarrow (\hat{A}_t - \bar{A}) / (\sigma_A + 10^{-8})\) - Compute loss (Eq. 8) and update \(\theta\) via Adam - Clip gradients: \(\|\nabla\|_2 \leq 0.5\) 4. Linearly anneal learning rate (optional)
3. Advantage Actor-Critic (A2C)¶
Reference: Mnih et al. (2016) [13]
Implementation: rlox.algorithms.a2c.A2C
A2C is the synchronous variant of A3C. It uses the same advantage estimation as PPO but without ratio clipping or multiple epochs.
Policy Gradient¶
Total Loss¶
Default: \(c_v = 0.5\), \(c_h = 0.01\), optimised with RMSprop.
Key Differences from PPO¶
| PPO | A2C | |
|---|---|---|
| Clipping | Yes (Eq. 5) | No |
| Epochs per rollout | \(K\) (typically 4) | 1 |
| GAE lambda | 0.95 | 1.0 (Monte Carlo) |
| Optimizer | Adam | RMSprop |
| n_steps | 128 | 5 |
4. Soft Actor-Critic (SAC)¶
Reference: Haarnoja et al. (2018) [3]
Implementation: rlox.algorithms.sac.SAC
Entropy-Regularised Objective¶
SAC maximises the maximum-entropy objective:
where \(\alpha\) is the temperature parameter controlling the entropy-reward tradeoff.
Soft Bellman Equation¶
The soft Q-function satisfies:
Critic Loss (Twin Q-Networks)¶
Two independent Q-networks are trained with:
where \(\tilde{a}' \sim \pi_\theta(\cdot \mid s')\) and \(\bar{\phi}_i\) are target network parameters.
Actor Loss¶
where \(\tilde{a} \sim \pi_\theta(\cdot \mid s)\) via the reparameterisation trick.
Squashed Gaussian Policy¶
Actions are sampled as \(a = \tanh(\mu(s) + \sigma(s) \odot \xi)\), \(\xi \sim \mathcal{N}(0, I)\). The log-probability with the Jacobian correction is:
Automatic Entropy Tuning¶
The temperature \(\alpha\) is optimised to satisfy a target entropy \(\bar{\mathcal{H}}\):
Default target entropy: \(\bar{\mathcal{H}} = -\dim(\mathcal{A})\).
Soft Target Update¶
5. Twin Delayed DDPG (TD3)¶
Reference: Fujimoto et al. (2018) [4]
Implementation: rlox.algorithms.td3.TD3
TD3 addresses overestimation bias in DDPG with three techniques:
Twin Critics¶
Same as SAC (Eq. 15-16), but with a deterministic target policy:
Target Policy Smoothing¶
Adds clipped noise to the target action:
Default: \(\sigma = 0.2\), \(c = 0.5\).
Delayed Policy Updates¶
The actor and target networks are updated every \(d\) critic updates (default \(d = 2\)):
6. Deep Q-Network (DQN)¶
Reference: Mnih et al. (2015) [5]
Implementation: rlox.algorithms.dqn.DQN
Bellman Equation¶
The Q-function satisfies:
DQN Loss¶
where \(n\) is the N-step return horizon and \(\bar{\theta}\) are target network parameters updated every \(K\) steps (hard copy).
Double DQN¶
Reference: van Hasselt et al. (2016) [6]
Decouples action selection from evaluation to reduce overestimation:
Dueling Architecture¶
Reference: Wang et al. (2016) [7]
Decomposes Q into value and advantage streams:
N-Step Returns¶
Instead of single-step bootstrapping, uses the \(n\)-step return:
Prioritized Experience Replay (PER)¶
Reference: Schaul et al. (2016) [8]
Samples transitions proportional to their TD error magnitude:
Importance-sampling weights correct the bias:
where \(\beta\) is annealed from \(\beta_0\) (default 0.4) to 1.0 over training.
Implementation: rlox_core::buffer::priority uses a sum-tree for \(O(\log N)\) sampling and priority updates.
Epsilon-Greedy Exploration¶
\(\epsilon\) is linearly decayed from \(\epsilon_0\) (default 1.0) to \(\epsilon_f\) (default 0.05) over the first fraction of training.
7. V-trace¶
Reference: Espeholt et al. (2018) [9]
Implementation: rlox_core::training::vtrace::compute_vtrace (Rust), rlox.compute_vtrace (Python)
V-trace provides off-policy correction for the IMPALA architecture.
Importance Weights¶
Let \(\rho_t = \frac{\pi(a_t \mid s_t)}{\mu(a_t \mid s_t)}\) be the importance ratio, where \(\mu\) is the behaviour policy. Define clipped weights:
V-trace Target¶
where the temporal difference is:
Backward Recursion¶
The implementation computes V-trace via backward iteration:
Policy gradient advantages:
Properties¶
- \(\bar{\rho} = \bar{c} = \infty\): on-policy, equivalent to GAE(\(\lambda=1\))
- \(\bar{\rho} = \bar{c} = 1\): default, limits variance while allowing some off-policy correction
- \(\bar{\rho}\) controls the bias of the value function fixed point
- \(\bar{c}\) controls the speed of convergence (trace cutting)
8. Group Relative Policy Optimization (GRPO)¶
Reference: Shao et al. (2024) [11]
Implementation: rlox_core::llm::ops::compute_group_advantages (Rust), rlox.algorithms.grpo.GRPO (Python)
Group-Relative Advantages¶
For a prompt \(x\), generate \(G\) completions \(\{y_1, \ldots, y_G\}\) and compute rewards \(\{r_1, \ldots, r_G\}\). The advantage for completion \(i\) is:
where \(\bar{r} = \frac{1}{G}\sum_{j=1}^G r_j\) and \(\sigma_r = \sqrt{\frac{1}{G}\sum_{j=1}^G (r_j - \bar{r})^2}\).
If \(\sigma_r < 10^{-8}\) (constant rewards), all advantages are set to zero.
GRPO Loss¶
The KL penalty prevents the policy from drifting too far from the reference model.
Key Difference from PPO¶
GRPO eliminates the need for a learned value function. Instead of \(V(s)\) as a baseline, it uses the group mean reward. This is particularly suited for LLM post-training where: - Episodes are single-turn (generate once, score once) - The reward function is an external model (e.g. reward model, verifier) - Training a value head for language models is expensive
9. Direct Preference Optimization (DPO)¶
Reference: Rafailov et al. (2023) [10]
Implementation: rlox.algorithms.dpo.DPO
Bradley-Terry Preference Model¶
Given chosen completion \(y_w\) and rejected completion \(y_l\) for prompt \(x\), the preference probability under the Bradley-Terry model is:
where \(\sigma\) is the sigmoid function and \(r^*\) is the ground-truth reward.
From RLHF to DPO¶
The optimal policy under KL-constrained reward maximisation has the form:
Solving for the reward and substituting into the Bradley-Terry model yields:
DPO Loss¶
where \(\beta\) is the temperature parameter (default 0.1).
Implicit Reward¶
DPO implicitly defines a reward:
The diagnostic metrics track chosen_reward and rejected_reward as \(\beta \cdot \mathbb{E}[\log(\pi_\theta / \pi_\text{ref})]\) for chosen and rejected completions respectively. A well-trained model should have chosen_reward > rejected_reward.
Sequence Log-Probabilities¶
For a sequence \(y = (y_1, \ldots, y_T)\):
10. Token-Level KL Divergence¶
Implementation: rlox_core::llm::ops::compute_token_kl (Rust), rlox.compute_token_kl (Python)
The forward KL divergence at the token level:
where \(\log p_t\) and \(\log q_t\) are the per-token log-probabilities under the policy and reference model respectively.
Properties: - \(D_\text{KL}(p \| p) = 0\) (identical distributions) - \(D_\text{KL}(p \| q) \geq 0\) (Gibbs' inequality) - Not symmetric: \(D_\text{KL}(p \| q) \neq D_\text{KL}(q \| p)\)
Used as a regularisation penalty in GRPO (Eq. 39) and general RLHF training to prevent policy collapse.
11. Bootstrap Confidence Intervals and IQM¶
Reference: Agarwal et al. (2021) [12]
Implementation: rlox.evaluation
Interquartile Mean (IQM)¶
A robust measure of central tendency that discards the bottom and top 25% of scores:
where \(X_{(i)}\) are the order statistics. IQM is more robust to outliers than the mean while being more statistically efficient than the median.
Stratified Bootstrap Confidence Interval¶
For a set of scores \(\{x_1, \ldots, x_n\}\):
- Draw \(B\) bootstrap resamples \(\{x_1^*, \ldots, x_n^*\}\) with replacement
- Compute \(\bar{x}^*_b = \frac{1}{n}\sum_i x_{i,b}^*\) for each resample \(b\)
- The \((1 - \alpha)\) confidence interval is \([\bar{x}^*_{(\alpha/2)}, \bar{x}^*_{(1-\alpha/2)}]\)
Default: \(B = 10{,}000\) resamples, 95% CI.
Performance Profiles¶
The performance profile of algorithm \(A\) is:
This gives the fraction of runs where algorithm \(A\) achieves at least score \(\tau\), aggregating across environments and seeds.
12. Polyak (Soft) Target Update¶
Used by SAC, TD3, and other off-policy algorithms:
Implementation: rlox.networks.polyak_update
Default \(\tau = 0.005\). This is equivalent to an exponential moving average of the online parameters, providing a slowly-evolving target that stabilises training.
13. Orthogonal Initialisation¶
Reference: Andrychowicz et al. (2021) [14]
Implementation: rlox.policies._orthogonal_init
Policy networks use orthogonal weight initialisation with gain-dependent scaling: - Hidden layers: gain \(= \sqrt{2}\) (for ReLU/Tanh activations) - Policy head: gain \(= 0.01\) (encourages near-uniform initial distribution) - Value head: gain \(= 1.0\)
Cross-References¶
- Rust User Guide -- code-level documentation of each implementation
- Python User Guide -- API usage examples
- References -- full academic citations for all referenced papers