Query Training: Learning a Worse Model to Infer Better Marginals in Undirected Graphical Models with Hidden Variables
Miguel L\'azaro-Gredilla, Wolfgang Lehrach, Nishad Gothoskar, Guangyao, Zhou, Antoine Dedieu, Dileep George

TL;DR
This paper introduces query training (QT), a novel method that trains undirected graphical models to produce more accurate marginals with approximate inference, improving performance on complex models with hidden variables.
Contribution
QT is a new training approach that optimizes PGMs for better approximate inference, maintaining query flexibility and outperforming existing methods on challenging models.
Findings
QT outperforms AdVIL on multiple datasets
It effectively learns complex grid Markov random fields
QT improves marginal estimation accuracy
Abstract
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model once, new probabilistic queries can be answered at test time without retraining. However, when using undirected PGMS with hidden variables, two sources of error typically compound in all but the simplest models (a) learning error (both computing the partition function and integrating out the hidden variables is intractable); and (b) prediction error (exact inference is also intractable). Here we introduce query training (QT), a mechanism to learn a PGM that is optimized for the approximate inference algorithm that will be paired with it. The resulting PGM is a worse model of the data (as measured by the likelihood), but it is tuned to produce better marginals for a given inference algorithm. Unlike prior works, our…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsBayesian Modeling and Causal Inference · Machine Learning and Algorithms · Machine Learning and Data Classification
MethodsProbability Guided Maxout
