Squared Earth Mover's Distance-based Loss for Training Deep Neural Networks
Le Hou, Chen-Ping Yu, Dimitris Samaras

TL;DR
This paper introduces a squared Earth Mover's Distance loss for training deep neural networks that accounts for inter-class relationships, improving performance especially on datasets with ordered classes, and includes a method to learn this relationship automatically.
Contribution
It proposes a novel squared EMD loss for deep learning that incorporates class relationships and a method to learn the ground distance matrix during training.
Findings
Achieves state-of-the-art results on datasets with inter-class relationships.
Automatically learning the class dissimilarity matrix yields comparable performance.
Maintains high performance even without strong inter-class relationships.
Abstract
In the context of single-label classification, despite the huge success of deep learning, the commonly used cross-entropy loss function ignores the intricate inter-class relationships that often exist in real-life tasks such as age classification. In this work, we propose to leverage these relationships between classes by training deep nets with the exact squared Earth Mover's Distance (also known as Wasserstein distance) for single-label classification. The squared EMD loss uses the predicted probabilities of all classes and penalizes the miss-predictions according to a ground distance matrix that quantifies the dissimilarities between classes. We demonstrate that on datasets with strong inter-class relationships such as an ordering between classes, our exact squared EMD losses yield new state-of-the-art results. Furthermore, we propose a method to automatically learn this matrix using…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMachine Learning and Data Classification · Domain Adaptation and Few-Shot Learning · Advanced Neural Network Applications
