Skip to content

Algorithms API

On-Policy

rlox.algorithms.ppo.PPO

PPO trainer.

Auto-detects observation and action dimensions from the environment. Selects DiscretePolicy or ContinuousPolicy based on action space type. Supports callbacks, logging, and checkpoint save/load.

Parameters

env_id : str Gymnasium environment ID. n_envs : int Number of parallel environments (default 8). seed : int Random seed. policy : nn.Module, optional Custom policy network. If None, auto-selects based on env. logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks. **config_kwargs Override any PPOConfig fields.

train(total_timesteps: int) -> dict[str, float]

Run PPO training and return final metrics.

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> PPO classmethod

Restore PPO from a checkpoint.

Parameters

path : str Path to the checkpoint file. env_id : str, optional Environment ID. If None, uses the one stored in the checkpoint config.

rlox.algorithms.a2c.A2C

Synchronous A2C.

Like PPO but: no ratio clipping, no epochs (single gradient step per rollout), and typically shorter n_steps.

Auto-detects observation and action dimensions from the environment.

predict(obs: Any, deterministic: bool = True) -> np.ndarray | int

Get action from the trained policy.

Parameters

obs : array-like Observation. deterministic : bool If True, return the mode of the action distribution.

Returns

Action as an int (discrete) or numpy array (continuous).

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> A2C classmethod

Restore A2C from a checkpoint.

Off-Policy

rlox.algorithms.sac.SAC

Soft Actor-Critic.

Twin critics, squashed Gaussian policy, automatic entropy tuning. Uses rlox.ReplayBuffer for storage.

predict(obs: np.ndarray, deterministic: bool = True) -> np.ndarray

Get action from the policy.

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> SAC classmethod

Restore SAC from a checkpoint.

rlox.algorithms.td3.TD3

Twin Delayed DDPG.

Deterministic policy, target policy smoothing, delayed policy updates. Uses rlox.ReplayBuffer for storage.

predict(obs: np.ndarray, deterministic: bool = True) -> np.ndarray

Get action from the policy (always deterministic for TD3).

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> TD3 classmethod

Restore TD3 from a checkpoint.

rlox.algorithms.dqn.DQN

DQN with optional Double DQN, Dueling architecture, N-step returns, and Prioritized Experience Replay.

Uses rlox.ReplayBuffer or rlox.PrioritizedReplayBuffer for storage.

predict(obs: np.ndarray, deterministic: bool = True) -> int

Get action from the policy.

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> DQN classmethod

Restore DQN from a checkpoint.

Offline RL

rlox.algorithms.td3_bc.TD3BC

Bases: OfflineAlgorithm

TD3+BC offline RL algorithm.

Parameters

dataset : OfflineDataset Offline dataset (e.g., rlox.OfflineDatasetBuffer). obs_dim : int Observation dimension. act_dim : int Action dimension. alpha : float BC regularization weight (default 2.5). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). tau : float Soft target update rate (default 0.005). gamma : float Discount factor (default 0.99). policy_delay : int Actor update frequency (default 2). target_noise : float Target policy smoothing noise (default 0.2). noise_clip : float Target noise clipping (default 0.5). act_high : float Action space upper bound for clipping (default 1.0). batch_size : int Minibatch size (default 256). callbacks : list[Callback], optional logger : LoggerCallback, optional

predict(obs: np.ndarray) -> np.ndarray

Get deterministic action.

rlox.algorithms.iql.IQL

Bases: OfflineAlgorithm

Implicit Q-Learning.

Parameters

dataset : OfflineDataset Offline dataset. obs_dim : int Observation dimension. act_dim : int Action dimension. expectile : float Expectile τ for value function regression (default 0.7). temperature : float β for advantage-weighted actor extraction (default 3.0). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). tau : float Soft target update rate (default 0.005). gamma : float Discount factor (default 0.99). batch_size : int Minibatch size (default 256). callbacks : list[Callback], optional logger : LoggerCallback, optional

predict(obs: np.ndarray) -> np.ndarray

Get deterministic action.

rlox.algorithms.cql.CQL

Bases: OfflineAlgorithm

Conservative Q-Learning.

Parameters

dataset : OfflineDataset Offline dataset. obs_dim : int Observation dimension. act_dim : int Action dimension. cql_alpha : float CQL penalty weight (default 5.0). n_random_actions : int Number of random actions for CQL penalty (default 10). auto_alpha : bool Whether to auto-tune cql_alpha via Lagrangian (default False). cql_target_value : float Target value for Lagrangian α tuning (default -1.0). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). tau : float Soft target update rate (default 0.005). gamma : float Discount factor (default 0.99). batch_size : int Minibatch size (default 256). auto_entropy : bool Whether to auto-tune SAC entropy α (default True). target_entropy : float, optional Target entropy for SAC (default -act_dim). callbacks : list[Callback], optional logger : LoggerCallback, optional

predict(obs: np.ndarray, deterministic: bool = True) -> np.ndarray

Get action from the policy.

rlox.algorithms.bc.BC

Bases: OfflineAlgorithm

Behavioral Cloning.

Parameters

dataset : OfflineDataset Offline dataset with expert demonstrations. obs_dim : int Observation dimension. act_dim : int Action dimension (or number of discrete actions). continuous : bool Whether the action space is continuous (default True). hidden : int Hidden layer width (default 256). learning_rate : float Learning rate (default 3e-4). batch_size : int Minibatch size (default 256). callbacks : list[Callback], optional logger : LoggerCallback, optional

predict(obs: np.ndarray) -> np.ndarray

Get action from the learned policy.

LLM Post-Training

rlox.algorithms.grpo.GRPO

GRPO trainer for language models.

Parameters

model : nn.Module Language model with forward(input_ids) -> logits and generate(prompt_ids, max_new_tokens) -> token_ids. ref_model : nn.Module Frozen reference model (same interface as model). reward_fn : callable (completions: list[torch.Tensor], prompts: torch.Tensor) -> list[float] group_size : int Number of completions per prompt. kl_coef : float KL penalty coefficient. learning_rate : float Optimiser learning rate. max_new_tokens : int Maximum number of tokens to generate per prompt. callbacks : list[Callback], optional Training callbacks. logger : LoggerCallback, optional Logger for metrics.

train_step(prompts: torch.Tensor) -> dict[str, float]

One GRPO update on a batch of prompts.

Uses batched group advantage computation via Rust for all prompts at once, avoiding a Python loop per prompt.

train(prompts: torch.Tensor, n_epochs: int = 1) -> dict[str, float]

Train over all prompts for n_epochs.

evaluate(prompts: torch.Tensor) -> float

Return mean reward over prompts (single generation per prompt).

save(path: str) -> None

Save training checkpoint.

rlox.algorithms.dpo.DPO

DPO trainer for language models.

Parameters

model : nn.Module Language model with forward(input_ids) -> logits. ref_model : nn.Module Frozen reference model. beta : float Temperature parameter for the DPO loss. learning_rate : float Optimiser learning rate. callbacks : list[Callback], optional Training callbacks. logger : LoggerCallback, optional Logger for metrics.

compute_loss(prompt: torch.Tensor, chosen: torch.Tensor, rejected: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]

Compute DPO loss.

Parameters

prompt : (B, P) chosen : (B, C) rejected : (B, R)

Returns

loss, metrics

train_step(prompt: torch.Tensor, chosen: torch.Tensor, rejected: torch.Tensor) -> dict[str, float]

One DPO gradient step.

save(path: str) -> None

Save training checkpoint.

rlox.algorithms.online_dpo.OnlineDPO

Online Direct Preference Optimization.

Each step: generate pairs of candidates per prompt, query a preference function, then run a DPO update on the preferred/rejected pair.

train_step(prompts: torch.Tensor) -> dict[str, float]

One OnlineDPO update on a batch of prompts.

rlox.algorithms.best_of_n.BestOfN

Best-of-N rejection sampling.

Generate N completions per prompt, score them with a reward function, and return the best completion for each prompt.

generate(prompts: torch.Tensor) -> torch.Tensor

Generate best-of-N completions for each prompt.

Parameters

prompts : (B, P) token ids

Returns

(B, P + T) best completions

Multi-Agent

rlox.algorithms.mappo.MAPPO

Multi-Agent PPO with centralized critic, decentralized actors.

Auto-detects observation and action dimensions from the environment. For n_agents=1, this reduces to standard PPO on single-agent envs. For n_agents>1, uses MultiAgentCollector with PettingZoo parallel envs.

Parameters

env_id : str Gymnasium environment ID (for n_agents=1) or PettingZoo env factory. n_agents : int Number of agents (default 1). n_envs : int Number of parallel environments (default 4). seed : int Random seed (default 42). n_steps : int Rollout length per environment per update (default 64). n_epochs : int Number of SGD passes per rollout (default 4). batch_size : int Minibatch size for SGD (default 128). learning_rate : float Adam learning rate (default 2.5e-4). gamma : float Discount factor (default 0.99). gae_lambda : float GAE lambda (default 0.95). clip_eps : float PPO clipping range (default 0.2). vf_coef : float Value loss coefficient (default 0.5). ent_coef : float Entropy bonus coefficient (default 0.01). max_grad_norm : float Maximum gradient norm for clipping (default 0.5). env_fn : Callable, optional PettingZoo parallel env factory for n_agents > 1. logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks.

train(total_timesteps: int) -> dict[str, float]

Run MAPPO training loop.

predict(obs: Any, deterministic: bool = True, agent_idx: int = 0) -> np.ndarray | int

Get action from the trained policy for a single agent.

Parameters

obs : array-like Observation. deterministic : bool If True, return the mode of the action distribution. agent_idx : int Which agent's policy to use (default 0).

Returns

Action as an int (discrete) or numpy array (continuous).

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> MAPPO classmethod

Restore MAPPO from a checkpoint.

Model-Based

rlox.algorithms.dreamer.DreamerV3

DreamerV3 agent with RSSM world model and latent actor-critic.

Components: - RSSM world model (GRU dynamics + categorical latents) - Symlog transforms for reward/value prediction - Actor-critic in latent space with imagination rollouts - Sequence replay from buffer - Continuous and discrete action support

Parameters

env_id : str Gymnasium environment ID. n_envs : int Number of parallel environments (default 1). seed : int Random seed (default 42). latent_dim : int Latent feature dimension -- ignored, computed from RSSM dims. imagination_horizon : int Steps to imagine for policy learning (default 5). buffer_size : int Replay buffer capacity (default 10000). batch_size : int Number of sequences per training batch (default 32). seq_len : int Sequence length for world model training (default 16). learning_rate : float Learning rate for all optimisers (default 3e-4). gamma : float Discount factor (default 0.99). obs_dim : int Observation dimensionality (default 4, auto-detected if possible). n_actions : int Number of actions or action dim (default 2, auto-detected if possible). deter_dim : int RSSM deterministic state dim (default 64). stoch_dim : int Number of categorical distributions (default 8). stoch_classes : int Classes per categorical (default 8). kl_balance : float KL balancing coefficient (default 0.8). free_nats : float Free nats for KL loss (default 1.0).

train(total_timesteps: int) -> dict[str, float]

Run DreamerV3 training loop.

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> DreamerV3 classmethod

Restore DreamerV3 from a checkpoint.

Distributed

rlox.algorithms.impala.IMPALA

IMPALA with V-trace off-policy correction.

Auto-detects observation and action dimensions from the environment. Actors collect data in parallel threads with the current policy snapshot. The learner applies V-trace corrected updates.

V-trace computation is vectorized: all environments are processed in a single batched call rather than looping per environment.

Parameters

env_id : str Gymnasium environment ID. n_actors : int Number of actor threads (default 2). n_envs : int Number of environments per actor (default 2). seed : int Random seed (default 42). n_steps : int Rollout length per actor per batch (default 32). learning_rate : float RMSprop learning rate (default 5e-4). gamma : float Discount factor (default 0.99). rho_bar : float V-trace truncation for importance weights (default 1.0). c_bar : float V-trace truncation for trace coefficients (default 1.0). vf_coef : float Value loss coefficient (default 0.5). ent_coef : float Entropy bonus coefficient (default 0.01). max_grad_norm : float Maximum gradient norm for clipping (default 40.0). logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks. worker_addresses : list[str] | None If provided, each actor uses a :class:RemoteEnvPool connecting to a subset of these gRPC worker addresses instead of local VecEnv/GymVecEnv. The addresses are partitioned evenly across actors, so len(worker_addresses) >= n_actors is required.

train(total_timesteps: int) -> dict[str, float]

Run IMPALA training loop.

predict(obs: Any, deterministic: bool = True) -> np.ndarray | int

Get action from the trained policy.

Parameters

obs : array-like Observation. deterministic : bool If True, return the mode of the action distribution.

Returns

Action as an int (discrete) or numpy array (continuous).

save(path: str) -> None

Save training checkpoint.

from_checkpoint(path: str, env_id: str | None = None) -> IMPALA classmethod

Restore IMPALA from a checkpoint.

Hybrid

rlox.algorithms.hybrid_ppo.HybridPPO

PPO with Rust-native collection via Candle.

The collection loop (env stepping + policy inference + GAE computation) runs in a background Rust thread with zero Python overhead. PyTorch handles only the training backward pass.

Parameters

env_id : str Environment ID. Currently only "CartPole-v1" (native Rust env). n_envs : int Number of parallel environments (default 16). seed : int Random seed. hidden : int Hidden layer width for both Candle (collection) and PyTorch (training). logger : LoggerCallback, optional Logger for metrics. callbacks : list[Callback], optional Training callbacks. **config_kwargs Override any PPOConfig fields (n_steps, learning_rate, etc.).

train(total_timesteps: int) -> dict[str, float]

Run hybrid PPO training.

Uses collect-then-train pattern: stop the Rust collection thread during PyTorch training to avoid CPU thread contention between Candle/Rayon and PyTorch/Accelerate thread pools on macOS.

predict(obs: np.ndarray, deterministic: bool = True) -> int

Get action from PyTorch policy.

timing_summary() -> dict[str, float]

Return collection vs training time breakdown.