Counterfactual Generative Models for Time-Varying Treatments
Shenghao Wu, Wenbin Zhou, Minshuo Chen, Shixiang Zhu

TL;DR
This paper introduces a novel conditional generative framework for estimating high-dimensional, time-varying treatment counterfactuals, addressing distribution mismatch and outperforming existing methods in synthetic and real-world data.
Contribution
The proposed method is a new generative approach that models time-varying treatments without explicit density estimation, improving counterfactual sample quality.
Findings
Generates high-quality counterfactual samples.
Outperforms state-of-the-art baselines.
Effective on synthetic and real-world data.
Abstract
Estimating the counterfactual outcome of treatment is essential for decision-making in public health and clinical science, among others. Often, treatments are administered in a sequential, time-varying manner, leading to an exponentially increased number of possible counterfactual outcomes. Furthermore, in modern applications, the outcomes are high-dimensional and conventional average treatment effect estimation fails to capture disparities in individuals. To tackle these challenges, we propose a novel conditional generative framework capable of producing counterfactual samples under time-varying treatment, without the need for explicit density estimation. Our method carefully addresses the distribution mismatch between the observed and counterfactual distributions via a loss function based on inverse probability re-weighting, and supports integration with state-of-the-art conditional…
Peer Reviews
Decision·Submitted to ICLR 2024
The proposed framework is simple, easy to use, and accessible. Judging from the experiments, the performance of the proposed method seems to be good.
- Proposition 1 doesn’t make sense. How could $\bar{a}$, a fixed value, be drawn from $\mathcal{D}$? Why is it a sum instead of an average? Wouldn’t the RHS of (4) be the same for all $\bar{a}$ while the LHS is supposed to be different? And why is the index in (5) from $t-d$ instead of from $t-d+1$? In proof of proposition 1, where does the expectation over $\bar{a}$ come from? I would be concerned if the authors actually used this formula in their experiments. - I don’t seem to understand the
- Paper tackles the complex issue of estimating counterfactual outcomes in the face of time-varying treatment effects. -The proposed method adeptly handles high-dimensional outcomes. - Capable of generating counterfactual samples without imposing rigid assumptions on the distribution of the counterfactual outcome.
IPTW values can be notably small and, as highlighted by the author, require precise definition. This circumstance can exacerbate in sequential treatment scenarios. Given we're handling a treatment sequence, it's important to note that the counterfactual treatment is not unique. However, the notation used does not reflect this. I believe the following two papers could also serve as baseline references: 1. "Disentangled Counterfactual Recurrent Networks for Treatment Effect Inference Over Time"
The paper is technically sound. Generally, it is not hard for readers to follow. The ideas are presented well, but still, the clarity of the paper can be further improved.
Although readers should be able to follow and understand the notions presented in the paper, the paper is not organized well. For instance, the authors defer the standard causal assumptions to the Appendix. The authors may not give detailed explanations about the causal assumptions in the main paper, but at least mention the names of the causal assumptions in the paper. Please refer to questions for further weakness.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMachine Learning in Healthcare · Advanced Causal Inference Techniques · Statistical Methods and Inference
MethodsDiffusion · Test
