Skip to content

Callbacks & Logging API

Callbacks

rlox.callbacks.Callback

Base callback class. Override methods to hook into training events.

All on_* methods are no-ops by default. Override only the ones you need. on_step should return True to continue training or False to stop early.

on_training_start(**kwargs: Any) -> None

Called once before the training loop begins.

on_step(**kwargs: Any) -> bool

Called after each environment step. Return False to stop training.

on_rollout_end(**kwargs: Any) -> None

Called after a complete rollout is collected.

on_train_batch(**kwargs: Any) -> None

Called after each SGD minibatch update.

on_eval(**kwargs: Any) -> None

Called after an evaluation episode completes.

on_training_end(**kwargs: Any) -> None

Called once after the training loop finishes.

rlox.callbacks.CallbackList

Run multiple callbacks in sequence.

rlox.callbacks.EvalCallback

Bases: Callback

Periodic evaluation with optional best model saving.

Requires the training loop to pass algo=self in on_step() kwargs. The algorithm must implement predict(obs, deterministic=True).

Parameters

eval_env : gymnasium.Env, optional Environment for evaluation. If None, created from algo's env_id. eval_freq : int Evaluate every eval_freq steps (default 10000). n_eval_episodes : int Number of evaluation episodes (default 5). best_model_path : str, optional If set, save the best model weights to this path. verbose : bool Print evaluation results (default True).

rlox.callbacks.CheckpointCallback

Bases: Callback

Periodic checkpoint saving.

Requires the training loop to pass algo=self in on_step() kwargs. The algorithm must implement save(path).

Parameters

save_freq : int Save a checkpoint every save_freq steps (default 10000). save_path : str Directory for checkpoint files (default "checkpoints"). verbose : bool Print when saving (default True).

rlox.callbacks.ProgressBarCallback

Bases: Callback

Display a tqdm progress bar during training.

rlox.callbacks.TimingCallback

Bases: Callback

Measure wall-clock time of each training phase.

summary() -> dict[str, float]

Return percentage of time spent in each phase.

rlox.callbacks.EarlyStoppingCallback

Bases: Callback

Stop training when reward plateaus.

Parameters

patience : int Number of steps without improvement before stopping (default 10). min_delta : float Minimum improvement to count as progress (default 0.0).

Loggers

rlox.logging.LoggerCallback

Base logger callback. Override methods to hook into training events.

Subclass this to implement custom logging (e.g. CSV, MLflow, etc.).

rlox.logging.ConsoleLogger

Bases: LoggerCallback

Lightweight logger that prints training progress to stdout.

rlox.logging.WandbLogger

Bases: LoggerCallback

Weights & Biases logger (lazy import).

rlox.logging.TensorBoardLogger

Bases: LoggerCallback

TensorBoard logger (lazy import).

Protocols

rlox.protocols.OnPolicyActor

Bases: Protocol

Protocol for on-policy actor-critics (PPO, A2C).

Any nn.Module implementing these three methods can be used as a PPO/A2C policy.

get_action_and_logprob(obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Sample actions and compute log-probabilities.

Parameters

obs : (B, obs_dim) tensor

Returns

actions : (B,) or (B, act_dim) tensor log_probs : (B,) tensor

get_value(obs: torch.Tensor) -> torch.Tensor

Compute value estimates.

Parameters

obs : (B, obs_dim) tensor

Returns

values : (B,) tensor

get_logprob_and_entropy(obs: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Evaluate log-probability and entropy for given actions.

Parameters

obs : (B, obs_dim) tensor actions : (B,) or (B, act_dim) tensor

Returns

log_probs : (B,) tensor entropy : (B,) tensor

rlox.protocols.StochasticActor

Bases: Protocol

Protocol for stochastic actors (SAC).

Any nn.Module implementing sample() and deterministic() can be used as a SAC actor.

sample(obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Sample actions with reparameterization.

Returns

actions : (B, act_dim) tensor — in [-1, 1] (pre-scaling) log_probs : (B,) tensor

deterministic(obs: torch.Tensor) -> torch.Tensor

Return deterministic (mean) actions.

Returns

actions : (B, act_dim) tensor

rlox.protocols.DeterministicActor

Bases: Protocol

Protocol for deterministic actors (TD3).

Any nn.Module implementing forward() can be used as a TD3 actor.

forward(obs: torch.Tensor) -> torch.Tensor

Return deterministic actions.

Returns

actions : (B, act_dim) tensor

rlox.protocols.ExplorationStrategy

Bases: Protocol

Protocol for exploration strategies.

Used by off-policy algorithms to add noise to actions.

select_action(action: np.ndarray, step: int, total_steps: int) -> np.ndarray

Add exploration noise to an action.

Parameters

action : raw action from the policy step : current training step total_steps : total training steps

Returns

noisy_action : action with exploration noise

reset() -> None

Reset any internal state (e.g., OU noise).

rlox.protocols.ReplayBufferProtocol

Bases: Protocol

Protocol for replay buffers.

Any buffer implementing push/sample/len can be used with off-policy algorithms.

push(*args: Any, **kwargs: Any) -> None

Store a transition.

sample(batch_size: int, seed: int) -> dict[str, Any]

Sample a batch of transitions.

__len__() -> int

Return number of stored transitions.

Offline RL

rlox.offline.base.OfflineAlgorithm

Base class for offline RL algorithms.

Handles the shared training loop: sample → update → log → callback. Subclasses implement _update() with algorithm-specific logic.

Parameters

dataset : OfflineDataset Offline dataset to sample from. batch_size : int Minibatch size for SGD. callbacks : list[Callback], optional Training callbacks. logger : LoggerCallback, optional Logger for metrics.

train(n_gradient_steps: int) -> dict[str, float]

Run offline training for n_gradient_steps.

Returns metrics from the last update step.

rlox.offline.base.OfflineDataset

Bases: Protocol

Protocol for offline datasets — users can bring their own.

Any object with sample(batch_size, seed) and __len__() works. The built-in rlox.OfflineDatasetBuffer satisfies this protocol.

sample(batch_size: int, seed: int) -> dict[str, np.ndarray]

Sample a batch of transitions.

Returns dict with keys: obs, next_obs, actions, rewards, terminated.