SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile
Ruisi Zhang, Tianyu Liu, Will Feng, Andrew Gu, Sanket Purandare,, Wanchao Liang, Francisco Massa

TL;DR
SimpleFSDP introduces a PyTorch-native, compiler-friendly Fully Sharded Data Parallel framework that simplifies implementation, enhances performance through compiler optimizations, and reduces memory usage in large model training.
Contribution
It presents a novel, torch.compile-friendly FSDP implementation with IR node bucketing and reordering for better computation-communication overlap, improving distributed training efficiency.
Findings
Up to 28.54% memory reduction in large models.
68.67% throughput improvement over existing FSDP2.
Effective use of compiler optimizations for distributed training.
Abstract
Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique -friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDistributed and Parallel Computing Systems · Parallel Computing and Optimization Techniques
MethodsLLaMA
