Simplicity Bias of Two-Layer Networks beyond Linearly Separable Data
Nikita Tsoy, Nikola Konstantinov

TL;DR
This paper investigates the simplicity bias in two-layer neural networks beyond linearly separable data, showing how early training favors simple features and how later training can learn more complex, potentially more transferable features.
Contribution
It provides a theoretical analysis of simplicity bias for general datasets in two-layer networks, extending previous work limited to linear separability.
Findings
Features cluster around few directions early in training.
Simplicity bias increases during later training stages.
Features learned later may be more useful for out-of-distribution transfer.
Abstract
Simplicity bias, the propensity of deep models to over-rely on simple features, has been identified as a potential reason for limited out-of-distribution generalization of neural networks (Shah et al., 2020). Despite the important implications, this phenomenon has been theoretically confirmed and characterized only under strong dataset assumptions, such as linear separability (Lyu et al., 2021). In this work, we characterize simplicity bias for general datasets in the context of two-layer neural networks initialized with small weights and trained with gradient flow. Specifically, we prove that in the early training phases, network features cluster around a few directions that do not depend on the size of the hidden layer. Furthermore, for datasets with an XOR-like pattern, we precisely identify the learned features and demonstrate that simplicity bias intensifies during later training…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDistributed Sensor Networks and Detection Algorithms · Advanced MIMO Systems Optimization · Energy Efficient Wireless Sensor Networks
