Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape
Juno Kim, Taiji Suzuki

TL;DR
This paper analyzes how Transformers learn nonlinear features in context by studying the mean-field dynamics of a multi-layer model, revealing benign loss landscapes and stability properties that explain in-context learning capabilities.
Contribution
It introduces a mean-field theoretical framework for Transformer dynamics with nonlinear features, providing new insights into the optimization landscape and stability of in-context learning.
Findings
Benign, highly nonconvex loss landscape in the mean-field limit.
Almost all Wasserstein gradient flows avoid saddle points.
Established new methods for convergence and improvement rate analysis.
Abstract
Large language models based on the Transformer architecture have demonstrated impressive capabilities to learn in context. However, existing theoretical studies on how this phenomenon arises are limited to the dynamics of a single layer of attention trained on linear regression tasks. In this paper, we study the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer. The MLP acts as a common nonlinear representation or feature map, greatly enhancing the power of in-context learning. We prove in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign. We also analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Furthermore, we establish novel methods…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNeural Networks and Applications
MethodsAttention Is All You Need · Layer Normalization · Absolute Position Encodings · Linear Layer · Byte Pair Encoding · Multi-Head Attention · Residual Connection · Dense Connections · Position-Wise Feed-Forward Layer · Label Smoothing
