Diffusion Tree Sampling: Scalable inference-time alignment of diffusion models
Vineet Jain, Kusha Sareen, Mohammad Pedramfar, Siamak Ravanbakhsh

TL;DR
This paper introduces Diffusion Tree Sampling, a scalable inference-time alignment method for diffusion models that reuses past computations to improve sample quality efficiently, matching or surpassing baseline performance with less compute.
Contribution
The paper proposes a novel tree-based search method for diffusion models that reuses previous computations, enabling scalable and efficient inference-time alignment.
Findings
DTS matches the best FID scores on MNIST and CIFAR-10 with 10x less compute.
DTS$^\star$ effectively finds high reward samples with 5x less compute.
The method provides an anytime algorithm that improves with additional compute.
Abstract
Adapting a pretrained diffusion model to new objectives at inference time remains an open problem in generative modeling. Existing steering methods suffer from inaccurate value estimation, especially at high noise levels, which biases guidance. Moreover, information from past runs is not reused to improve sample quality, resulting in inefficient use of compute. Inspired by the success of Monte Carlo Tree Search, we address these limitations by casting inference-time alignment as a search problem that reuses past computations. We introduce a tree-based approach that samples from the reward-aligned target density by propagating terminal rewards back through the diffusion chain and iteratively refining value estimates with each additional generation. Our proposed method, Diffusion Tree Sampling (DTS), produces asymptotically exact samples from the target distribution in the limit of…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsAdvanced Neuroimaging Techniques and Applications
MethodsDiffusion
