Supervised Pretraining Can Learn In-Context Reinforcement Learning
Jonathan N. Lee, Annie Xie, Aldo Pacchiano, Yash Chandak, Chelsea, Finn, Ofir Nachum, Emma Brunskill

TL;DR
This paper demonstrates that supervised pretraining of transformers on decision-making tasks enables in-context reinforcement learning, allowing models to adapt to new tasks, perform exploration, and achieve sample-efficient learning with theoretical guarantees.
Contribution
Introducing Decision-Pretrained Transformer (DPT), a supervised pretraining method that enables transformers to perform in-context RL with theoretical and empirical benefits.
Findings
DPT can solve RL problems in-context, showing exploration and offline conservatism.
The model generalizes to new tasks beyond pretraining.
Theoretical connection to Bayesian posterior sampling with regret guarantees.
Abstract
Large transformer models trained on diverse datasets have shown a remarkable ability to learn in-context, achieving high few-shot performance on tasks they were not explicitly trained to solve. In this paper, we study the in-context learning capabilities of transformers in decision-making problems, i.e., reinforcement learning (RL) for bandits and Markov decision processes. To do so, we introduce and study Decision-Pretrained Transformer (DPT), a supervised pretraining method where the transformer predicts an optimal action given a query state and an in-context dataset of interactions, across a diverse set of tasks. This procedure, while simple, produces a model with several surprising capabilities. We find that the pretrained transformer can be used to solve a range of RL problems in-context, exhibiting both exploration online and conservatism offline, despite not being explicitly…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsAdvanced Bandit Algorithms Research · Data Stream Mining Techniques · Machine Learning and Data Classification
MethodsAttention Is All You Need · Convolution · Dense Connections · Dropout · Byte Pair Encoding · Softmax · Layer Normalization · Linear Layer · Position-Wise Feed-Forward Layer · Absolute Position Encodings
