One Step of Gradient Descent is Provably the Optimal In-Context Learner with One Layer of Linear Self-Attention
Arvind Mahankali, Tatsunori B. Hashimoto, Tengyu Ma

TL;DR
This paper provides a theoretical analysis showing that a single-layer linear self-attention transformer trained on linear regression data learns to perform one step of gradient descent, with the learned algorithm adapting to data distribution changes.
Contribution
It offers the first rigorous proof that such transformers implement gradient descent steps and explores how data distribution affects the learned algorithm.
Findings
Transformer with linear self-attention learns one step of GD.
Distribution of covariates influences whether GD is pre-conditioned.
Response distribution changes have limited impact on the learned algorithm.
Abstract
Recent works have empirically analyzed in-context learning and shown that transformers trained on synthetic linear regression tasks can learn to implement ridge regression, which is the Bayes-optimal predictor, given sufficient capacity [Aky\"urek et al., 2023], while one-layer transformers with linear self-attention and no MLP layer will learn to implement one step of gradient descent (GD) on a least-squares linear regression objective [von Oswald et al., 2022]. However, the theory behind these observations remains poorly understood. We theoretically study transformers with a single layer of linear self-attention, trained on synthetic noisy linear regression data. First, we mathematically show that when the covariates are drawn from a standard Gaussian distribution, the one-layer transformer which minimizes the pre-training loss will implement a single step of GD on the least-squares…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Machine Learning and Data Classification · Advanced Neural Network Applications
MethodsLinear Regression
