Disentangling Feature Structure: A Mathematically Provable Two-Stage Training Dynamics in Transformers
Zixuan Gong, Shijia Li, Yong Liu, Jiaye Teng

TL;DR
This paper provides a theoretical analysis of the two-stage training dynamics in transformers, showing how disentangled feature structures like syntax and semantics influence learning, supported by empirical spectral analysis.
Contribution
It introduces the first rigorous theoretical framework explaining feature-level two-stage optimization in transformers based on disentangled features.
Findings
Two-stage training dynamics are linked to spectral properties of attention weights.
Disentangled features like syntax and semantics influence the learning process.
The analysis is grounded in a simplified model with structured data.
Abstract
Transformers may exhibit two-stage training dynamics during the real-world training process. For instance, when training GPT-2 on the Counterfact dataset, the answers progress from syntactically incorrect to syntactically correct to semantically correct. However, existing theoretical analyses hardly account for this feature-level two-stage phenomenon, which originates from the disentangled two-type features like syntax and semantics. In this paper, we theoretically demonstrate how the two-stage training dynamics potentially occur in transformers. Specifically, we analyze the feature learning dynamics induced by the aforementioned disentangled two-type feature structure, grounding our analysis in a simplified yet illustrative setting that comprises a normalized ReLU self-attention layer and structured data. Such disentanglement of feature structure is general in practice, e.g., natural…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
