Transformers learn in-context by gradient descent
Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, Jo\~ao, Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, Max Vladymyrov

TL;DR
This paper reveals that Transformers learn in-context by effectively performing gradient descent during their forward pass, providing a mechanistic understanding of in-context learning especially in regression tasks.
Contribution
It demonstrates the equivalence between self-attention transformations and gradient descent, and shows how trained Transformers act as mesa-optimizers learning models via gradient descent.
Findings
Transformers can be constructed to mimic gradient descent on regression tasks.
Trained Transformers learn iterative curvature correction to improve regression performance.
The study links in-context learning mechanisms to gradient descent and induction-head functions.
Abstract
At present, the mechanisms of in-context learning in Transformers are not well understood and remain mostly an intuition. In this paper, we suggest that training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradient-descent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass. This allows us, at least…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Generative Adversarial Networks and Image Synthesis · Advanced Neural Network Applications
