ShardTensor: Domain Parallelism for Scientific Machine Learning
Corey Adams, Peter Harrington, Akshay Subramaniam, Mohammad Shoaib Abbas, Jaideep Pathak, Mike Pritchard, Sanjay Choudhry

TL;DR
ShardTensor introduces a new domain parallelism paradigm that allows scalable processing of high-resolution scientific data, overcoming traditional hardware constraints and improving training and inference efficiency.
Contribution
It presents ShardTensor, a generalized framework for domain parallelism that decouples data size from hardware limitations in SciML workloads.
Findings
Achieves strong and weak scaling during training and inference.
Enables processing of higher data sizes with improved latency.
Supports multiple dimensions of parallelization for extreme-scale inputs.
Abstract
Scientific Machine Learning (SciML) faces unique challenges for extreme-resolution data, with mitigations that often fail to scale or degrade the accuracy of trained models. While some specialized methods have achieved remarkable results in training models or performing inference on massive spatial datasets with bespoke techniques, there is no generalized framework for parallelization over input data below batch size one per device. In this work we introduce ShardTensor: a novel paradigm of domain parallelism that enables flexible scaling of input data to arbitrary sizes. By decoupling the spatial dimensionality of input data from hardware constraints, ShardTensor enables scientific machine learning workloads to reach new levels of high fidelity training and inference. We demonstrate both strong and weak scaling of workloads during training and inference, showing improved latency with…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
