Dynamics of Transient Structure in In-Context Linear Regression Transformers
Liam Carroll, Jesse Hoogland, Matthew Farrugia-Roberts, Daniel Murfet

TL;DR
This paper investigates how transformers trained on in-context linear regression exhibit a transient phase where they initially act like ridge regression before specializing, explained by a tradeoff between loss and complexity.
Contribution
It introduces the transient ridge phenomenon in transformers and provides a theoretical explanation based on Bayesian internal model selection and model complexity.
Findings
Transformers initially behave like ridge regression during training.
The transition from general to specialized solutions is characterized by principal component analysis.
Model complexity measurements support the proposed theoretical explanation.
Abstract
Modern deep neural networks display striking examples of rich internal computational structure. Uncovering principles governing the development of such structure is a priority for the science of deep learning. In this paper, we explore the transient ridge phenomenon: when transformers are trained on in-context linear regression tasks with intermediate task diversity, they initially behave like ridge regression before specializing to the tasks in their training distribution. This transition from a general solution to a specialized solution is revealed by joint trajectory principal component analysis. Further, we draw on the theory of Bayesian internal model selection to suggest a general explanation for the phenomena of transient structure in transformers, based on an evolving tradeoff between loss and complexity. We empirically validate this explanation by measuring the model complexity…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNeural Networks and Applications · Fault Detection and Control Systems
MethodsLinear Regression
