What learning algorithm is in-context learning? Investigations with linear models
Ekin Aky\"urek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, Denny Zhou

TL;DR
This paper investigates whether transformer-based in-context learners implicitly implement standard linear learning algorithms, providing theoretical proofs and empirical evidence that they approximate gradient descent, ridge regression, and Bayesian estimators.
Contribution
It demonstrates that transformers can encode and execute classical linear learning algorithms internally, bridging neural in-context learning with traditional statistical methods.
Findings
Transformers can implement linear learning algorithms like gradient descent and ridge regression.
Trained in-context learners closely match predictors from classical algorithms.
Learners' internal representations encode weights and matrices similar to those in traditional estimators.
Abstract
Neural sequence models, especially transformers, exhibit a remarkable capacity for in-context learning. They can construct new predictors from sequences of labeled examples presented in the input without further parameter updates. We investigate the hypothesis that transformer-based in-context learners implement standard learning algorithms implicitly, by encoding smaller models in their activations, and updating these implicit models as new examples appear in the context. Using linear regression as a prototypical problem, we offer three sources of evidence for this hypothesis. First, we prove by construction that transformers can implement learning algorithms for linear models based on gradient descent and closed-form ridge regression. Second, we show that trained in-context learners closely match the predictors computed by gradient descent, ridge regression, and exact…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsNeural Networks and Applications
MethodsLinear Regression
