rlox Examples¶
Quick Start¶
Train PPO on CartPole (2 lines)¶
CLI¶
python -m rlox train --algo ppo --env CartPole-v1 --timesteps 100000
python -m rlox train --algo sac --env Pendulum-v1 --timesteps 50000
python -m rlox eval --algo ppo --checkpoint model.pt --env CartPole-v1
RL Algorithms¶
PPO with Callbacks¶
from rlox.algorithms.ppo import PPO
from rlox.callbacks import ProgressBarCallback, TimingCallback
from rlox.logging import ConsoleLogger
ppo = PPO(
env_id="CartPole-v1",
n_envs=8,
learning_rate=3e-4,
n_steps=128,
n_epochs=4,
callbacks=[ProgressBarCallback(), TimingCallback()],
logger=ConsoleLogger(log_interval=1000),
seed=42,
)
metrics = ppo.train(total_timesteps=100_000)
# Check training phase breakdown
timing_cb = ppo.callbacks.callbacks[1]
print(timing_cb.summary())
# {'env_step': 45.2, 'gae_compute': 5.1, 'gradient_update': 49.7}
SAC on MuJoCo¶
from rlox.algorithms.sac import SAC
sac = SAC(
env_id="HalfCheetah-v4",
buffer_size=1_000_000,
learning_rate=3e-4,
batch_size=256,
tau=0.005,
gamma=0.99,
learning_starts=10_000,
hidden=256,
seed=42,
)
metrics = sac.train(total_timesteps=1_000_000)
# Get actions for evaluation
import numpy as np
obs = np.zeros(17, dtype=np.float32)
action = sac.predict(obs, deterministic=True) # scaled to env range
# Save/load
sac.save("sac_halfcheetah.pt")
SAC with Custom Environment¶
import gymnasium as gym
from rlox.algorithms.sac import SAC
# Pass a pre-constructed environment
env = gym.make("Pendulum-v1", g=5.0) # custom gravity
sac = SAC(env_id=env, learning_starts=1000)
sac.train(total_timesteps=50_000)
Multi-Environment Off-Policy Training¶
All off-policy algorithms support parallel data collection with n_envs:
from rlox.algorithms.sac import SAC
from rlox.algorithms.td3 import TD3
from rlox.algorithms.dqn import DQN
# SAC with 4 parallel environments
sac = SAC(env_id="HalfCheetah-v4", n_envs=4, learning_starts=10_000)
metrics = sac.train(total_timesteps=1_000_000)
# TD3 with 4 parallel environments
td3 = TD3(env_id="Pendulum-v1", n_envs=4, learning_starts=1000)
metrics = td3.train(total_timesteps=50_000)
# DQN with 8 parallel environments
dqn = DQN(env_id="CartPole-v1", n_envs=8, learning_starts=1000)
metrics = dqn.train(total_timesteps=100_000)
Custom Collector with Exploration¶
import rlox
from rlox.algorithms.sac import SAC
from rlox.off_policy_collector import OffPolicyCollector
from rlox.exploration import GaussianNoise
# Share buffer between collector and algorithm
buf = rlox.ReplayBuffer(1_000_000, obs_dim=3, act_dim=1)
collector = OffPolicyCollector(
env_id="Pendulum-v1",
n_envs=4,
buffer=buf,
exploration=GaussianNoise(sigma=0.1, clip=0.3),
)
sac = SAC(env_id="Pendulum-v1", buffer=buf, collector=collector)
metrics = sac.train(total_timesteps=50_000)
TD3 on Pendulum¶
from rlox.algorithms.td3 import TD3
td3 = TD3(
env_id="Pendulum-v1",
learning_rate=3e-4,
policy_delay=2,
target_noise=0.2,
noise_clip=0.5,
exploration_noise=0.1,
)
metrics = td3.train(total_timesteps=50_000)
action = td3.predict(obs)
DQN with Prioritized Experience Replay¶
from rlox.algorithms.dqn import DQN
dqn = DQN(
env_id="CartPole-v1",
double_dqn=True,
dueling=True,
prioritized=True,
n_step=3,
buffer_size=100_000,
learning_starts=1000,
)
metrics = dqn.train(total_timesteps=100_000)
action = dqn.predict(obs) # returns int
A2C¶
from rlox.algorithms.a2c import A2C
a2c = A2C(env_id="CartPole-v1", n_envs=8, learning_rate=7e-4)
metrics = a2c.train(total_timesteps=50_000)
A2CTrainer (High-Level)¶
TD3Trainer (High-Level)¶
MAPPO (Multi-Agent)¶
Multi-agent PPO with centralised critic. Works with PettingZoo environments.
from rlox import Trainer
trainer = Trainer("mappo", env="spread_v3", config={"n_agents": 3}, seed=42)
metrics = trainer.train(total_timesteps=500_000)
DreamerV3Trainer (World Model)¶
Learns a latent dynamics model and trains the policy inside the learned world model. Particularly effective for image-based and sparse-reward environments.
from rlox import Trainer
trainer = Trainer("dreamer", env="Pendulum-v1", seed=42)
metrics = trainer.train(total_timesteps=200_000)
IMPALATrainer (Distributed Actors)¶
Distributed actor-learner architecture with V-trace off-policy correction. Actors collect experience in parallel (optionally across machines via gRPC) while a central learner trains the policy.
from rlox import Trainer
trainer = Trainer("impala", env="CartPole-v1", n_actors=8, seed=42)
metrics = trainer.train(total_timesteps=1_000_000)
Config-Driven Training¶
Define experiments in YAML and launch without writing Python scripts:
# experiment.yaml
algorithm: ppo
env: CartPole-v1
total_timesteps: 100_000
seed: 42
config:
n_envs: 16
learning_rate: 3e-4
VecNormalize¶
Running observation and reward normalisation for vectorised environments:
from rlox import Trainer
from rlox.wrappers import VecNormalize
trainer = Trainer(
"ppo", env="HalfCheetah-v4",
config={"normalize_obs": True, "normalize_rewards": True},
seed=42,
)
metrics = trainer.train(total_timesteps=1_000_000)
Diagnostics Dashboard & HTML Report¶
from rlox.dashboard import MetricsCollector, HTMLReport, TerminalDashboard
from rlox import Trainer
# Collect metrics and generate an HTML report after training
collector = MetricsCollector()
trainer = Trainer("ppo", env="CartPole-v1", callbacks=[collector], seed=42)
trainer.train(total_timesteps=50_000)
report = HTMLReport(collector)
report.save("training_report.html")
# Or use a live terminal dashboard during training
trainer = Trainer(
"ppo", env="CartPole-v1",
callbacks=[TerminalDashboard()],
seed=42,
)
trainer.train(total_timesteps=50_000)
LLM Post-Training¶
See llm-post-training.md for the full guide.
DPO Quick Example¶
from rlox.algorithms.dpo import DPO
dpo = DPO(model=model, ref_model=ref_model, beta=0.1)
for prompt, chosen, rejected in dataset:
metrics = dpo.train_step(prompt, chosen, rejected)
GRPO Quick Example¶
from rlox.algorithms.grpo import GRPO
grpo = GRPO(
model=model, ref_model=ref_model,
reward_fn=my_reward_fn, group_size=4,
)
for prompts in dataloader:
metrics = grpo.train_step(prompts)
Core Primitives¶
Environment Stepping¶
import rlox
# Native Rust CartPole (fastest)
env = rlox.VecEnv(n=1024, seed=42)
result = env.step_all(actions) # actions: numpy uint32 array
# result["obs"]: (1024, 4) float32
# Gymnasium wrapper (any env)
from rlox import GymVecEnv
env = GymVecEnv("HalfCheetah-v4", n_envs=8)
result = env.step_all(actions)
use rlox_core::env::builtins::CartPole;
use rlox_core::env::parallel::VecEnv;
use rlox_core::env::spaces::Action;
use rlox_core::env::RLEnv;
use rlox_core::seed::derive_seed;
// Create 64 parallel CartPole environments
let envs: Vec<Box<dyn RLEnv>> = (0..64)
.map(|i| Box::new(CartPole::new(Some(derive_seed(42, i)))) as _)
.collect();
let mut vec_env = VecEnv::new(envs).unwrap();
let observations = vec_env.reset_all(Some(42)).unwrap();
let actions: Vec<Action> = (0..64)
.map(|i| Action::Discrete((i % 2) as u32))
.collect();
let batch = vec_env.step_all(&actions).unwrap();
assert_eq!(batch.obs.len(), 64);
Continuous Actions (Pendulum)¶
use rlox_core::env::builtins::Pendulum;
use rlox_core::env::spaces::Action;
use rlox_core::env::RLEnv;
let mut env = Pendulum::new(Some(42));
let obs = env.reset(Some(42)).unwrap();
println!("obs: {:?}", obs.as_slice()); // [cos θ, sin θ, ω]
let t = env.step(&Action::Continuous(vec![1.5])).unwrap();
println!("reward: {:.2}", t.reward);
GAE Computation¶
import rlox
import numpy as np
rewards = np.random.randn(2048)
values = np.random.randn(2048)
dones = np.zeros(2048)
adv, ret = rlox.compute_gae(rewards, values, dones,
last_value=0.0, gamma=0.99, lam=0.95)
# Batched (8 envs, Rayon-parallel)
rewards_flat = np.random.randn(8 * 2048)
values_flat = np.random.randn(8 * 2048)
dones_flat = np.zeros(8 * 2048)
last_vals = np.random.randn(8)
adv, ret = rlox.compute_gae_batched(rewards_flat, values_flat,
dones_flat, last_vals, 2048, 0.99, 0.95)
use rlox_core::training::gae::compute_gae;
let rewards = &[1.0, 1.0, 1.0, 0.0, 1.0];
let values = &[0.5, 0.6, 0.7, 0.3, 0.8];
let dones = &[0.0, 0.0, 0.0, 1.0, 0.0];
let (advantages, returns) = compute_gae(
rewards, values, dones,
0.9, // last_value
0.99, // gamma
0.95, // gae_lambda
);
// Invariant: returns[t] == advantages[t] + values[t]
for t in 0..5 {
assert!((returns[t] - (advantages[t] + values[t])).abs() < 1e-10);
}
Replay Buffers¶
import rlox
import numpy as np
buf = rlox.ReplayBuffer(capacity=100_000, obs_dim=4, act_dim=1)
obs = np.zeros(4, dtype=np.float32)
buf.push(obs, np.zeros(1), reward=1.0, terminated=False,
truncated=False, next_obs=obs)
batch = buf.sample(batch_size=256, seed=42)
# batch["obs"], batch["actions"], batch["rewards"], ...
# Prioritized replay
pbuf = rlox.PrioritizedReplayBuffer(100_000, 4, 1, alpha=0.6, beta=0.4)
pbuf.push(obs, np.zeros(1), 1.0, False, False, obs, priority=1.0)
batch = pbuf.sample(256, seed=42)
# batch["weights"], batch["indices"]
use rlox_core::buffer::ringbuf::ReplayBuffer;
use rlox_core::buffer::priority::PrioritizedReplayBuffer;
use rlox_core::buffer::ExperienceRecord;
// Uniform replay buffer
let mut buf = ReplayBuffer::new(100_000, 4, 1);
buf.push(ExperienceRecord {
obs: vec![0.1, 0.2, 0.3, 0.4],
next_obs: vec![0.2, 0.3, 0.4, 0.5],
action: vec![0.0],
reward: 1.0,
terminated: false,
truncated: false,
}).unwrap();
let batch = buf.sample(32, 42).unwrap();
assert_eq!(batch.batch_size, 32);
// Prioritized replay
let mut per = PrioritizedReplayBuffer::new(100_000, 4, 1, 0.6, 0.4);
per.push(ExperienceRecord {
obs: vec![0.1; 4], next_obs: vec![0.2; 4],
action: vec![0.0], reward: 1.0,
terminated: false, truncated: false,
}, 1.0).unwrap();
Reward Shaping (PBRS)¶
import rlox
import numpy as np
rewards = np.array([1.0, 1.0, 1.0])
phi = np.array([0.5, 0.6, 0.7]) # potential of current state
phi_next = np.array([0.6, 0.7, 0.8]) # potential of next state
dones = np.array([0.0, 0.0, 1.0])
shaped = rlox.shape_rewards_pbrs(rewards, phi, phi_next, gamma=0.99, dones=dones)
use rlox_core::training::reward_shaping::shape_rewards_pbrs;
let rewards = &[1.0, 1.0, 1.0];
let phi = &[0.5, 0.6, 0.7];
let phi_next = &[0.6, 0.7, 0.8];
let dones = &[0.0, 0.0, 1.0];
let shaped = shape_rewards_pbrs(rewards, phi, phi_next, 0.99, dones).unwrap();
// done=1 → raw reward only (no shaping at episode boundary)
assert!((shaped[2] - 1.0).abs() < 1e-10);
Weight Operations (Meta-Learning)¶
use rlox_core::training::weight_ops::{reptile_update, polyak_update};
let mut meta = vec![1.0f32, 2.0, 3.0];
let task = vec![4.0f32, 5.0, 6.0];
reptile_update(&mut meta, &task, 0.1);
// meta is now [1.3, 2.3, 3.3]
let mut target = vec![0.0f32; 3];
let source = vec![1.0f32; 3];
polyak_update(&mut target, &source, 0.005);
KL Divergence¶
Plugin Ecosystem¶
Register a Custom Environment¶
from rlox.plugins import ENV_REGISTRY
from rlox import Trainer
# Define a custom environment factory
def make_custom_env():
import gymnasium as gym
env = gym.make("CartPole-v1")
# Add custom wrappers, reward shaping, etc.
return env
# Register it
ENV_REGISTRY.register("my-custom-env", make_custom_env)
# Use it with any Trainer
trainer = Trainer("ppo", env="my-custom-env", seed=42)
metrics = trainer.train(total_timesteps=50_000)
Discover Plugins from Installed Packages¶
from rlox.plugins import discover_plugins, ENV_REGISTRY
# Auto-discover plugins from all installed packages
discover_plugins()
# Now any plugin-registered environments are available
print(ENV_REGISTRY.list()) # shows all registered env names
Visual RL¶
FrameStack + ImagePreprocess¶
import gymnasium as gym
from rlox.wrappers.visual import FrameStack, ImagePreprocess
from rlox import Trainer
# Standard visual RL preprocessing pipeline
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array")
env = ImagePreprocess(env, width=84, height=84, grayscale=True)
env = FrameStack(env, n_frames=4)
# Train with preprocessed pixel observations
trainer = Trainer("dqn", env=env, config={"buffer_size": 100_000}, seed=42)
metrics = trainer.train(total_timesteps=1_000_000)
AtariWrapper (All-in-One)¶
import gymnasium as gym
from rlox.wrappers.visual import AtariWrapper
# AtariWrapper applies: NoopReset, MaxAndSkip, EpisodicLife,
# FireReset, ClipReward, ImagePreprocess, FrameStack
env = AtariWrapper(gym.make("ALE/Pong-v5"), frame_stack=4)
Cloud Deploy¶
Generate a Dockerfile¶
from rlox.deploy import generate_dockerfile
dockerfile = generate_dockerfile(
checkpoint_path="checkpoints/ppo_cartpole.pt",
algorithm="ppo",
env="CartPole-v1",
)
with open("Dockerfile", "w") as f:
f.write(dockerfile)
Generate a Kubernetes Job¶
from rlox.deploy import generate_k8s_job
manifest = generate_k8s_job(
name="rlox-ppo-training",
image="my-registry/rlox:latest",
gpu=1,
memory="16Gi",
)
with open("k8s-job.yaml", "w") as f:
f.write(manifest)
Monitoring & Profiling¶
Console Logger¶
from rlox.logging import ConsoleLogger
logger = ConsoleLogger(log_interval=500)
# Prints: step=500 | SPS=1234 | reward=45.20
W&B Integration¶
from rlox.logging import WandbLogger
logger = WandbLogger(project="my-rl-project", config={"algo": "ppo"})
TensorBoard¶
Profiling with TimingCallback¶
from rlox.callbacks import TimingCallback
timing = TimingCallback()
ppo = PPO(env_id="CartPole-v1", callbacks=[timing])
ppo.train(50_000)
print(timing.summary())
# {'env_step': 42.1, 'gae_compute': 8.3, 'gradient_update': 49.6}