MPAX: Mathematical Programming in JAX
Haihao Lu, Zedong Peng, Jinwen Yang

TL;DR
MPAX is an open-source JAX-based solver for large-scale LP and QP problems that leverages modern machine learning infrastructure for efficient, scalable, and differentiable optimization across various hardware platforms.
Contribution
Introduces MPAX, a novel JAX-native solver integrating advanced algorithms and hardware acceleration for large-scale mathematical programming in machine learning workflows.
Findings
Significant GPU speedups over CPU baselines.
Competitive performance with existing GPU-based solvers.
Effective multi-GPU scaling and differentiable optimization capabilities.
Abstract
We present MPAX (Mathematical Programming in JAX), an open-source first-order solver for large-scale linear programming (LP) and convex quadratic programming (QP) built natively in JAX. The primary goal of MPAX is to exploit modern machine learning infrastructure for large-scale mathematical programming, while also providing advanced mathematical programming algorithms that are easy to integrate into machine learning workflows. MPAX implements two PDHG variants, r2HPDHG for LP and rAPDHG for QP, together with diagonal preconditioning, adaptive restarts, adaptive step sizes, primal-weight updates, infeasibility detection, and feasibility polishing. Leveraging JAX's compilation and parallelization ecosystem, MPAX provides across-hardware portability, batched solving, distributed optimization, and automatic differentiation. We evaluate MPAX on CPUs, NVIDIA GPUs, and Google TPUs, observing…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDistributed and Parallel Computing Systems · Parallel Computing and Optimization Techniques
