laplax -- Laplace Approximations with JAX
Tobias Weber, B\'alint Mucs\'anyi, Lenard Rommel, Thomas Christie, Lars Kas\"uschke, Marvin Pf\"ortner, Philipp Hennig

TL;DR
laplax is an open-source Python package that simplifies applying Laplace approximations to neural networks using JAX, supporting Bayesian uncertainty quantification and model selection.
Contribution
introduces laplax, a modular, functional Python library for efficient Laplace approximations in deep learning with minimal dependencies.
Findings
enables scalable uncertainty estimation in neural networks
facilitates research on Bayesian deep learning methods
provides a flexible tool for rapid experimentation
Abstract
The Laplace approximation provides a scalable and efficient means of quantifying weight-space uncertainty in deep neural networks, enabling the application of Bayesian tools such as predictive uncertainty and model selection via Occam's razor. In this work, we introduce laplax, a new open-source Python package for performing Laplace approximations with jax. Designed with a modular and purely functional architecture and minimal external dependencies, laplax offers a flexible and researcher-friendly framework for rapid prototyping and experimentation. Its goal is to facilitate research on Bayesian neural networks, uncertainty quantification for deep learning, and the development of improved Laplace approximation techniques.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMathematics and Applications · Advanced Numerical Analysis Techniques
