Optimal Inference Schedules for Masked Diffusion Models
Sitan Chen, Kevin Cong, Jerry Li

TL;DR
This paper provides a rigorous analysis of parallel sampling in masked diffusion models, establishing bounds and schedules that optimize inference efficiency while maintaining sampling quality.
Contribution
It introduces an exact divergence characterization for any distribution and schedule, linking it to univariate function approximation, and proposes bounds and schedules based on information-theoretic properties.
Findings
Optimal unmasking schedule derived from function approximation theory.
Parallel sampling in $O(log n)$ steps possible under certain distribution properties.
New bounds and schedules improve understanding of diffusion model inference.
Abstract
A major bottleneck of standard auto-regressive large language models is that their inference process is inherently sequential, resulting in very long and costly inference times. To circumvent this, practitioners proposed a class of language models called diffusion language models, of which the masked diffusion model (MDM) is the most successful. The MDM is able to sample tokens out-of-order and, ostensibly, many tokens at once and in parallel. However, there is very limited rigorous understanding of how much parallel sampling these models can perform without noticeable degradation in their sampling performance. Prior work of Li and Cai obtained some preliminary bounds, but these are not tight for many natural classes of distributions. In this work, we give a new, exact characterization of the expected divergence between the true distribution and the sampled distribution, for any…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsGenerative Adversarial Networks and Image Synthesis · Machine Learning and Algorithms · Stochastic Gradient Optimization Techniques
