Bringing PDEs to JAX with forward and reverse modes automatic differentiation
Ivan Yashchuk

TL;DR
This paper introduces an extension to the JAX automatic differentiation library that integrates with the Firedrake finite element library, enabling efficient differentiation of PDE solutions for scientific computing applications.
Contribution
It provides a high-level interface for differentiating PDE solutions using JAX and Firedrake, bypassing complex low-level differentiation through nonlinear solvers.
Findings
Enables differentiation of PDE solutions with JAX and Firedrake
Supports tangent-linear and adjoint differentiation methods
Facilitates composition of finite element solvers with differentiable programs
Abstract
Partial differential equations (PDEs) are used to describe a variety of physical phenomena. Often these equations do not have analytical solutions and numerical approximations are used instead. One of the common methods to solve PDEs is the finite element method. Computing derivative information of the solution with respect to the input parameters is important in many tasks in scientific computing. We extend JAX automatic differentiation library with an interface to Firedrake finite element library. High-level symbolic representation of PDEs allows bypassing differentiating through low-level possibly many iterations of the underlying nonlinear solvers. Differentiating through Firedrake solvers is done using tangent-linear and adjoint equations. This enables the efficient composition of finite element solvers with arbitrary differentiable programs. The code is available at…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsIterative Learning Control Systems
MethodsLib
