Callbacks & Logging API¶
Callbacks¶
rlox.callbacks.Callback
¶
Base callback class. Override methods to hook into training events.
All on_* methods are no-ops by default. Override only the ones
you need. on_step should return True to continue training
or False to stop early.
on_training_start(**kwargs: Any) -> None
¶
Called once before the training loop begins.
on_step(**kwargs: Any) -> bool
¶
Called after each environment step. Return False to stop training.
on_rollout_end(**kwargs: Any) -> None
¶
Called after a complete rollout is collected.
on_train_batch(**kwargs: Any) -> None
¶
Called after each SGD minibatch update.
on_eval(**kwargs: Any) -> None
¶
Called after an evaluation episode completes.
on_training_end(**kwargs: Any) -> None
¶
Called once after the training loop finishes.
rlox.callbacks.CallbackList
¶
Run multiple callbacks in sequence.
rlox.callbacks.EvalCallback
¶
Bases: Callback
Periodic evaluation with optional best model saving.
Requires the training loop to pass algo=self in on_step() kwargs.
The algorithm must implement predict(obs, deterministic=True).
Parameters¶
eval_env : gymnasium.Env, optional
Environment for evaluation. If None, created from algo's env_id.
eval_freq : int
Evaluate every eval_freq steps (default 10000).
n_eval_episodes : int
Number of evaluation episodes (default 5).
best_model_path : str, optional
If set, save the best model weights to this path.
verbose : bool
Print evaluation results (default True).
rlox.callbacks.CheckpointCallback
¶
Bases: Callback
Periodic checkpoint saving.
Requires the training loop to pass algo=self in on_step() kwargs.
The algorithm must implement save(path).
Parameters¶
save_freq : int
Save a checkpoint every save_freq steps (default 10000).
save_path : str
Directory for checkpoint files (default "checkpoints").
verbose : bool
Print when saving (default True).
rlox.callbacks.TimingCallback
¶
rlox.callbacks.EarlyStoppingCallback
¶
Loggers¶
rlox.logging.LoggerCallback
¶
Base logger callback. Override methods to hook into training events.
Subclass this to implement custom logging (e.g. CSV, MLflow, etc.).
rlox.logging.ConsoleLogger
¶
rlox.logging.WandbLogger
¶
rlox.logging.TensorBoardLogger
¶
Protocols¶
rlox.protocols.OnPolicyActor
¶
rlox.protocols.StochasticActor
¶
Bases: Protocol
Protocol for stochastic actors (SAC).
Any nn.Module implementing sample() and deterministic() can be used as a SAC actor.
rlox.protocols.DeterministicActor
¶
rlox.protocols.ExplorationStrategy
¶
rlox.protocols.ReplayBufferProtocol
¶
Bases: Protocol
Protocol for replay buffers.
Any buffer implementing push/sample/len can be used with off-policy algorithms.
Offline RL¶
rlox.offline.base.OfflineAlgorithm
¶
Base class for offline RL algorithms.
Handles the shared training loop: sample → update → log → callback.
Subclasses implement _update() with algorithm-specific logic.
Parameters¶
dataset : OfflineDataset Offline dataset to sample from. batch_size : int Minibatch size for SGD. callbacks : list[Callback], optional Training callbacks. logger : LoggerCallback, optional Logger for metrics.
train(n_gradient_steps: int) -> dict[str, float]
¶
Run offline training for n_gradient_steps.
Returns metrics from the last update step.
rlox.offline.base.OfflineDataset
¶
Bases: Protocol
Protocol for offline datasets — users can bring their own.
Any object with sample(batch_size, seed) and __len__() works.
The built-in rlox.OfflineDatasetBuffer satisfies this protocol.
sample(batch_size: int, seed: int) -> dict[str, np.ndarray]
¶
Sample a batch of transitions.
Returns dict with keys: obs, next_obs, actions, rewards, terminated.