Transformers Can Do Bayesian Inference
Samuel M\"uller, Noah Hollmann, Sebastian Pineda Arango, Josif, Grabocka, Frank Hutter

TL;DR
This paper introduces Prior-Data Fitted Networks (PFNs), a novel approach that uses in-context learning with transformers to perform Bayesian inference efficiently across diverse tasks, mimicking Gaussian processes and enabling fast, accurate predictions.
Contribution
PFNs leverage in-context learning to approximate a wide range of posteriors, providing a general, fast, and accurate method for Bayesian inference in various settings.
Findings
PFNs can nearly perfectly mimic Gaussian processes.
PFNs enable over 200-fold speedups in Bayesian inference tasks.
PFNs perform well in diverse areas like Gaussian process regression, Bayesian neural networks, and few-shot image classification.
Abstract
Currently, it is hard to reap the benefits of deep learning for Bayesian methods, which allow the explicit specification of prior knowledge and accurately capture model uncertainty. We present Prior-Data Fitted Networks (PFNs). PFNs leverage in-context learning in large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. Presented with a set of samples from a new…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsGaussian Processes and Bayesian Inference · Machine Learning and Algorithms · Machine Learning and Data Classification
MethodsGaussian Process
