Streamlining Prediction in Bayesian Deep Learning
Rui Li, Marcus Klasson, Arno Solin, Martin Trapp

TL;DR
This paper introduces a method for efficient Bayesian deep learning predictions using local linearisation and Gaussian approximations, enabling single-pass inference without sampling, applicable to models like MLPs and transformers.
Contribution
It proposes a novel analytical approximation technique for Bayesian predictions that reduces computational cost and complexity in deep learning models.
Findings
Accurately approximates posterior predictive distributions
Reduces inference time by avoiding sampling
Effective on both regression and classification tasks
Abstract
The rising interest in Bayesian deep learning (BDL) has led to a plethora of methods for estimating the posterior distribution. However, efficient computation of inferences, such as predictions, has been largely overlooked with Monte Carlo integration remaining the standard. In this work we examine streamlining prediction in BDL through a single forward pass without sampling. For this we use local linearisation on activation functions and local Gaussian approximations at linear layers. Thus allowing us to analytically compute an approximation to the posterior predictive distribution. We showcase our approach for both MLP and transformers, such as ViT and GPT-2, and assess its performance on regression and classification tasks. Open-source library: https://github.com/AaltoML/SUQ
Peer Reviews
Decision·ICLR 2025 Poster
The paper introduces a novel approach to Bayesian Deep Learning (BDL) by eliminating the need for Monte Carlo sampling through local approximations. It nicely blends local linearisation with Gaussian methods and offers interesting solutions for managing transformer architectures and Kronecker-factored covariance. Quality-wise, the work is solidly validated with extensive experiments across various tasks and architectures. It thoroughly benchmarks against established baselines, clearly showcasin
The treatment of attention layers (using deterministic queries/keys) needs stronger justification. The scaling factor for predictive variance is tuned on validation data - this seems ad hoc. No discussion of potential failure modes or limitations of the local linearisation assumption
The experiment sections cover a variety of tasks.
The main claim of the paper is to provide a scalable alternative to sampling, but they fail in providing sufficient evidence for such claim. Not all the experiments details are reported, first of all, they do NOT specify the number of samples used in the MC sampling baseline. For whatever fixed parameter approximate posterior $q$, the proposed method is an approximation of the predictive distribution. Such predictive distribution is exactly what we get in the limit of infinite number of samples
Overall, I recommend to accept the paper since it 1) proposes a simple, yet effective solution to the problem of efficient Bayesian neural network prediction, 2) appears to be technically correct and is well presented, and 3) provides extensive empirical evidence for the claim that their method performs at least comparable to more expensive alternatives. The proposed method provides an intuitive solution to the problem of efficient BNN prediction, assuming that we want to consider a posterior d
The main weakness of the paper is that it does not consider alternative methods for approximating the predictive distribution of a BNN with a single forward pass. You do not compare the method to last-layer variants of the Laplace approximation and MFVI. This is important to determine if there is any benefit in trying to linearise the neural network to allow for more efficient predictions vs. just treating the network up to the last layer deterministically. For example, when using a last-layer L
Code & Models
Videos
Taxonomy
TopicsAnomaly Detection Techniques and Applications · Gaussian Processes and Bayesian Inference · Data Stream Mining Techniques
MethodsRefunds@Expedia|||How do I get a full refund from Expedia? · Attention Is All You Need · Dense Connections · Dropout · Discriminative Fine-Tuning · Linear Layer · Cosine Annealing · Attention Dropout · Layer Normalization · Byte Pair Encoding
