rlox Rust Optimization Plan¶
Overview¶
| # | Item | File | Impact | Effort |
|---|---|---|---|---|
| P1 | AsyncCollector serial strided GAE | pipeline/collector.rs:167-193 |
2-5x GAE | Medium |
| P2 | VecEnv per-env obs allocation | env/parallel.rs, builtins.rs:69 |
~10M allocs/s removed | High |
| P3 | PyO3 push() double-copy | rlox-python/buffer.rs:222-239 |
40-50% push latency | Low |
| P4 | SampledBatch fresh alloc per sample() | buffer/ringbuf.rs:199-244 |
~56MB/step (Atari) | Low |
| P5 | PER tree_min_prob O(N) | buffer/priority.rs:320-337 |
O(N) -> O(1) | Medium |
| P6 | Release profile optimizations | Cargo.toml |
5-15% all code | Trivial |
| P7 | GAE inner loop branch | training/gae.rs:28 |
1-3% | Trivial |
| P8 | Rayon PAR_THRESHOLD tuning | llm/ops.rs:124 |
Prevent overhead | Trivial |
| P9 | V-trace double exp() | training/vtrace.rs:49-50 |
~200us/call | Trivial |
| P10 | HashMap::new per CartPole step | env/builtins.rs:137 |
Minor alloc | Trivial |
| P11 | PyO3 sample() missing allow_threads | rlox-python/buffer.rs |
GIL concurrency | Trivial |
| P12 | ExperienceRecord forces ownership | buffer/mod.rs:9-19 |
Root cause of P3 | Medium |
| P13 | current_obs re-alloc per step | pipeline/collector.rs:143-147 |
1 alloc/step | Low |
Detailed Items¶
P1: AsyncCollector serial strided GAE (CRITICAL)¶
File: crates/rlox-core/src/pipeline/collector.rs:167-193
Problem: The AsyncCollector computes GAE per-environment by extracting strided per-env slices into temporary Vecs, then calling single-env compute_gae in a serial loop. compute_gae_batched with Rayon parallelism exists but is not used.
Current code:
for env_idx in 0..n_envs {
// Strided gather: touches cache lines n_envs * 8 bytes apart
let env_rewards: Vec<f64> = (0..n_steps)
.map(|t| all_rewards[t * n_envs + env_idx])
.collect();
let env_values: Vec<f64> = (0..n_steps)
.map(|t| all_values[t * n_envs + env_idx])
.collect();
let env_dones: Vec<f64> = (0..n_steps)
.map(|t| all_dones[t * n_envs + env_idx])
.collect();
let (env_adv, env_ret) = gae::compute_gae(
&env_rewards, &env_values, &env_dones,
last_values[env_idx], gamma, gae_lambda,
);
// Strided write-back
for t in 0..n_steps {
advantages[t * n_envs + env_idx] = env_adv[t];
returns[t * n_envs + env_idx] = env_ret[t];
}
}
Issues: 1. Cache-hostile strided access: with 64 envs, every gather touches cache lines 512 bytes apart 2. 3 * n_envs temporary Vec allocations per rollout 3. Serial loop when Rayon batched version exists 4. Strided write-back also thrashes cache
Proposed fix:
Option A (minimal): Transpose data once, call compute_gae_batched, transpose result back:
// Transpose step-major -> env-major
let mut env_major_rewards = vec![0.0; total];
let mut env_major_values = vec![0.0; total];
let mut env_major_dones = vec![0.0; total];
for t in 0..n_steps {
for e in 0..n_envs {
env_major_rewards[e * n_steps + t] = all_rewards[t * n_envs + e];
env_major_values[e * n_steps + t] = all_values[t * n_envs + e];
env_major_dones[e * n_steps + t] = all_dones[t * n_envs + e];
}
}
let (env_major_adv, env_major_ret) = gae::compute_gae_batched(
&env_major_rewards, &env_major_values, &env_major_dones,
&last_values, n_steps, gamma, gae_lambda,
);
// Transpose back env-major -> step-major
for e in 0..n_envs {
for t in 0..n_steps {
advantages[t * n_envs + e] = env_major_adv[e * n_steps + t];
returns[t * n_envs + e] = env_major_ret[e * n_steps + t];
}
}
Option B (better, larger refactor): Store rollout data env-major from the start in the collection loop (lines 130-147), then call compute_gae_batched directly without transposing.
Impact: 2-5x GAE speedup for typical configs (64 envs, 128 steps). Eliminates O(n_envs) temporary allocations.
Effort: Medium
Dependencies: None (compute_gae_batched already exists)
Testing: Existing proptest gae_batched_matches_unbatched validates correctness. Add integration test comparing AsyncCollector output before/after.
P2: VecEnv per-env observation allocation¶
File: crates/rlox-core/src/env/builtins.rs:68-69, crates/rlox-core/src/env/parallel.rs:60-92
Problem: Every CartPole::obs() allocates a new Vec<f32>:
BatchTransition stores observations as Vec<Vec<f32>> (Array-of-Structures), requiring n_envs heap allocations per step_all().
Proposed fix:
-
Add
write_obs(&self, buf: &mut [f32])toRLEnvtrait: -
Change
BatchTransitionto use flat storage: -
Pre-allocate in
VecEnvand reuse across steps:
Impact: Eliminates n_envs heap allocations per step. For 1024 envs at 10kHz: ~10M allocations/second removed.
Effort: High (requires trait changes, struct changes, PyO3 adapter changes)
Dependencies: None
Testing: Existing CartPole tests + new test verifying flat obs matches per-env obs.
P3: PyO3 push() double-copy¶
File: crates/rlox-python/src/buffer.rs:222-239
Problem: Each push() call copies data twice:
fn push(&mut self, obs: PyReadonlyArray1<f32>, ...) -> PyResult<()> {
let obs_vec = obs.as_slice()?.to_vec(); // copy 1: numpy -> Vec
let next_obs_vec = n.as_slice()?.to_vec(); // copy 1
let action_vec = action.as_slice()?.to_vec(); // copy 1
let record = ExperienceRecord { obs: obs_vec, ... };
self.inner.push(record) // copy 2: Vec -> ring buffer
}
Proposed fix: Add push_slices method to ReplayBuffer:
impl ReplayBuffer {
pub fn push_slices(
&mut self,
obs: &[f32],
next_obs: &[f32],
action: &[f32],
reward: f32,
terminated: bool,
truncated: bool,
) -> Result<(), RloxError> {
let idx = self.write_pos % self.capacity;
let obs_start = idx * self.obs_dim;
self.observations[obs_start..obs_start + self.obs_dim].copy_from_slice(obs);
self.next_observations[obs_start..obs_start + self.obs_dim].copy_from_slice(next_obs);
let act_start = idx * self.act_dim;
self.actions[act_start..act_start + self.act_dim].copy_from_slice(action);
self.rewards[idx] = reward;
self.terminated[idx] = terminated;
self.truncated[idx] = truncated;
self.write_pos += 1;
if self.count < self.capacity { self.count += 1; }
Ok(())
}
}
Then in PyO3:
fn push(&mut self, obs: PyReadonlyArray1<f32>, ...) -> PyResult<()> {
self.inner.push_slices(
obs.as_slice()?, next_obs.as_slice()?, action.as_slice()?,
reward, terminated, truncated,
).map_err(...)
}
Impact: 40-50% push latency reduction. For obs_dim=28224 (Atari): saves ~220KB copying per push.
Effort: Low
Dependencies: P12 (ExperienceRecord ownership) — this effectively replaces it.
Testing: Existing buffer push/sample round-trip tests validate correctness.
P4: SampledBatch fresh allocation every sample()¶
File: crates/rlox-core/src/buffer/ringbuf.rs:199-244
Problem: Every sample() creates new Vecs:
pub fn sample(&self, batch_size: usize, seed: u64) -> Result<SampledBatch, RloxError> {
let mut batch = SampledBatch::with_capacity(batch_size, self.obs_dim, self.act_dim);
// ... fills vecs ...
}
Proposed fix:
pub fn sample_into(
&self,
batch: &mut SampledBatch,
batch_size: usize,
seed: u64,
) -> Result<(), RloxError> {
batch.clear(); // reuse allocations
// ... same gather logic, extending existing Vecs ...
}
Add SampledBatch::clear():
impl SampledBatch {
pub fn clear(&mut self) {
self.observations.clear();
self.next_observations.clear();
self.actions.clear();
self.rewards.clear();
self.terminated.clear();
self.truncated.clear();
}
}
The PyO3 layer can cache a SampledBatch on the PyReplayBuffer struct.
Impact: Eliminates allocator pressure. For batch_size=256, obs_dim=28224: saves ~56MB alloc+dealloc per step.
Effort: Low
Dependencies: None
Testing: Existing sample tests. Verify same results with sample vs sample_into.
P5: PrioritizedReplayBuffer tree_min_prob O(N)¶
File: crates/rlox-core/src/buffer/priority.rs:320-337
Problem:
fn tree_min_prob(&self) -> f64 {
let mut min_p = f64::MAX;
for i in 0..self.count {
let p = self.tree.get(i);
// ...
}
}
sample(). For 1M entries: 1M f64 reads per call.
Proposed fix: Add a parallel min-tree alongside the sum-tree:
pub struct SumTree {
capacity: usize,
tree: Vec<f64>, // sum tree
min_tree: Vec<f64>, // min tree (new)
}
impl SumTree {
pub fn set(&mut self, index: usize, value: f64) {
// Update sum tree (existing)
// ...
// Update min tree (new)
let mut pos = index + self.capacity;
self.min_tree[pos] = value;
while pos > 1 {
pos /= 2;
self.min_tree[pos] = self.min_tree[2 * pos].min(self.min_tree[2 * pos + 1]);
}
}
pub fn min(&self) -> f64 {
self.min_tree[1] // O(1)
}
}
Impact: O(N) -> O(1) for tree_min_prob. O(log N) overhead added to each set().
Effort: Medium
Dependencies: None
Testing: Extend existing SumTree proptests to verify min property.
P6: Release profile optimizations (QUICK WIN)¶
File: Workspace Cargo.toml
Problem: No [profile.release] section. Default codegen-units=16, no LTO, no target-cpu.
Proposed fix — add to workspace Cargo.toml:
Add .cargo/config.toml:
Impact: 5-15% across all numeric code. codegen-units=1 enables cross-crate inlining. LTO enables cross-crate optimization. target-cpu=native enables AVX2/AVX-512 auto-vectorization.
Effort: Trivial (add 4 lines)
Dependencies: None
Testing: Run existing benchmarks before/after to measure.
P7: GAE inner loop branch (QUICK WIN)¶
File: crates/rlox-core/src/training/gae.rs:26-36
Problem:
for t in (0..n).rev() {
let next_value = if t == n - 1 { last_value } else { values[t + 1] }; // branch every iteration
// ...
}
Proposed fix:
// Handle last step outside loop
let next_non_terminal = 1.0 - dones[n - 1];
let delta = rewards[n - 1] + gamma * last_value * next_non_terminal - values[n - 1];
let mut last_gae = delta;
advantages[n - 1] = last_gae;
// Main loop: no branch
for t in (0..n - 1).rev() {
let next_non_terminal = 1.0 - dones[t];
let delta = rewards[t] + gamma * values[t + 1] * next_non_terminal - values[t];
last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae;
advantages[t] = last_gae;
}
Impact: 1-3%. Removes branch prediction miss risk in tight numerical loop.
Effort: Trivial
Dependencies: None
Testing: Existing GAE tests + proptests.
P8: Rayon PAR_THRESHOLD tuning (QUICK WIN)¶
File: crates/rlox-core/src/llm/ops.rs:124
Problem: const PAR_THRESHOLD: usize = 16; — Rayon overhead (~1-10us) may dominate for short sequences.
Proposed fix:
// Gate on total work, not just batch size
let total_elements = log_probs_policy.len();
const PAR_ELEMENT_THRESHOLD: usize = 4096;
let out = if total_elements >= PAR_ELEMENT_THRESHOLD {
use rayon::prelude::*;
(0..batch_size).into_par_iter().map(kl_for_seq).collect()
} else {
(0..batch_size).map(kl_for_seq).collect()
};
Also: parallelize compute_batch_group_advantages (currently serial).
Impact: Prevents Rayon overhead on small inputs. Adds parallelism to GRPO.
Effort: Trivial
Dependencies: None
P9: V-trace double exp() (QUICK WIN)¶
File: crates/rlox-core/src/training/vtrace.rs:49-50
Problem:
let rho_t = rho_bar.min(log_rhos[t].exp()); // exp() call 1
let c_t = c_bar.min(log_rhos[t].exp()); // exp() call 2 (identical)
Proposed fix:
Impact: ~20 cycles saved per timestep. For 10K steps: ~200us per vtrace call.
Effort: Trivial
Dependencies: None
Testing: Existing vtrace tests.
P10: HashMap::new() per CartPole step (QUICK WIN)¶
File: crates/rlox-core/src/env/builtins.rs:137
Problem:
Proposed fix: Change Transition.info to Option<HashMap<String, f64>>:
Or use HashMap::with_capacity(0) (does not allocate on Rust >= 1.36).
Impact: Removes 1 allocation per environment per step.
Effort: Trivial
Dependencies: Requires updating all code that reads transition.info.
P11: PyO3 sample() missing allow_threads (QUICK WIN)¶
File: crates/rlox-python/src/buffer.rs (ReplayBuffer and PrioritizedReplayBuffer sample methods)
Problem: sample() holds the GIL during Rust computation.
Proposed fix:
fn sample<'py>(&self, py: Python<'py>, batch_size: usize, seed: u64) -> PyResult<...> {
let batch = py.allow_threads(|| {
self.inner.sample(batch_size, seed)
}).map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
// Build Python dict from batch (needs GIL)
// ...
}
Impact: Enables Python threads to run during sampling.
Effort: Trivial
Dependencies: None
P12: ExperienceRecord forces ownership¶
File: crates/rlox-core/src/buffer/mod.rs:9-19
Problem:
pub struct ExperienceRecord {
pub obs: Vec<f32>, // owned, always cloned
pub next_obs: Vec<f32>, // owned, always cloned
pub action: Vec<f32>, // owned, always cloned
// ...
}
Root cause of P3 (double-copy in push).
Proposed fix: Keep ExperienceRecord for the public API but add the push_slices path from P3 as the primary internal API. ExperienceRecord becomes a convenience wrapper:
impl ReplayBuffer {
pub fn push(&mut self, record: ExperienceRecord) -> Result<(), RloxError> {
self.push_slices(&record.obs, &record.next_obs, &record.action,
record.reward, record.terminated, record.truncated)
}
}
Impact: Combined with P3, eliminates all unnecessary copies in the push path.
Effort: Medium (API change)
Dependencies: Implement alongside P3.
P13: current_obs re-allocation per step in AsyncCollector¶
File: crates/rlox-core/src/pipeline/collector.rs:143-147
Problem:
current_obs = transition
.obs
.into_iter()
.flat_map(|o| o.into_iter())
.collect(); // re-allocates every step
Proposed fix:
// Before loop:
let mut current_obs = vec![0.0f32; n_envs * obs_dim];
// In loop:
let mut offset = 0;
for obs_vec in &transition.obs {
current_obs[offset..offset + obs_dim].copy_from_slice(obs_vec);
offset += obs_dim;
}
Impact: Eliminates 1 Vec allocation per step.
Effort: Low
Dependencies: None (but benefits from P2 flat obs layout)
Implementation Order¶
Phase 1 — Quick Wins (1-2 hours total): P6 (Cargo profiles) -> P9 (vtrace exp) -> P7 (GAE branch) -> P10 (HashMap) -> P11 (sample GIL) -> P8 (threshold tuning)
Phase 2 — Buffer Performance (1 day): P3 + P12 (push_slices) -> P4 (sample_into) -> P5 (min-tree)
Phase 3 — Pipeline & VecEnv (2-3 days): P1 (AsyncCollector GAE) -> P13 (current_obs) -> P2 (flat obs layout)