Hybrid Rust Collection + Candle Inference Architecture¶
Motivation¶
Profiling reveals PyTorch dispatch overhead dominates PPO training time:
| Component | Latency | Share |
|---|---|---|
| Policy inference (PyTorch) | 123us/call | 56% of step |
| Actual compute | ~10us | — |
| Env step (Rust VecEnv) | 96us/call | 44% of step |
The root cause: each policy forward call crosses Python→PyTorch→Python boundaries with ~113us of dispatch overhead atop ~10us of actual BLAS work.
Solution: Three-phase approach to eliminate Python overhead progressively.
Three-Phase Roadmap¶
gantt
title rlox Performance Roadmap
dateFormat YYYY-MM-DD
section Phase 1
torch.compile integration :p1, 2026-03-28, 3d
Auto-compile in trainers :p1b, after p1, 2d
section Phase 2
ONNX export from PyTorch :p2a, after p1b, 3d
ONNX import to Candle :p2b, after p2a, 4d
Candle inference server :p2c, after p2b, 3d
section Phase 3
Weight sync PyTorch→Candle :p3a, after p2c, 3d
Candle collection loop :p3b, after p3a, 5d
CUDA backend integration :p3c, after p3b, 4d
Python PPO integration :p3d, after p3c, 3d
Phase 1: torch.compile (Immediate, 15-30% speedup)¶
Status: compile_policy() already exists. Needs auto-application in trainers.
flowchart LR
A[PPO/SAC created] --> B{compile flag?}
B -->|yes| C[torch.compile on policy networks]
B -->|no| D[standard PyTorch]
C --> E[15-30% faster inference]
D --> F[baseline speed]
Changes needed:¶
-
Auto-compile in algorithm constructors — add
compile: bool = Falseparameter: -
Fix compile targets — already done: compiles
get_action_and_logprob,get_value,get_logprob_and_entropyindividually for on-policy policies. -
Benchmark — measure SPS before/after on CartPole, HalfCheetah.
Files: python/rlox/algorithms/ppo.py, sac.py, td3.py, dqn.py
Effort: 1 day
Impact: 15-30% wall-time reduction
Phase 2: ONNX Export → Candle Import (Deployment)¶
Enable PyTorch-trained models to be deployed via Candle for inference without Python.
flowchart LR
A[Train in PyTorch] --> B[Export to ONNX]
B --> C[Load in Candle]
C --> D[Rust inference server]
D --> E[Sub-microsecond latency]
subgraph Python Training
A
B
end
subgraph Rust Deployment
C
D
E
end
2.1 ONNX Export from PyTorch¶
File: New python/rlox/export.py
def export_policy_onnx(algo, path: str, obs_dim: int):
"""Export a trained policy to ONNX format."""
import torch
dummy_input = torch.randn(1, obs_dim)
if hasattr(algo, 'actor'): # SAC/TD3
torch.onnx.export(algo.actor, dummy_input, path,
input_names=['obs'], output_names=['action'])
elif hasattr(algo, 'policy'): # PPO/A2C
torch.onnx.export(algo.policy.actor, dummy_input, path,
input_names=['obs'], output_names=['logits'])
elif hasattr(algo, 'q_network'): # DQN
torch.onnx.export(algo.q_network, dummy_input, path,
input_names=['obs'], output_names=['q_values'])
2.2 Candle ONNX Import¶
File: crates/rlox-candle/src/onnx.rs
Candle has built-in ONNX support via candle-onnx:
use candle_core::{Device, Tensor};
use candle_onnx::onnx;
pub struct OnnxPolicy {
model: onnx::ONNXModel,
device: Device,
}
impl OnnxPolicy {
pub fn load(path: &str, device: Device) -> Result<Self, candle_core::Error> {
let model = onnx::read_file(path)?;
Ok(Self { model, device })
}
pub fn forward(&self, obs: &Tensor) -> Result<Tensor, candle_core::Error> {
let inputs = std::collections::HashMap::from([
("obs".to_string(), obs.clone())
]);
self.model.forward(&inputs)
}
}
2.3 Candle Inference Server¶
File: New crates/rlox-candle/src/server.rs
Pure-Rust inference endpoint:
pub struct CandleInferenceServer {
policy: OnnxPolicy,
}
impl CandleInferenceServer {
pub fn predict(&self, obs: &[f32]) -> Vec<f32> {
let tensor = Tensor::from_slice(obs, &[1, obs.len()], &self.policy.device)?;
let output = self.policy.forward(&tensor)?;
output.to_vec1()?
}
}
Effort: 5-7 days Impact: PyTorch-free inference deployment, sub-microsecond policy latency
Phase 3: Hybrid Collection Loop (Long-term, 2x PPO speedup)¶
Move the entire PPO collection loop into Rust with Candle inference.
sequenceDiagram
participant Python as Python (SGD only)
participant Rust as Rust Collection Loop
participant Candle as Candle Policy
participant VecEnv as Rust VecEnv
participant GAE as Rust GAE
Python->>Rust: collect(n_steps=128)
Note over Python: GIL released
loop 128 steps (pure Rust, no Python)
Rust->>Candle: act(obs) → actions, log_probs
Candle-->>Rust: actions, log_probs
Rust->>Candle: get_value(obs) → values
Candle-->>Rust: values
Rust->>VecEnv: step_all(actions)
VecEnv-->>Rust: next_obs, rewards, dones
end
Rust->>GAE: compute_gae_batched(...)
GAE-->>Rust: advantages, returns
Rust-->>Python: RolloutBatch (single FFI return)
Note over Python: SGD in PyTorch
Python->>Python: for epoch in n_epochs
Python->>Python: loss.backward(), optimizer.step()
Python->>Rust: sync_weights(flat_params)
Note over Rust: scatter into Candle model
3.1 Weight Synchronization (PyTorch → Candle)¶
Strategy: Flat f32 buffer transfer. PyTorch concatenates all parameters into one contiguous array; Rust scatters them into Candle's Linear layers.
flowchart TD
A[PyTorch model.parameters] --> B[torch.cat + flatten]
B --> C[numpy array f32]
C --> D[PyO3 zero-copy slice]
D --> E[Rust scatter into Candle Linear layers]
E --> F[Candle model ready for inference]
Python side:
def get_flat_params(model):
return torch.cat([p.data.view(-1) for p in model.parameters()]).numpy()
# After each SGD phase:
rust_collector.sync_weights(get_flat_params(policy))
Rust side (crates/rlox-candle/src/sync.rs):
pub fn scatter_weights(model: &mut CandleActorCritic, flat: &[f32]) {
let mut offset = 0;
for (name, tensor) in model.named_parameters() {
let numel = tensor.elem_count();
let slice = &flat[offset..offset + numel];
let shape = tensor.shape().clone();
*tensor = Tensor::from_slice(slice, shape, &Device::Cpu).unwrap();
offset += numel;
}
debug_assert_eq!(offset, flat.len(), "weight size mismatch");
}
Cost: ~52KB memcpy for 64-hidden MLP, negligible vs rollout time.
3.2 Candle Collection Loop¶
File: New crates/rlox-candle/src/collector.rs
pub struct CandleCollector {
policy: CandleActorCritic,
envs: Box<dyn BatchSteppable>,
obs_normalizer: Option<RunningMeanStd>,
}
impl CandleCollector {
pub fn collect(&mut self, n_steps: usize, gamma: f64, gae_lambda: f64) -> RolloutBatch {
let n_envs = self.envs.num_envs();
// Pre-allocate everything
let mut all_obs = Vec::with_capacity(n_steps * n_envs * obs_dim);
let mut all_actions = Vec::with_capacity(n_steps * n_envs);
let mut all_log_probs = Vec::with_capacity(n_steps * n_envs);
// ...
for _step in 0..n_steps {
// Candle inference (no Python, no GIL)
let obs_tensor = Tensor::from_slice(¤t_obs, &[n_envs, obs_dim], &Device::Cpu)?;
let (actions, log_probs) = self.policy.act(&obs_tensor)?;
let values = self.policy.get_value(&obs_tensor)?;
// Env step (Rust VecEnv)
let transition = self.envs.step_batch(&action_vec)?;
// Store data
all_obs.extend_from_slice(¤t_obs);
// ...
}
// GAE in Rust
let (advantages, returns) = compute_gae_batched(...);
RolloutBatch { obs: all_obs, actions: all_actions, ... }
}
}
3.3 CUDA Backend¶
Candle supports CUDA natively. The same collection loop works on GPU:
let device = Device::new_cuda(0)?; // or Device::Cpu
let policy = CandleActorCritic::new(obs_dim, act_dim, hidden, &device)?;
For GPU collection: - Observations stay on GPU (no CPU→GPU transfer per step) - Policy inference on GPU (batched CUDA kernels) - Only the final RolloutBatch transfers to CPU for PyTorch SGD
3.4 PyO3 Integration¶
File: crates/rlox-python/src/candle_collector.rs
#[pyclass]
pub struct PyCandleCollector {
inner: CandleCollector,
}
#[pymethods]
impl PyCandleCollector {
#[new]
fn new(env_id: &str, n_envs: usize, obs_dim: usize, act_dim: usize,
hidden: usize, device: &str) -> PyResult<Self> { ... }
fn collect<'py>(&mut self, py: Python<'py>, n_steps: usize,
gamma: f64, gae_lambda: f64) -> PyResult<Bound<'py, PyDict>> {
// Release GIL during collection
let batch = py.allow_threads(|| {
self.inner.collect(n_steps, gamma, gae_lambda)
})?;
// Convert to Python dict
...
}
fn sync_weights(&mut self, flat_params: PyReadonlyArray1<f32>) -> PyResult<()> {
scatter_weights(&mut self.inner.policy, flat_params.as_slice()?);
Ok(())
}
}
3.5 Python PPO Integration¶
from rlox import CandleCollector
collector = CandleCollector(
env_id="CartPole-v1", n_envs=8,
obs_dim=4, act_dim=2, hidden=64,
device="cpu", # or "cuda:0"
)
# Initial weight sync
collector.sync_weights(get_flat_params(policy))
for update in range(n_updates):
# One FFI call — entire collection in Rust+Candle
batch = collector.collect(n_steps=128, gamma=0.99, gae_lambda=0.95)
# SGD in PyTorch (unchanged)
for epoch in range(n_epochs):
for mb in batch.sample_minibatches(batch_size):
loss = ppo_loss(policy, mb)
loss.backward()
optimizer.step()
# Sync updated weights back to Candle
collector.sync_weights(get_flat_params(policy))
Performance Estimates¶
graph LR
subgraph Current[Current: 219 us/step]
A1[PyTorch inference<br>123 us] --> A2[VecEnv step<br>96 us]
end
subgraph Phase1[Phase 1: ~165 us/step]
B1[torch.compile<br>~70 us] --> B2[VecEnv step<br>96 us]
end
subgraph Phase3[Phase 3: ~111 us/step]
C1[Candle inference<br>~15 us] --> C2[VecEnv step<br>96 us]
end
| Phase | Inference | Step | Total | PPO Speedup |
|---|---|---|---|---|
| Current | 123 us (PyTorch) | 96 us | 219 us/step | 1.0x |
| Phase 1 (torch.compile) | ~70 us | 96 us | ~166 us/step | 1.3x |
| Phase 3 (Candle CPU) | ~15 us | 96 us | ~111 us/step | 2.0x |
| Phase 3 (Candle CUDA) | ~5 us | 96 us | ~101 us/step | 2.2x |
Dependencies¶
graph TD
P1[Phase 1: torch.compile] --> P2[Phase 2: ONNX export/import]
P2 --> P3A[Phase 3A: Weight sync]
P3A --> P3B[Phase 3B: Candle collector]
P3B --> P3C[Phase 3C: CUDA backend]
P3C --> P3D[Phase 3D: Python integration]
style P1 fill:#90EE90
style P2 fill:#FFE4B5
style P3A fill:#FFB6C1
style P3B fill:#FFB6C1
style P3C fill:#FFB6C1
style P3D fill:#FFB6C1
Cargo Dependencies¶
# crates/rlox-candle/Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-onnx = "0.8" # Phase 2
candle-transformers = "0.8" # Optional, for LLM inference
[features]
default = ["cpu"]
cpu = []
cuda = ["candle-core/cuda"]
collector = [] # Phase 3: collection loop
onnx = ["candle-onnx"] # Phase 2: ONNX import
Files to Create/Modify¶
| Phase | File | Action |
|---|---|---|
| 1 | python/rlox/algorithms/*.py |
Add compile param |
| 2 | python/rlox/export.py |
New: ONNX export |
| 2 | crates/rlox-candle/src/onnx.rs |
New: ONNX import |
| 2 | crates/rlox-candle/src/server.rs |
New: inference server |
| 3 | crates/rlox-candle/src/sync.rs |
New: weight sync |
| 3 | crates/rlox-candle/src/collector.rs |
New: collection loop |
| 3 | crates/rlox-python/src/candle_collector.rs |
New: PyO3 binding |
| 3 | python/rlox/hybrid_ppo.py |
New: hybrid PPO trainer |
Timeline¶
| Phase | Duration | Prerequisite |
|---|---|---|
| Phase 1 | 1-2 days | None |
| Phase 2 | 5-7 days | Phase 1 |
| Phase 3 | 10-15 days | Phase 2 |
| Total | 16-24 days |