Kernelized Wasserstein Natural Gradient
Michael Arbel, Arthur Gretton, Wuchen Li, Guido Montufar

TL;DR
This paper introduces a kernelized Wasserstein natural gradient method that efficiently approximates the natural gradient in Wasserstein space, improving optimization in probabilistic models with theoretical guarantees and empirical validation.
Contribution
It proposes a novel framework for approximating Wasserstein natural gradients using kernel methods, balancing accuracy and computational efficiency.
Findings
The estimator accurately approximates the natural gradient in simple examples.
Using the estimator improves classification performance on Cifar10 and Cifar100.
The approach offers theoretical guarantees on approximation quality.
Abstract
Many machine learning problems can be expressed as the optimization of some cost functional over a parametric family of probability distributions. It is often beneficial to solve such optimization problems using natural gradient methods. These methods are invariant to the parametrization of the family, and thus can yield more effective optimization. Unfortunately, computing the natural gradient is challenging as it requires inverting a high dimensional matrix at each iteration. We propose a general framework to approximate the natural gradient for the Wasserstein metric, by leveraging a dual formulation of the metric restricted to a Reproducing Kernel Hilbert Space. Our approach leads to an estimator for gradient direction that can trade-off accuracy and computational cost, with theoretical guarantees. We verify its accuracy on simple examples, and show the advantage of using such an…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neuroimaging Techniques and Applications · Sparse and Compressive Sensing Techniques · Topological and Geometric Data Analysis
