Learning Data Representations with Joint Diffusion Models
Kamil Deja, Tomasz Trzcinski, Jakub M. Tomczak

TL;DR
This paper introduces a joint diffusion model that combines data generation and classification in a single, stable framework, improving performance and enabling new interpretability methods.
Contribution
The work extends diffusion models with integrated classifiers for stable joint training, outperforming existing hybrid methods in both generation and classification tasks.
Findings
Outperforms state-of-the-art hybrid models in benchmarks
Enables effective visual counterfactual explanations
Achieves stable end-to-end training with shared parameters
Abstract
Joint machine learning models that allow synthesizing and classifying data often offer uneven performance between those tasks or are unstable to train. In this work, we depart from a set of empirical observations that indicate the usefulness of internal representations built by contemporary deep diffusion-based generative models not only for generating but also predicting. We then propose to extend the vanilla diffusion model with a classifier that allows for stable joint end-to-end training with shared parameterization between those objectives. The resulting joint diffusion model outperforms recent state-of-the-art hybrid methods in terms of both classification and generation quality on all evaluated benchmarks. On top of our joint training approach, we present how we can directly benefit from shared generative and discriminative representations by introducing a method for visual…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsGenerative Adversarial Networks and Image Synthesis · Domain Adaptation and Few-Shot Learning · Music and Audio Processing
MethodsDiffusion
