Which Features are Learnt by Contrastive Learning? On the Role of Simplicity Bias in Class Collapse and Feature Suppression
Yihao Xue, Siddharth Joshi, Eric Gan, Pin-Yu Chen, Baharan, Mirzasoleiman

TL;DR
This paper provides a theoretical framework explaining how contrastive learning tends to favor simpler features, leading to class collapse and feature suppression, and proposes solutions like increasing embedding dimensionality and better data augmentations.
Contribution
It offers the first rigorous theory on feature learning in contrastive learning, linking simplicity bias to class collapse and feature suppression, and suggests practical improvements.
Findings
Gradient bias towards simplicity causes class collapse.
Increasing embedding size mitigates feature suppression.
Better data augmentations improve feature learning.
Abstract
Contrastive learning (CL) has emerged as a powerful technique for representation learning, with or without label supervision. However, supervised CL is prone to collapsing representations of subclasses within a class by not capturing all their features, and unsupervised CL may suppress harder class-relevant features by focusing on learning easy class-irrelevant features; both significantly compromise representation quality. Yet, there is no theoretical understanding of \textit{class collapse} or \textit{feature suppression} at \textit{test} time. We provide the first unified theoretically rigorous framework to determine \textit{which} features are learnt by CL. Our analysis indicate that, perhaps surprisingly, bias of (stochastic) gradient descent towards finding simpler solutions is a key factor in collapsing subclass representations and suppressing harder class-relevant features.…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Machine Learning and ELM
