Rethinking Sharpness-Aware Minimization as Variational Inference
Szilvia Ujv\'ary, Zsigmond Telek, Anna Kerekes, Anna M\'esz\'aros,, Ferenc Husz\'ar

TL;DR
This paper connects sharpness-aware minimization (SAM) with mean-field variational inference (MFVI), showing they share similar interpretations and proposing algorithms that combine both for improved neural network training.
Contribution
It establishes a theoretical link between SAM and MFVI, and introduces algorithms that interpolate between them, with empirical evaluation on benchmark datasets.
Findings
SAM and MFVI share similar flatness optimization interpretations
Proposed algorithms outperform or match existing methods on benchmarks
SAM-like updates can serve as a drop-in replacement for the reparametrisation trick
Abstract
Sharpness-aware minimization (SAM) aims to improve the generalisation of gradient-based learning by seeking out flat minima. In this work, we establish connections between SAM and Mean-Field Variational Inference (MFVI) of neural network parameters. We show that both these methods have interpretations as optimizing notions of flatness, and when using the reparametrisation trick, they both boil down to calculating the gradient at a perturbed version of the current mean parameter. This thinking motivates our study of algorithms that combine or interpolate between SAM and MFVI. We evaluate the proposed variational algorithms on several benchmark datasets, and compare their performance to variants of SAM. Taking a broader perspective, our work suggests that SAM-like updates can be used as a drop-in replacement for the reparametrisation trick.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Human Pose and Action Recognition · Multimodal Machine Learning Applications
MethodsVariational Inference
