GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
Dmitry Lepikhin, HyoukJoong Lee, Yuanzhong Xu, Dehao Chen, Orhan, Firat, Yanping Huang, Maxim Krikun, Noam Shazeer, Zhifeng Chen

TL;DR
GShard introduces a scalable, efficient method for training extremely large neural network models by combining conditional computation, automatic sharding, and a new API extension, enabling models with hundreds of billions of parameters to be trained effectively.
Contribution
It presents GShard, a novel module and compiler extension that simplifies parallel computation and enables training of giant models with automatic sharding.
Findings
Successfully trained a 600-billion-parameter multilingual translation model
Achieved high translation quality across 100 languages
Reduced training time to 4 days on 2048 TPU v3 accelerators
Abstract
Neural network scaling has been critical for improving the model quality in many real-world machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a sure-fire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. GShard is a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler. It provides an elegant way to express a wide range of parallel computation patterns with minimal changes to the existing model code. GShard enabled us to scale up multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts beyond 600 billion parameters using automatic sharding. We demonstrate that such a giant model can efficiently be trained on 2048 TPU…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsTopic Modeling · Advanced Neural Network Applications · Machine Learning and Data Classification
MethodsLinear Layer · Absolute Position Encodings · Position-Wise Feed-Forward Layer · GShard · Residual Connection · Label Smoothing · Multi-Head Attention · Adam · *Communicated@Fast*How Do I Communicate to Expedia? · Dropout
