Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect
Kaihua Tang, Jianqiang Huang, Hanwang Zhang

TL;DR
This paper introduces a causal inference framework to address long-tailed classification, identifying the SGD momentum as a confounder and proposing interventions to improve model performance, achieving state-of-the-art results.
Contribution
It establishes a causal inference perspective on long-tailed classification, disentangling the effects of SGD momentum and proposing a principled solution to improve accuracy.
Findings
Achieved state-of-the-art results on Long-tailed CIFAR-10/-100, ImageNet-LT, and LVIS benchmarks.
Identified SGD momentum as a confounder affecting tail class predictions.
Proposed causal interventions that enhance long-tailed visual recognition.
Abstract
As the class size grows, maintaining a balanced dataset across many classes is challenging because the data are long-tailed in nature; it is even impossible when the sample-of-interest co-exists with each other in one collectable unit, e.g., multiple visual instances in one image. Therefore, long-tailed classification is the key to deep learning at scale. However, existing methods are mainly based on re-weighting/re-sampling heuristics that lack a fundamental theory. In this paper, we establish a causal inference framework, which not only unravels the whys of previous methods, but also derives a new principled solution. Specifically, our theory shows that the SGD momentum is essentially a confounder in long-tailed classification. On one hand, it has a harmful causal effect that misleads the tail prediction biased towards the head. On the other hand, its induced mediation also benefits…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Machine Learning and Algorithms · Machine Learning and Data Classification
MethodsCausal inference · Stochastic Gradient Descent
