TL;DR
JaxSGMC is a flexible, modular library in JAX that simplifies the implementation of stochastic gradient MCMC methods for uncertainty quantification in deep learning, promoting research and application in Bayesian neural networks.
Contribution
It introduces a modular, application-agnostic library for SG-MCMC in JAX, enabling easy implementation and customization of samplers for Bayesian deep learning.
Findings
Supports multiple state-of-the-art SG-MCMC algorithms
Facilitates rapid development of new samplers
Accelerates research in uncertainty quantification
Abstract
We present JaxSGMC, an application-agnostic library for stochastic gradient Markov chain Monte Carlo (SG-MCMC) in JAX. SG-MCMC schemes are uncertainty quantification (UQ) methods that scale to large datasets and high-dimensional models, enabling trustworthy neural network predictions via Bayesian deep learning. JaxSGMC implements several state-of-the-art SG-MCMC samplers to promote UQ in deep learning by reducing the barriers of entry for switching from stochastic optimization to SG-MCMC sampling. Additionally, JaxSGMC allows users to build custom samplers from standard SG-MCMC building blocks. Due to this modular structure, we anticipate that JaxSGMC will accelerate research into novel SG-MCMC schemes and facilitate their application across a broad range of domains.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
MethodsLib
