Transformers Trained via Gradient Descent Can Provably Learn a Class of Teacher Models
Chenyang Zhang, Qingyue Zhao, Quanquan Gu, Yuan Cao

TL;DR
This paper provides a theoretical analysis showing that one-layer transformers trained via gradient descent can learn a broad class of teacher models, including convolutional and statistical models, with provable guarantees.
Contribution
It introduces a theoretical framework demonstrating transformers' ability to learn diverse teacher models with simplified attention mechanisms and provides generalization guarantees.
Findings
Transformers can recover teacher model parameters with optimal loss.
One-layer transformers with position-only attention can learn various teacher models.
Transformers generalize well to out-of-distribution data under mild conditions.
Abstract
Transformers have achieved great success across a wide range of applications, yet the theoretical foundations underlying their success remain largely unexplored. To demystify the strong capacities of transformers applied to versatile scenarios and tasks, we theoretically investigate utilizing transformers as students to learn from a class of teacher models. Specifically, the teacher models covered in our analysis include convolution layers with average pooling, graph convolution layers, and various classic statistical learning models, including a variant of sparse token selection models [Sanford et al., 2023, Wang et al., 2024] and group-sparse linear predictors [Zhang et al., 2025]. When learning from this class of teacher models, we prove that one-layer transformers with simplified "position-only'' attention can successfully recover all parameter blocks of the teacher models, thus…
Peer Reviews
Decision·ICLR 2026 Poster
The paper is well written and well organized. The proofs are complex and they seem sensible to me, even though I could only superficially go over them.
My main concerns with the paper are the strong theoretical assumptions and what I perceive is a somewhat unfair comparison with previous works. Regarding the first point, the assumption that the attention matrix only depends on the position of the tokens is extremely strong. In fact, this simply means that the attention matrix does not depend on the input at all, since the positions are the same for all inputs. (In Eq. (2.4), P is fixed and independent of X). With this assumption, the paper ne
- The results provide a tight convergence rate of $\Theta(1/T)$ when training on the population loss, improving previous analyses while generalizing the setting. - Experimental results for different examples of teacher models are provided, which supports the theoretical analysis.
- The setting is restricted to the population loss, as well as Gaussian data. - The generalization of teacher models is nice, but ultimately it is still the idea of sparse selection from the input.
1. Unified theoretical framework: The paper identifies a fundamental bilinear structure shared across diverse learning tasks, enabling unified learning guarantees. This is a significant conceptual contribution that connects previously studied disparate settings. 2. Tight convergence guarantees: The paper establishes matching upper and lower bounds for the convergence rate of Θ(1/T), improving upon prior work 3. Empirical alignment with theory. Synthetic experiments match predicted slopes and s
1. **Limited model complexity** The analysis is restricted to one-layer transformers with simplified "position-only" attention. While the authors justify this simplification empirically, the gap between this architecture and practical multi-layer transformers with full attention is substantial. 2. **Notation density**: The paper is extremely dense with notation making it difficult to follow. A notation table and more intuitive explanations would improve accessibility. 3. **Scalability concerns
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Graph Neural Networks · Advanced Neural Network Applications · Stochastic Gradient Optimization Techniques
