JAX-Fluids 2.0: Towards HPC for Differentiable CFD of Compressible Two-phase Flows
Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams

TL;DR
JAX-Fluids 2.0 advances differentiable CFD by integrating HPC scalability, new two-phase flow models, and robust numerical schemes, enabling efficient large-scale simulations of complex compressible flows.
Contribution
This work introduces HPC-capable parallelization and new two-phase flow models into JAX-Fluids, enhancing its scalability, robustness, and modeling capabilities for differentiable CFD.
Findings
Scales efficiently on GPU and TPU HPC systems
Successfully simulates complex two-phase flows and shock interactions
Demonstrates stable gradient computation across extended trajectories
Abstract
In our effort to facilitate machine learning-assisted computational fluid dynamics (CFD), we introduce the second iteration of JAX-Fluids. JAX-Fluids is a Python-based fully-differentiable CFD solver designed for compressible single- and two-phase flows. In this work, the first version is extended to incorporate high-performance computing (HPC) capabilities. We introduce a parallelization strategy utilizing JAX primitive operations that scales efficiently on GPU (up to 512 NVIDIA A100 graphics cards) and TPU (up to 1024 TPU v3 cores) HPC systems. We further demonstrate the stable parallel computation of automatic differentiation gradients across extended integration trajectories. The new code version offers enhanced two-phase flow modeling capabilities. In particular, a five-equation diffuse-interface model is incorporated which complements the level-set sharp-interface model.…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsFluid Dynamics and Mixing
