JaxPruner: A concise library for sparsity research
Joo Hyung Lee, Wonpyo Park, Nicole Mitchell, Jonathan Pilault, Johan, Obando-Ceron, Han-Byul Kim, Namhoon Lee, Elias Frantar, Yun Long, Amir, Yazdanbakhsh, Shivani Agrawal, Suvinay Subramanian, Xin Wang, Sheng-Chun Kao,, Xingyao Zhang, Trevor Gale, Aart Bik, Woohyun Han

TL;DR
JaxPruner is an open-source JAX library that simplifies research on sparse neural networks by providing efficient, easy-to-integrate pruning and sparse training algorithms with minimal overhead.
Contribution
It offers a concise, unified implementation of pruning algorithms compatible with Optax, facilitating rapid experimentation and integration across various JAX-based machine learning frameworks.
Findings
Demonstrated seamless integration with four JAX-based codebases.
Provided baseline experiments on popular benchmarks.
Showcased minimal memory and latency overhead.
Abstract
This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMachine Learning and Algorithms · Machine Learning and Data Classification · Advanced Neural Network Applications
MethodsLib · Pruning
