Skip to content

Custom Components Tutorial

rlox is designed to be extensible. You can bring your own networks, exploration strategies, loss functions, and compose them using the builder pattern.

Protocols — What Your Custom Components Must Implement

rlox uses Python Protocols (structural subtyping) — your custom class just needs to implement the right methods. No inheritance required.

graph TD
    subgraph On-Policy[On-Policy Protocols]
        A[OnPolicyActor] --> A1[get_action_and_logprob]
        A --> A2[get_value]
        A --> A3[get_logprob_and_entropy]
    end

    subgraph Off-Policy[Off-Policy Protocols]
        B[StochasticActor] --> B1[sample]
        B --> B2[deterministic]
        C[DeterministicActor] --> C1[forward]
        D[QFunction] --> D1[forward - obs, action]
        E[DiscreteQFunction] --> E1[forward - obs]
    end

    subgraph Shared[Shared Protocols]
        F[ExplorationStrategy] --> F1[select_action]
        F --> F2[reset]
        G[ReplayBufferProtocol] --> G1[push]
        G --> G2[sample]
        G --> G3[__len__]
        H[CollectorProtocol] --> H1[reset]
        H --> H2[collect_step]
        H --> H3[n_envs]
    end

Example 1: Custom CNN Policy for PPO

import torch
import torch.nn as nn

class MyCNNPolicy(nn.Module):
    """Custom CNN policy that satisfies OnPolicyActor protocol."""

    def __init__(self, obs_shape, n_actions):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(obs_shape[0], 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
        )
        # Compute feature size
        with torch.no_grad():
            dummy = torch.zeros(1, *obs_shape)
            feat_size = self.features(dummy).shape[1]

        self.actor = nn.Linear(feat_size, n_actions)
        self.critic = nn.Linear(feat_size, 1)

    def get_action_and_logprob(self, obs):
        features = self.features(obs)
        logits = self.actor(features)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action, dist.log_prob(action)

    def get_value(self, obs):
        features = self.features(obs)
        return self.critic(features).squeeze(-1)

    def get_logprob_and_entropy(self, obs, actions):
        features = self.features(obs)
        logits = self.actor(features)
        dist = torch.distributions.Categorical(logits=logits)
        return dist.log_prob(actions), dist.entropy()

# Use with PPO — protocol check happens automatically
from rlox.algorithms.ppo import PPO

ppo = PPO(env_id="CartPole-v1", policy=MyCNNPolicy((4,), 2))
ppo.train(total_timesteps=50_000)

Example 2: Custom Networks for Off-Policy Algorithms

SAC, TD3, and DQN now accept custom networks via actor=, critic=, q_network=, and buffer= parameters:

from rlox.algorithms.sac import SAC
from rlox.algorithms.td3 import TD3
from rlox.algorithms.dqn import DQN
import torch.nn as nn

# --- SAC with custom CNN actor and critic ---

class MyCNNActor(nn.Module):
    """Custom CNN actor that satisfies StochasticActor protocol."""
    def __init__(self, obs_shape, act_dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(obs_shape[0], 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            feat_size = self.features(torch.zeros(1, *obs_shape)).shape[1]
        self.mean = nn.Linear(feat_size, act_dim)
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def sample(self, obs):
        features = self.features(obs)
        mean = torch.tanh(self.mean(features))
        std = self.log_std.exp().expand_as(mean)
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1)
        return action, log_prob

    def deterministic(self, obs):
        return torch.tanh(self.mean(self.features(obs)))

# Inject custom actor — SAC handles everything else
sac = SAC(env_id="Pendulum-v1", actor=MyCNNActor((3,), 1), learning_starts=1000)
sac.train(total_timesteps=50_000)

# --- TD3 with custom critic ---

class MyWideCritic(nn.Module):
    """Wider critic network."""
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 1),
        )
    def forward(self, obs, action):
        return self.net(torch.cat([obs, action], dim=-1))

td3 = TD3(env_id="Pendulum-v1", critic=MyWideCritic(3, 1))

# --- DQN with custom Q-network ---

class MyDuelingNet(nn.Module):
    """Custom dueling architecture with attention."""
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.shared = nn.Sequential(nn.Linear(obs_dim, 128), nn.ReLU())
        self.value = nn.Linear(128, 1)
        self.advantage = nn.Linear(128, n_actions)

    def forward(self, obs):
        features = self.shared(obs)
        v = self.value(features)
        a = self.advantage(features)
        return v + a - a.mean(dim=-1, keepdim=True)

dqn = DQN(env_id="CartPole-v1", q_network=MyDuelingNet(4, 2))

# --- Custom replay buffer ---
import rlox
mmap_buf = rlox.MmapReplayBuffer(
    hot_capacity=10_000, total_capacity=500_000,
    obs_dim=4, act_dim=1, cold_path="/tmp/replay.bin"
)
dqn = DQN(env_id="CartPole-v1", buffer=mmap_buf)

Example 3: Multi-Environment Collectors

All off-policy algorithms (SAC, TD3, DQN) support multi-env data collection via OffPolicyCollector. You can use the built-in n_envs parameter or bring your own collector.

Quick: Just Set n_envs

from rlox.algorithms.sac import SAC
from rlox.algorithms.td3 import TD3
from rlox.algorithms.dqn import DQN

# SAC with 4 parallel environments — collector auto-created
sac = SAC(env_id="HalfCheetah-v4", n_envs=4, learning_starts=10_000)
sac.train(total_timesteps=1_000_000)

# Works the same for TD3 and DQN
td3 = TD3(env_id="Pendulum-v1", n_envs=4)
dqn = DQN(env_id="CartPole-v1", n_envs=8)

Custom: Inject Your Own Collector

import rlox
from rlox.algorithms.sac import SAC
from rlox.off_policy_collector import OffPolicyCollector
from rlox.exploration import GaussianNoise

# Create a shared buffer
buf = rlox.ReplayBuffer(1_000_000, obs_dim=3, act_dim=1)

# Create collector with custom exploration
collector = OffPolicyCollector(
    env_id="Pendulum-v1",
    n_envs=4,
    buffer=buf,
    exploration=GaussianNoise(sigma=0.1, clip=0.3),
)

# Inject into algorithm — buffer must be shared
sac = SAC(env_id="Pendulum-v1", buffer=buf, collector=collector)
sac.train(total_timesteps=50_000)

Build Your Own Collector

Any class satisfying CollectorProtocol can be used:

from rlox.off_policy_collector import CollectorProtocol
import numpy as np

class MyCollector:
    """Custom collector — e.g., for sim-to-real or domain randomization."""

    @property
    def n_envs(self) -> int:
        return 1

    def reset(self) -> np.ndarray:
        # Return initial observations (n_envs, obs_dim)
        ...

    def collect_step(self, get_action, step, total_steps):
        # Step envs, store transitions, return (next_obs, rewards, mean_ep_reward)
        ...

# Protocol check
assert isinstance(MyCollector(), CollectorProtocol)

# Use with any off-policy algorithm
sac = SAC(env_id="Pendulum-v1", collector=MyCollector())
graph LR
    subgraph Collection
        C[OffPolicyCollector] --> V[GymVecEnv]
        C --> E[Exploration Strategy]
        C --> B[ReplayBuffer]
    end
    subgraph Learning
        A[SAC / TD3 / DQN] --> U[_update]
        U --> B
    end
    A -->|n_envs > 1| C
    A -->|n_envs = 1| S[Single Env Loop]

Example 4: Custom Exploration Strategies

from rlox.exploration import OUNoise, GaussianNoise, EpsilonGreedy

# Ornstein-Uhlenbeck noise (temporally correlated)
noise = OUNoise(action_dim=1, sigma=0.3, theta=0.15)

# Gaussian noise (i.i.d.)
noise = GaussianNoise(sigma=0.1, clip=0.3)

# Epsilon-greedy with custom decay
noise = EpsilonGreedy(n_actions=4, eps_start=1.0, eps_end=0.01, decay_fraction=0.2)

Example 5: Builder Pattern

from rlox.builders import SACBuilder, PPOBuilder, DQNBuilder
from rlox.exploration import OUNoise
from rlox.callbacks import EvalCallback, ProgressBarCallback

# Simple SAC
sac = SACBuilder().env("Pendulum-v1").build()

# Customized SAC
sac = (SACBuilder()
    .env("Pendulum-v1")
    .learning_rate(1e-4)
    .hidden(256)
    .tau(0.01)
    .exploration(OUNoise(action_dim=1, sigma=0.2))
    .callbacks([EvalCallback(eval_freq=5000), ProgressBarCallback()])
    .compile(True)
    .build())

sac.train(total_timesteps=50_000)

# PPO with custom policy
from rlox.policies import DiscretePolicy
ppo = (PPOBuilder()
    .env("CartPole-v1")
    .policy(DiscretePolicy(obs_dim=4, n_actions=2, hidden=128))
    .n_envs(16)
    .n_steps(256)
    .learning_rate(2.5e-4)
    .compile(True)
    .build())

# DQN with Rainbow
dqn = (DQNBuilder()
    .env("CartPole-v1")
    .double_dqn(True)
    .dueling(True)
    .prioritized(True)
    .build())

Example 6: Composable Losses

import torch
from rlox.losses import LossComponent, CompositeLoss, PPOLoss

# Custom auxiliary loss
class RepresentationLoss(LossComponent):
    """Encourage diverse feature representations."""

    def __init__(self, coef=0.01):
        self.coef = coef

    def compute(self, **kwargs):
        obs = kwargs.get("obs")
        if obs is None:
            return torch.tensor(0.0), {}
        # Encourage high variance in features
        variance = obs.var(dim=0).mean()
        loss = -self.coef * variance  # maximize variance
        return loss, {"repr_var": variance.item()}

# Compose with PPO loss
class KLPenalty(LossComponent):
    """KL divergence penalty against a reference policy."""

    def __init__(self, ref_policy, coef=0.1):
        self.ref_policy = ref_policy
        self.coef = coef

    def compute(self, **kwargs):
        obs = kwargs.get("obs")
        actions = kwargs.get("actions")
        if obs is None or actions is None:
            return torch.tensor(0.0), {}
        with torch.no_grad():
            ref_log_probs, _ = self.ref_policy.get_logprob_and_entropy(obs, actions)
        curr_log_probs = kwargs.get("log_probs", ref_log_probs)
        kl = (curr_log_probs - ref_log_probs).mean()
        return self.coef * kl, {"kl_penalty": kl.item()}

# Combine losses
combined = CompositeLoss([
    (1.0, RepresentationLoss(coef=0.01)),
    (0.1, KLPenalty(ref_policy, coef=0.1)),
])
loss, metrics = combined.compute(obs=obs_batch, actions=action_batch)

Example 7: Custom Training Loop with rlox Components

import rlox
from rlox import RolloutCollector, PPOLoss, RolloutBatch
from rlox.policies import DiscretePolicy
import torch

# Mix and match rlox components in your own loop
policy = DiscretePolicy(obs_dim=4, n_actions=2)
collector = RolloutCollector(
    env_id="CartPole-v1", n_envs=8, seed=0,
    gamma=0.99, gae_lambda=0.95,
)
loss_fn = PPOLoss(clip_eps=0.2, vf_coef=0.5, ent_coef=0.01)
optimizer = torch.optim.Adam(policy.parameters(), lr=2.5e-4)

for update in range(100):
    # Collect with rlox (Rust VecEnv + batched GAE)
    batch = collector.collect(policy, n_steps=128)

    # Your custom training logic
    for epoch in range(4):
        for mb in batch.sample_minibatches(batch_size=256):
            adv = (mb.advantages - mb.advantages.mean()) / (mb.advantages.std() + 1e-8)
            loss, metrics = loss_fn(policy, mb.obs, mb.actions, mb.log_probs,
                                     adv, mb.returns, mb.values)

            # Add your custom loss terms here
            # loss = loss + 0.01 * my_custom_loss(mb)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
            optimizer.step()

    if update % 10 == 0:
        print(f"Update {update}: entropy={metrics['entropy']:.3f}")

Protocol Verification

You can check if your custom class satisfies a protocol at runtime:

from rlox.protocols import OnPolicyActor, StochasticActor

# Check conformance
assert isinstance(my_policy, OnPolicyActor), "Missing required methods!"

# This works with any class that has the right methods
class MinimalPolicy:
    def get_action_and_logprob(self, obs): ...
    def get_value(self, obs): ...
    def get_logprob_and_entropy(self, obs, actions): ...

assert isinstance(MinimalPolicy(), OnPolicyActor)  # True!