Algorithms API¶
On-Policy¶
rlox.algorithms.ppo.PPO
¶
PPO trainer.
Auto-detects observation and action dimensions from the environment. Selects DiscretePolicy or ContinuousPolicy based on action space type. Supports callbacks, logging, and checkpoint save/load.
Parameters¶
env_id : str Gymnasium environment ID. n_envs : int Number of parallel environments (default 8). seed : int Random seed. policy : nn.Module, optional Custom policy network. If None, auto-selects based on env. logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks. **config_kwargs Override any PPOConfig fields.
rlox.algorithms.a2c.A2C
¶
Synchronous A2C.
Like PPO but: no ratio clipping, no epochs (single gradient step per
rollout), and typically shorter n_steps.
Auto-detects observation and action dimensions from the environment.
Off-Policy¶
rlox.algorithms.sac.SAC
¶
Soft Actor-Critic.
Twin critics, squashed Gaussian policy, automatic entropy tuning. Uses rlox.ReplayBuffer for storage.
rlox.algorithms.td3.TD3
¶
Twin Delayed DDPG.
Deterministic policy, target policy smoothing, delayed policy updates. Uses rlox.ReplayBuffer for storage.
rlox.algorithms.dqn.DQN
¶
DQN with optional Double DQN, Dueling architecture, N-step returns, and Prioritized Experience Replay.
Uses rlox.ReplayBuffer or rlox.PrioritizedReplayBuffer for storage.
Offline RL¶
rlox.algorithms.td3_bc.TD3BC
¶
Bases: OfflineAlgorithm
TD3+BC offline RL algorithm.
Parameters¶
dataset : OfflineDataset
Offline dataset (e.g., rlox.OfflineDatasetBuffer).
obs_dim : int
Observation dimension.
act_dim : int
Action dimension.
alpha : float
BC regularization weight (default 2.5).
hidden : int
Hidden layer width (default 256).
learning_rate : float
Learning rate (default 3e-4).
tau : float
Soft target update rate (default 0.005).
gamma : float
Discount factor (default 0.99).
policy_delay : int
Actor update frequency (default 2).
target_noise : float
Target policy smoothing noise (default 0.2).
noise_clip : float
Target noise clipping (default 0.5).
act_high : float
Action space upper bound for clipping (default 1.0).
batch_size : int
Minibatch size (default 256).
callbacks : list[Callback], optional
logger : LoggerCallback, optional
predict(obs: np.ndarray) -> np.ndarray
¶
Get deterministic action.
rlox.algorithms.iql.IQL
¶
Bases: OfflineAlgorithm
Implicit Q-Learning.
Parameters¶
dataset : OfflineDataset Offline dataset. obs_dim : int Observation dimension. act_dim : int Action dimension. expectile : float Expectile τ for value function regression (default 0.7). temperature : float β for advantage-weighted actor extraction (default 3.0). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). tau : float Soft target update rate (default 0.005). gamma : float Discount factor (default 0.99). batch_size : int Minibatch size (default 256). callbacks : list[Callback], optional logger : LoggerCallback, optional
predict(obs: np.ndarray) -> np.ndarray
¶
Get deterministic action.
rlox.algorithms.cql.CQL
¶
Bases: OfflineAlgorithm
Conservative Q-Learning.
Parameters¶
dataset : OfflineDataset Offline dataset. obs_dim : int Observation dimension. act_dim : int Action dimension. cql_alpha : float CQL penalty weight (default 5.0). n_random_actions : int Number of random actions for CQL penalty (default 10). auto_alpha : bool Whether to auto-tune cql_alpha via Lagrangian (default False). cql_target_value : float Target value for Lagrangian α tuning (default -1.0). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). tau : float Soft target update rate (default 0.005). gamma : float Discount factor (default 0.99). batch_size : int Minibatch size (default 256). auto_entropy : bool Whether to auto-tune SAC entropy α (default True). target_entropy : float, optional Target entropy for SAC (default -act_dim). callbacks : list[Callback], optional logger : LoggerCallback, optional
predict(obs: np.ndarray, deterministic: bool = True) -> np.ndarray
¶
Get action from the policy.
rlox.algorithms.bc.BC
¶
Bases: OfflineAlgorithm
Behavioral Cloning.
Parameters¶
dataset : OfflineDataset Offline dataset with expert demonstrations. obs_dim : int Observation dimension. act_dim : int Action dimension (or number of discrete actions). continuous : bool Whether the action space is continuous (default True). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). batch_size : int Minibatch size (default 256). callbacks : list[Callback], optional logger : LoggerCallback, optional
predict(obs: np.ndarray) -> np.ndarray
¶
Get action from the learned policy.
LLM Post-Training¶
rlox.algorithms.grpo.GRPO
¶
GRPO trainer for language models.
Parameters¶
model : nn.Module
Language model with forward(input_ids) -> logits and
generate(prompt_ids, max_new_tokens) -> token_ids.
ref_model : nn.Module
Frozen reference model (same interface as model).
reward_fn : callable
(completions: list[torch.Tensor], prompts: torch.Tensor) -> list[float]
group_size : int
Number of completions per prompt.
kl_coef : float
KL penalty coefficient.
learning_rate : float
Optimiser learning rate.
max_new_tokens : int
Maximum number of tokens to generate per prompt.
callbacks : list[Callback], optional
Training callbacks.
logger : LoggerCallback, optional
Logger for metrics.
train_step(prompts: torch.Tensor) -> dict[str, float]
¶
One GRPO update on a batch of prompts.
Uses batched group advantage computation via Rust for all prompts at once, avoiding a Python loop per prompt.
train(prompts: torch.Tensor, n_epochs: int = 1) -> dict[str, float]
¶
Train over all prompts for n_epochs.
evaluate(prompts: torch.Tensor) -> float
¶
Return mean reward over prompts (single generation per prompt).
save(path: str) -> None
¶
Save training checkpoint.
rlox.algorithms.dpo.DPO
¶
DPO trainer for language models.
Parameters¶
model : nn.Module
Language model with forward(input_ids) -> logits.
ref_model : nn.Module
Frozen reference model.
beta : float
Temperature parameter for the DPO loss.
learning_rate : float
Optimiser learning rate.
callbacks : list[Callback], optional
Training callbacks.
logger : LoggerCallback, optional
Logger for metrics.
rlox.algorithms.online_dpo.OnlineDPO
¶
Online Direct Preference Optimization.
Each step: generate pairs of candidates per prompt, query a preference function, then run a DPO update on the preferred/rejected pair.
train_step(prompts: torch.Tensor) -> dict[str, float]
¶
One OnlineDPO update on a batch of prompts.
rlox.algorithms.best_of_n.BestOfN
¶
Best-of-N rejection sampling.
Generate N completions per prompt, score them with a reward function, and return the best completion for each prompt.
Multi-Agent¶
rlox.algorithms.mappo.MAPPO
¶
Multi-Agent PPO with centralized critic, decentralized actors.
Auto-detects observation and action dimensions from the environment. For n_agents=1, this reduces to standard PPO on single-agent envs. For n_agents>1, uses MultiAgentCollector with PettingZoo parallel envs.
Parameters¶
env_id : str Gymnasium environment ID (for n_agents=1) or PettingZoo env factory. n_agents : int Number of agents (default 1). n_envs : int Number of parallel environments (default 4). seed : int Random seed (default 42). n_steps : int Rollout length per environment per update (default 64). n_epochs : int Number of SGD passes per rollout (default 4). batch_size : int Minibatch size for SGD (default 128). learning_rate : float Adam learning rate (default 2.5e-4). gamma : float Discount factor (default 0.99). gae_lambda : float GAE lambda (default 0.95). clip_eps : float PPO clipping range (default 0.2). vf_coef : float Value loss coefficient (default 0.5). ent_coef : float Entropy bonus coefficient (default 0.01). max_grad_norm : float Maximum gradient norm for clipping (default 0.5). env_fn : Callable, optional PettingZoo parallel env factory for n_agents > 1. logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks.
train(total_timesteps: int) -> dict[str, float]
¶
Run MAPPO training loop.
predict(obs: Any, deterministic: bool = True, agent_idx: int = 0) -> np.ndarray | int
¶
save(path: str) -> None
¶
Save training checkpoint.
from_checkpoint(path: str, env_id: str | None = None) -> MAPPO
classmethod
¶
Restore MAPPO from a checkpoint.
Model-Based¶
rlox.algorithms.dreamer.DreamerV3
¶
DreamerV3 agent with RSSM world model and latent actor-critic.
Components: - RSSM world model (GRU dynamics + categorical latents) - Symlog transforms for reward/value prediction - Actor-critic in latent space with imagination rollouts - Sequence replay from buffer - Continuous and discrete action support
Parameters¶
env_id : str Gymnasium environment ID. n_envs : int Number of parallel environments (default 1). seed : int Random seed (default 42). latent_dim : int Latent feature dimension -- ignored, computed from RSSM dims. imagination_horizon : int Steps to imagine for policy learning (default 5). buffer_size : int Replay buffer capacity (default 10000). batch_size : int Number of sequences per training batch (default 32). seq_len : int Sequence length for world model training (default 16). learning_rate : float Learning rate for all optimisers (default 3e-4). gamma : float Discount factor (default 0.99). obs_dim : int Observation dimensionality (default 4, auto-detected if possible). n_actions : int Number of actions or action dim (default 2, auto-detected if possible). deter_dim : int RSSM deterministic state dim (default 64). stoch_dim : int Number of categorical distributions (default 8). stoch_classes : int Classes per categorical (default 8). kl_balance : float KL balancing coefficient (default 0.8). free_nats : float Free nats for KL loss (default 1.0).
Distributed¶
rlox.algorithms.impala.IMPALA
¶
IMPALA with V-trace off-policy correction.
Auto-detects observation and action dimensions from the environment. Actors collect data in parallel threads with the current policy snapshot. The learner applies V-trace corrected updates.
V-trace computation is vectorized: all environments are processed in a single batched call rather than looping per environment.
Parameters¶
env_id : str
Gymnasium environment ID.
n_actors : int
Number of actor threads (default 2).
n_envs : int
Number of environments per actor (default 2).
seed : int
Random seed (default 42).
n_steps : int
Rollout length per actor per batch (default 32).
learning_rate : float
RMSprop learning rate (default 5e-4).
gamma : float
Discount factor (default 0.99).
rho_bar : float
V-trace truncation for importance weights (default 1.0).
c_bar : float
V-trace truncation for trace coefficients (default 1.0).
vf_coef : float
Value loss coefficient (default 0.5).
ent_coef : float
Entropy bonus coefficient (default 0.01).
max_grad_norm : float
Maximum gradient norm for clipping (default 40.0).
logger : LoggerCallback, optional
Logger for metrics.
callbacks : list[Callback], optional
Training callbacks.
worker_addresses : list[str] | None
If provided, each actor uses a :class:RemoteEnvPool connecting
to a subset of these gRPC worker addresses instead of local
VecEnv/GymVecEnv. The addresses are partitioned evenly across
actors, so len(worker_addresses) >= n_actors is required.
train(total_timesteps: int) -> dict[str, float]
¶
Run IMPALA training loop.
predict(obs: Any, deterministic: bool = True) -> np.ndarray | int
¶
save(path: str) -> None
¶
Save training checkpoint.
from_checkpoint(path: str, env_id: str | None = None) -> IMPALA
classmethod
¶
Restore IMPALA from a checkpoint.
Hybrid¶
rlox.algorithms.hybrid_ppo.HybridPPO
¶
PPO with Rust-native collection via Candle.
The collection loop (env stepping + policy inference + GAE computation) runs in a background Rust thread with zero Python overhead. PyTorch handles only the training backward pass.
Parameters¶
env_id : str Environment ID. Currently only "CartPole-v1" (native Rust env). n_envs : int Number of parallel environments (default 16). seed : int Random seed. hidden : int Hidden layer width for both Candle (collection) and PyTorch (training). logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks. **config_kwargs Override any PPOConfig fields (n_steps, learning_rate, etc.).
train(total_timesteps: int) -> dict[str, float]
¶
Run hybrid PPO training.
Uses collect-then-train pattern: stop the Rust collection thread during PyTorch training to avoid CPU thread contention between Candle/Rayon and PyTorch/Accelerate thread pools on macOS.
predict(obs: np.ndarray, deterministic: bool = True) -> int
¶
Get action from PyTorch policy.
timing_summary() -> dict[str, float]
¶
Return collection vs training time breakdown.