A Mean Field Theory of Batch Normalization
Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, and, Samuel S. Schoenholz

TL;DR
This paper develops a mean field theory for batch normalization in neural networks, revealing that it causes gradient explosion at large depths, which can be mitigated by tuning network parameters, and analyzes the learning dynamics.
Contribution
It provides a precise theoretical analysis of signal and gradient propagation in batch-normalized networks, identifying the root of gradient explosion and proposing ways to improve trainability.
Findings
Gradient signals grow exponentially with depth.
Batch normalization causes gradient explosion.
Tuning near linear regime improves trainability.
Abstract
We develop a mean field theory for batch normalization in fully-connected feedforward neural networks. In so doing, we provide a precise characterization of signal propagation and gradient backpropagation in wide batch-normalized networks at initialization. Our theory shows that gradient signals grow exponentially in depth and that these exploding gradients cannot be eliminated by tuning the initial weight variances or by adjusting the nonlinear activation function. Indeed, batch normalization itself is the cause of gradient explosion. As a result, vanilla batch-normalized networks without skip connections are not trainable at large depths for common initialization schemes, a prediction that we verify with a variety of empirical simulations. While gradient explosion cannot be eliminated, it can be reduced by tuning the network close to the linear regime, which improves the trainability…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsFault Detection and Control Systems · Control Systems and Identification · Advanced Control Systems Optimization
MethodsBatch Normalization
