Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training
Yuanzhong Xu, HyoukJoong Lee, Dehao Chen, Hongjun Choi, Blake, Hechtman, Shibo Wang

TL;DR
This paper introduces an automatic method to shard weight update computations across replicas in data-parallel training, significantly improving performance and scalability without requiring model code modifications.
Contribution
It presents a static analysis and transformation approach to automatically shard weight updates, reducing bottlenecks in large-scale neural network training.
Findings
Achieves substantial speedups on image and language models on Cloud TPUs.
Enables efficient training with both Adam and SGD optimizers.
Contributes to state-of-the-art MLPerf 0.6 training performance.
Abstract
In data-parallel synchronous training of deep neural networks, different devices (replicas) run the same program with different partitions of the training batch, but weight update computation is repeated on all replicas, because the weights do not have a batch dimension to partition. This can be a bottleneck for performance and scalability in typical language models with large weights, and models with small per-replica batch size which is typical in large-scale training. This paper presents an approach to automatically shard the weight update computation across replicas with efficient communication primitives and data formatting, using static analysis and transformations on the training computation graph. We show this technique achieves substantial speedups on typical image and language models on Cloud TPUs, requiring no change to model code. This technique helps close the gap between…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Domain Adaptation and Few-Shot Learning · Parallel Computing and Optimization Techniques
