Normalization Layer Per-Example Gradients are Sufficient to Predict Gradient Noise Scale in Transformers
Gavia Gray, Aman Tiwari, Shane Bergsma, Joel Hestness

TL;DR
This paper introduces an efficient method to compute per-example gradient norms in transformers, revealing that normalization layers alone predict the gradient noise scale and can guide training to reduce time.
Contribution
It proposes a minimal-FLOP method for computing per-example gradient norms, showing normalization layers predict GNS and enabling faster training schedules.
Findings
Normalization layer GNS predicts total GNS accurately.
The method reduces training time by 18%.
Per-example gradient norms can be computed efficiently during backpropagation.
Abstract
Per-example gradient norms are a vital ingredient for estimating gradient noise scale (GNS) with minimal variance. Observing the tensor contractions required to compute them, we propose a method with minimal FLOPs in 3D or greater tensor regimes by simultaneously computing the norms while computing the parameter gradients. Using this method we are able to observe the GNS of different layers at higher accuracy than previously possible. We find that the total GNS of contemporary transformer models is predicted well by the GNS of only the normalization layers. As a result, focusing only on the normalization layer, we develop a custom kernel to compute the per-example gradient norms while performing the LayerNorm backward pass with zero throughput overhead. Tracking GNS on only those layers, we are able to guide a practical batch size schedule that reduces training time by 18% on a…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsMagnetic Properties and Applications · Advancements in Semiconductor Devices and Circuit Design · Integrated Circuits and Semiconductor Failure Analysis
MethodsGraph Network-based Simulators
