Understanding Transformer Optimization via Gradient Heterogeneity
Akiyoshi Tomihari, Issei Sato

TL;DR
This paper investigates why Adam outperforms SGD in training Transformers by analyzing gradient heterogeneity, revealing that Adam's normalization makes it less sensitive to this issue and explaining its empirical success.
Contribution
The study provides a theoretical framework linking gradient heterogeneity to optimizer performance in Transformers, highlighting the role of normalization placement and comparing sign-based methods to SGD.
Findings
Gradient heterogeneity degrades SGD convergence but less affects Adam.
Post-LN architectures exhibit higher gradient heterogeneity.
Experimental validation in NLP and vision confirms theoretical insights.
Abstract
Transformers are difficult to optimize with stochastic gradient descent (SGD) and largely rely on adaptive optimizers such as Adam. Despite their empirical success, the reasons behind Adam's superior performance over SGD remain poorly understood. In this study, we analyze the optimization of Transformer models through the lens of \emph{gradient heterogeneity}, defined as the variation in gradient norms across parameter blocks. We provide a theoretical analysis showing that gradient heterogeneity, together with Hessian heterogeneity, degrades the convergence of gradient-based methods such as SGD, while sign-based methods are substantially less sensitive to this effect. Adam's coordinate-wise normalization makes its update directions depend mainly on gradient signs, so Adam can be interpreted as a soft variant of SignSGD. Our analysis uses the fact that SGD and SignSGD follow steepest…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsHippo pathway signaling and YAP/TAZ
MethodsAdam · Stochastic Gradient Descent
