Decision Transformer¶
Intuition¶
Decision Transformer recasts reinforcement learning as a sequence modeling problem. Instead of learning value functions or policy gradients, it trains a causal transformer on (return-to-go, state, action) triples from offline data. At inference time, you condition on a desired return-to-go and the model autoregressively generates actions that aim to achieve that target return. This makes it a purely supervised-learning approach to RL, inheriting the scalability and stability of transformer training.
Key Equations¶
The model takes as input a sequence of interleaved tokens:
where \(R_t = \sum_{t'=t}^{T} r_{t'}\) is the return-to-go at timestep \(t\).
The training objective is standard supervised learning on actions:
where \(\ell\) is cross-entropy for discrete actions or MSE for continuous actions.
At inference, the return-to-go is set to a target value:
Pseudocode¶
algorithm Decision Transformer:
collect offline dataset D = {(s_t, a_t, r_t)_t}
compute return-to-go R_t for each timestep in each episode
initialize transformer model with context_length K
for update = 1 to n_updates do
sample batch of subsequences from D
for each subsequence:
input = interleave(R_{t:t+K}, s_{t:t+K}, a_{t:t+K})
a_pred = transformer(input) at state positions
loss = cross_entropy(a_pred, a_true) or MSE(a_pred, a_true)
update model with Adam
# Inference
set R_0 = target_return
for t = 0, 1, ... do
a_t = transformer.predict(R_{0:t}, s_{0:t}, a_{0:t-1})
execute a_t, observe r_t, s_{t+1}
R_{t+1} = R_t - r_t
Quick Start¶
from rlox import Trainer
trainer = Trainer("dt", env="CartPole-v1", seed=42, config={
"context_length": 20,
"target_return": 200.0,
})
metrics = trainer.train(total_timesteps=50_000)
print(f"Loss: {metrics['loss']:.4f}")
For continuous control:
trainer = Trainer("dt", env="HalfCheetah-v4", seed=42, config={
"context_length": 20,
"embed_dim": 128,
"n_heads": 4,
"n_layers": 3,
"batch_size": 64,
"target_return": 6000.0,
})
metrics = trainer.train(total_timesteps=1_000_000)
Hyperparameters¶
| Parameter | Default | Description |
|---|---|---|
context_length |
20 |
Number of timesteps in the transformer context window |
n_heads |
4 |
Number of attention heads |
n_layers |
3 |
Number of transformer layers |
embed_dim |
128 |
Embedding dimension |
learning_rate |
1e-4 |
Adam learning rate |
batch_size |
64 |
Minibatch size for training |
dropout |
0.1 |
Dropout rate |
target_return |
200.0 |
Desired return-to-go for evaluation |
warmup_steps |
500 |
Data collection steps before training |
seed |
42 |
Random seed |
When to Use¶
- Use Decision Transformer when: you have a large offline dataset and want to leverage transformer architectures, or when you want return-conditioned behavior (generate policies of varying quality).
- Prefer DT over CQL/IQL when: the dataset is large and diverse, and you want to scale with model size rather than algorithmic complexity.
- Do not use DT when: online interaction is cheap (on-policy methods like PPO will be simpler), or the offline dataset is small (prefer IQL or TD3+BC).
References¶
- Chen, L., Lu, K., Rajeswaran, A., Lee, K., Grover, A., Laskin, M., Abbeel, P., Srinivas, A., & Mordatch, I. (2021). Decision Transformer: Reinforcement Learning via Sequence Modeling. NeurIPS 2021. arXiv:2106.01345.