Learning an Invertible Output Mapping Can Mitigate Simplicity Bias in Neural Networks
Sravanti Addepalli, Anshul Nasery, R. Venkatesh Babu, Praneeth, Netrapalli, Prateek Jain

TL;DR
This paper introduces a method to reduce simplicity bias in neural networks by using an invertible output mapping, leading to improved out-of-distribution performance and more diverse feature utilization.
Contribution
The paper proposes an invertible output mapping technique to mitigate simplicity bias, enhancing feature diversity and robustness against distribution shifts in neural networks.
Findings
Up to 15% improvement in OOD accuracy on semi-synthetic datasets.
Significant gains over SOTA methods on DomainBed benchmark.
Effective mitigation of simplicity bias through invertible output mapping.
Abstract
Deep Neural Networks are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that Simplicity Bias (SB) of DNNs - bias towards learning only the simplest features - is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term Feature Replication Hypothesis,…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Advanced Neural Network Applications · Machine Learning and Data Classification
MethodsLinear Layer · Stochastic Gradient Descent
