TL;DR
JAX MD is a flexible, scalable framework for differentiable physics simulations, enabling advanced applications like neural network integration and meta-optimization in molecular dynamics on GPUs.
Contribution
The paper introduces JAX MD, a novel differentiable physics simulation framework that seamlessly integrates neural networks and scales efficiently on GPUs.
Findings
Supports differentiation of entire trajectories for meta-optimization
Integrates graph neural networks into physics simulations
Scales to hundreds of thousands of particles on a single GPU
Abstract
We introduce JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. These primitives are flexible enough that they can be used to scale up workloads outside of molecular dynamics. We present several examples that highlight the features of JAX MD including: integration of graph neural networks into traditional simulations,…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
