Scaling Deep Learning Training with MPMD Pipeline Parallelism
Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

TL;DR
JaxPP is a system that enables efficient large-scale deep learning training through flexible pipeline parallelism, automatic task distribution, and asynchronous execution, significantly improving hardware utilization.
Contribution
The paper introduces JaxPP, a novel system that combines flexible pipeline schedules, automatic task distribution, and an MPMD runtime for improved deep learning training scalability.
Findings
Up to 1.11x hardware utilization improvement
Automatic distribution of pipeline stages over clusters
Seamless programming model for user-defined schedules
Abstract
We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to with respect to the best performing SPMD configuration.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNeural Networks and Applications
