Sharpness-Aware Minimization Enhances Feature Quality via Balanced Learning
Jacob Mitchell Springer, Vaishnavh Nagarajan, Aditi Raghunathan

TL;DR
This paper reveals that Sharpness-Aware Minimization (SAM) improves feature quality by balancing diverse features, especially in datasets with redundant or spurious features, leading to better out-of-distribution generalization.
Contribution
The paper uncovers a new mechanism of SAM that balances feature quality, distinct from flatness, enhancing learning in datasets with redundant or spurious features.
Findings
SAM improves feature quality in datasets with redundant or spurious features.
SAM outperforms SGD in datasets like CelebA, Waterbirds, CIFAR-MNIST, and DomainBed.
SAM enhances out-of-distribution generalization by balancing feature learning.
Abstract
Sharpness-Aware Minimization (SAM) has emerged as a promising alternative optimizer to stochastic gradient descent (SGD). The originally-proposed motivation behind SAM was to bias neural networks towards flatter minima that are believed to generalize better. However, recent studies have shown conflicting evidence on the relationship between flatness and generalization, suggesting that flatness does fully explain SAM's success. Sidestepping this debate, we identify an orthogonal effect of SAM that is beneficial out-of-distribution: we argue that SAM implicitly balances the quality of diverse features. SAM achieves this effect by adaptively suppressing well-learned features which gives remaining features opportunity to be learned. We show that this mechanism is beneficial in datasets that contain redundant or spurious features where SGD falls for the simplicity bias and would not…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsIndustrial Vision Systems and Defect Detection
MethodsStochastic Gradient Descent · Segment Anything Model
