Transformers learn to implement preconditioned gradient descent for in-context learning
Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra

TL;DR
This paper proves that transformers trained on linear regression tasks can learn to implement preconditioned gradient descent algorithms, with the depth of the transformer corresponding to the number of iterations learned.
Contribution
It provides the first theoretical analysis showing transformers can learn to implement preconditioned gradient descent through training on random instances.
Findings
Single attention layer implements one iteration of preconditioned gradient descent.
Deeper transformers with L layers implement L iterations of the algorithm.
Preconditioning adapts to input distribution and data variance.
Abstract
Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: Can transformers learn to implement such algorithms by training over random problem instances? To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with …
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsStochastic Gradient Optimization Techniques · Geophysical Methods and Applications · Advanced Neural Network Applications
