flowMC: Normalizing-flow enhanced sampling package for probabilistic inference in Jax
Kaze W. K. Wong, Marylou Gabri\'e, Daniel Foreman-Mackey

TL;DR
flowMC is a Python library that combines local gradient-based MCMC methods with normalizing flow models to efficiently sample complex posterior distributions, leveraging JAX for high performance and flexibility.
Contribution
It introduces a novel sampling framework that integrates deep generative models with traditional MCMC, supporting complex distributions and accelerators.
Findings
Handles multimodal and correlated distributions effectively
Supports gradient-based samplers like HMC and MALA via JAX
Achieves high performance with GPU/TPU acceleration
Abstract
flowMC is a Python library for accelerated Markov Chain Monte Carlo (MCMC) leveraging deep generative modeling. It is built on top of the machine learning libraries JAX and Flax. At its core, flowMC uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions. While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then uses it to propose global jumps across the parameter space. The flowMC sampler can handle non-trivial geometry, such as multimodal distributions and distributions with local correlations. The key features of flowMC are summarized in the following list: * Since flowMC is built on top of JAX, it supports gradient-based samplers through automatic differentiation such as MALA and Hamiltonian Monte…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsGenerative Adversarial Networks and Image Synthesis · Gaussian Processes and Bayesian Inference · Topic Modeling
