Beyond ReinMax: Low-Variance Gradient Estimators for Discrete Latent Variables
Daniel Wang, Thang D. Bui

TL;DR
This paper introduces new low-variance gradient estimators for discrete latent variables, improving training efficiency and accuracy in models like variational autoencoders by combining Rao-Blackwellisation and control variates.
Contribution
It proposes ReinMax-Rao and ReinMax-CV estimators that reduce variance in gradient estimation, and offers a new numerical integration perspective on ReinMax.
Findings
ReinMax-Rao and ReinMax-CV outperform previous estimators in experiments.
The new estimators achieve lower variance and better training stability.
Alternative numerical methods can further improve gradient approximations.
Abstract
Machine learning models involving discrete latent variables require gradient estimators to facilitate backpropagation in a computationally efficient manner. The most recent addition to the Straight-Through family of estimators, ReinMax, can be viewed from a numerical ODE perspective as incorporating an approximation via Heun's method to reduce bias, but at the cost of high variance. In this work, we introduce the ReinMax-Rao and ReinMax-CV estimators which incorporate Rao-Blackwellisation and control variate techniques into ReinMax to reduce its variance. Our estimators demonstrate superior performance on training variational autoencoders with discrete latent spaces. Furthermore, we investigate the possibility of leveraging alternative numerical methods for constructing more accurate gradient approximations and present an alternative view of ReinMax from a simpler numerical integration…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsFace recognition and analysis · Generative Adversarial Networks and Image Synthesis · Stochastic Gradient Optimization Techniques
