Supervised learning pays attention
Erin Craig, Robert Tibshirani

TL;DR
This paper introduces an attention-based supervised learning method that creates personalized, interpretable models for each prediction, improving performance on heterogeneous data without sacrificing simplicity.
Contribution
It adapts attention mechanisms to supervised learning for tabular data, enabling local, interpretable models that adapt to data heterogeneity and distributional shifts.
Findings
Attention weighting improves predictive accuracy on real and simulated data.
The method provides interpretability by identifying key features and relevant training observations.
Theoretical analysis shows lower mean squared error compared to standard linear models in certain data settings.
Abstract
In-context learning with attention enables large neural networks to make context-specific predictions by selectively focusing on relevant examples. Here, we adapt this idea to supervised learning procedures such as lasso regression and gradient boosting, for tabular data. Our goals are to (1) flexibly fit personalized models for each prediction point and (2) retain model simplicity and interpretability. Our method fits a local model for each test observation by weighting the training data according to attention, a supervised similarity measure that emphasizes features and interactions that are predictive of the outcome. Attention weighting allows the method to adapt to heterogeneous data in a data-driven way, without requiring cluster or similarity pre-specification. Further, our approach is uniquely interpretable: for each test observation, we identify which features are most…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Generative Adversarial Networks and Image Synthesis · Machine Learning in Healthcare
