JetSCI: A Hybrid JAX-PETSc Framework for Scalable Differentiable Simulation
Alberto Cattaneo, M Keith Ballard, Robert M. Kirby, Varun Shankar

TL;DR
JetSCI is a hybrid framework combining JAX and PETSc to enable scalable, differentiable simulations on large HPC systems, improving efficiency and accuracy for complex micromechanics problems.
Contribution
It introduces a novel hybrid JAX-PETSc framework that unifies GPU-accelerated differentiable modeling with scalable MPI-based HPC solutions.
Findings
JetSCI outperforms JAX-only implementations in efficiency.
JetSCI achieves higher accuracy in finite element micromechanics simulations.
JetSCI effectively combines GPU acceleration with distributed-memory parallelism.
Abstract
The rapid rise of scientific machine learning (SciML) has expanded the role of differentiable modeling, surrogate modeling, and data-driven constitutive laws in large-scale simulation. The JAX framework provides an attractive environment for these workflows through automatically differentiable programs, vectorization, GPU acceleration, and while enabling seamless learning of surrogate models. However, large-scale simulation still relies on mature HPC infrastructure. Libraries, such as PETSc, provide scalable MPI-based parallelism, robust linear and nonlinear solvers, and advanced preconditioning capabilities that remain difficult to reproduce in JAX-only workflows. We present JetSCI, a hybrid JAX-PETSc framework that unifies these complementary strengths. JetSCI uses JAX for GPU-parallel differentiable discretizations and PETSc for robust, scalable solution of the resulting systems on…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
