Don't blame Dataset Shift! Shortcut Learning due to Gradients and Cross Entropy
Aahlad Puli, Lily Zhang, Yoav Wald, Rajesh Ranganath

TL;DR
This paper investigates why gradient-based cross-entropy training favors shortcut solutions even when stable features are sufficient, revealing the role of margin maximization bias and proposing new loss functions to promote reliance on stable features.
Contribution
It identifies the implicit max-margin bias in default-ERM as a cause of shortcut learning and introduces margin control (MARG-CTRL) to promote stable feature dependence in perception tasks.
Findings
Default-ERM prefers shortcut solutions due to margin maximization bias.
MARG-CTRL encourages uniform-margin solutions, reducing shortcut reliance.
MARG-CTRL improves performance on vision and language tasks by focusing on stable features.
Abstract
Common explanations for shortcut learning assume that the shortcut improves prediction under the training distribution but not in the test distribution. Thus, models trained via the typical gradient-based optimization of cross-entropy, which we call default-ERM, utilize the shortcut. However, even when the stable feature determines the label in the training distribution and the shortcut does not provide any additional information, like in perception tasks, default-ERM still exhibits shortcut learning. Why are such solutions preferred when the loss for default-ERM can be driven to zero using the stable feature alone? By studying a linear perception task, we show that default-ERM's preference for maximizing the margin leads to models that depend more on the shortcut than the stable feature, even without overparameterization. This insight suggests that default-ERM's implicit inductive bias…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Neural Networks and Applications · Multimodal Machine Learning Applications
