Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism
Siqi Miao, Miaoyuan Liu, Pan Li

TL;DR
This paper introduces Graph Stochastic Attention (GSAT), an inherently interpretable graph learning method that improves stability and accuracy by selecting task-relevant subgraphs through stochastic attention based on the information bottleneck principle.
Contribution
The paper proposes GSAT, a novel stochastic attention mechanism for graph neural networks that enhances interpretability and prediction accuracy by filtering task-irrelevant information.
Findings
GSAT outperforms state-of-the-art methods by up to 20% in interpretation AUC.
GSAT improves prediction accuracy by up to 5%.
Selected subgraphs are provably free of spuriously correlated patterns.
Abstract
Interpretable graph learning is in need as many scientific applications depend on learning models to collect insights from graph-structured data. Previous works mostly focused on using post-hoc approaches to interpret pre-trained models (graph neural networks in particular). They argue against inherently interpretable models because the good interpretability of these models is often at the cost of their prediction accuracy. However, those post-hoc methods often fail to provide stable interpretation and may extract features that are spuriously correlated with the task. In this work, we address these issues by proposing Graph Stochastic Attention (GSAT). Derived from the information bottleneck principle, GSAT injects stochasticity to the attention weights to block the information from task-irrelevant graph components while learning stochasticity-reduced attention to select task-relevant…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Advanced Graph Neural Networks · Machine Learning in Healthcare
