How Transformers Learn Causal Structure with Gradient Descent
Eshaan Nichani, Alex Damian, Jason D. Lee

TL;DR
This paper investigates how transformers learn causal structures through gradient descent, revealing that the attention mechanism encodes causal graphs and demonstrating this with theoretical proofs and empirical validation.
Contribution
The paper provides the first theoretical analysis of how gradient descent enables transformers to learn causal structures via attention mechanisms.
Findings
Gradient of attention encodes mutual information between tokens
Largest gradient entries correspond to causal graph edges
Transformers can recover various causal structures in practice
Abstract
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
How Transformers Learn Causal Structure with Gradient Descent· youtube
Taxonomy
TopicsBayesian Modeling and Causal Inference
