SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
Gowthami Somepalli, Micah Goldblum, Avi Schwarzschild, C. Bayan Bruss,, Tom Goldstein

TL;DR
SAINT is a novel neural network architecture that employs row and column attention with contrastive pre-training, significantly enhancing performance on tabular data tasks and surpassing traditional gradient boosting methods.
Contribution
Introduces SAINT, a hybrid deep learning model with attention mechanisms and contrastive pre-training for improved tabular data analysis.
Findings
SAINT outperforms previous deep learning models on benchmark datasets.
SAINT surpasses gradient boosting methods like XGBoost, CatBoost, and LightGBM.
Contrastive pre-training benefits when labels are limited.
Abstract
Tabular data underpins numerous high-impact applications of machine learning from fraud detection to genomics and healthcare. Classical approaches to solving tabular problems, such as gradient boosting and random forests, are widely used by practitioners. However, recent deep learning methods have achieved a degree of performance competitive with popular techniques. We devise a hybrid deep learning approach to solving tabular data problems. Our method, SAINT, performs attention over both rows and columns, and it includes an enhanced embedding method. We also study a new contrastive self-supervised pre-training method for use when labels are scarce. SAINT consistently improves performance over previous deep learning methods, and it even outperforms gradient boosting methods, including XGBoost, CatBoost, and LightGBM, on average over a variety of benchmark tasks.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsImbalanced Data Classification Techniques · Machine Learning and Data Classification · Anomaly Detection Techniques and Applications
MethodsMixup · Dense Connections · CutMix · Feedforward Network · SAINT
