Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro
Du Phan, Neeraj Pradhan, Martin Jankowiak

TL;DR
This paper introduces composable effect handlers in NumPyro, enabling flexible probabilistic programming with accelerated inference methods like JIT-compiled NUTS, resulting in faster performance on various datasets.
Contribution
It presents a novel approach to composing effect handlers with program transformations in NumPyro, enhancing flexibility and efficiency in probabilistic programming.
Findings
JIT compilation of NUTS significantly speeds up inference
Composable effect handlers enable flexible extension of NumPyro
Performance improvements are consistent across dataset sizes
Abstract
NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsParallel Computing and Optimization Techniques · Bayesian Modeling and Causal Inference · Formal Methods in Verification
