CaTs and DAGs: Integrating Directed Acyclic Graphs with Transformers for Causally Constrained Predictions
Matthew J. Vowels, Mathieu Rochat, Sina Akbari

TL;DR
This paper introduces Causal Transformers (CaTs), a new neural network model that incorporates causal structures via DAGs to enhance robustness, interpretability, and reliability in real-world applications.
Contribution
The paper proposes CaTs, a novel model integrating DAG-based causal constraints into transformers, improving their robustness and interpretability while maintaining strong function approximation.
Findings
CaTs adhere to causal constraints specified by DAGs.
CaTs demonstrate improved robustness under covariate shift.
CaTs offer enhanced interpretability in neural network predictions.
Abstract
Artificial Neural Networks (ANNs), including fully-connected networks and transformers, are highly flexible and powerful function approximators, widely applied in fields like computer vision and natural language processing. However, their inability to inherently respect causal structures can limit their robustness, making them vulnerable to covariate shift and difficult to interpret/explain. This poses significant challenges for their reliability in real-world applications. In this paper, we introduce Causal Transformers (CaTs), a general model class designed to operate under predefined causal constraints, as specified by a Directed Acyclic Graph (DAG). CaTs retain the powerful function approximation abilities of traditional neural networks while adhering to the underlying structural constraints, improving robustness, reliability, and interpretability at inference time. This approach…
Peer Reviews
Decision·ICLR 2026 Poster
1. Incorporating graphs into the attention mechanism ensures that the causal constraints encoded in the graphs (and therefore the resulting causal inferences) are guaranteed to hold. This is a strong advantage over methods that attempt to achieve causality through regularization. The adjustments such as the omission of layer norm are done carefully to preserve this causal integrity. 2. Assumptions are stated clearly. 3. The provided experimental results seem promising.
4. The contributions of the paper are largely empirical and leave much to be desired from the theoretical aspects of the models. Most notably, I believe it would be a core contribution to include a theorem that guarantees that causal constraints (such as those from Causal Bayesian Networks) are correctly enforced in the CaT architecture, and therefore, the estimated queries will be correct. Another important contribution would be to support the claim that causal models like CaT can resolve issue
* The authors provide clear motivation and demonstrate it with a compelling example that compares predictive power and treatment effect estimation. * They propose a novel modification of the attention mechanism for embedding DAGs into the model. * They introduce a formulation for incorporating DAG structure into neural networks. * They benchmark their approach against other causal inference methods.
* The architecture needs to be trained for each DAG on each dataset, which can limit the method from being applied to larger datasets. * While the authors describe the ability to handle multidimensional embeddings as an advantage, the experiments are conducted on tabular datasets. It would be valuable to see whether the transformer architecture can excel in such settings. * The authors quantitatively justify not using normalization; however, it is not clear whether the pros would outweigh the
1. A key advantage highlighted is that by constraining the model to a (correct) causal DAG, it learns to ignore spurious correlations in the training data. This makes the model more robust and stable when faced with distributional shifts (i.e., covariate shift), a common failure point for traditional machine learning models. 2. The experiments show superior performance over the traditional random forest models even if the DAG is misspecified. 3. Unlike many causal inference methods designed f
1. The literature review part is not complete. There have been many works in the literature that consider using neural networks for counterfactual estimation. For instance, [1] establishes the connection between neural nets and causal models and constructs feed-forward networks for discrete causal models. [2] establishes similar results for causal models with mixed variables (continuous and discrete) and uses it for partial identification. Similar to this work, [3] proposes causal transformers f
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMental Health Research Topics · Machine Learning in Healthcare · Explainable Artificial Intelligence (XAI)
