MPX: Mixed Precision Training for JAX
Alexander Gr\"afe, Sebastian Trimpe

TL;DR
MPX is a new toolbox that enables efficient mixed-precision training in JAX, simplifying integration with existing frameworks and maintaining model accuracy through dynamic loss-scaling and type management.
Contribution
It introduces MPX, the first comprehensive mixed-precision training toolkit for JAX, with seamless integration, dynamic loss-scaling, and precision management features.
Findings
Enables mixed-precision training in JAX with minimal code changes.
Maintains model accuracy through dynamic loss-scaling.
Provides automatic gradient and optimizer management for mixed precision.
Abstract
Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
