JAXMg: A multi-GPU linear solver in JAX
Roeland Wiersema

TL;DR
JAXMg introduces a multi-GPU linear algebra library for JAX, enabling scalable dense linear solves and eigenvalue computations within JAX workflows by interfacing with NVIDIA's cuSOLVERMg.
Contribution
It provides the first JAX-compatible multi-GPU dense linear algebra primitives, integrating high-performance GPU solvers into JAX's composable ecosystem.
Findings
Supports matrices larger than single-GPU memory limits.
Enables end-to-end multi-GPU scientific workflows in JAX.
Preserves JAX's composability and JIT compatibility.
Abstract
Solving large dense linear systems and eigenvalue problems is a core requirement in many areas of scientific computing, but scaling these operations beyond a single GPU remains challenging within modern programming frameworks. While highly optimized multi-GPU solver libraries exist, they are typically difficult to integrate into composable, just-in-time (JIT) compiled Python workflows. JAXMg provides multi-GPU dense linear algebra for JAX, enabling Cholesky-based linear solves and symmetric eigendecompositions for matrices that exceed single-GPU memory limits. By interfacing JAX with NVIDIA's cuSOLVERMg through an XLA Foreign Function Interface, JAXMg exposes distributed GPU solvers as JIT-compatible JAX primitives. This design allows scalable linear algebra to be embedded directly within JAX programs, preserving composability with JAX transformations and enabling multi-GPU execution in…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsParallel Computing and Optimization Techniques · Numerical Methods and Algorithms · Scientific Computing and Data Management
