Optimal Transport Model Distributional Robustness
Van-Anh Nguyen, Trung Le, Anh Tuan Bui, Thanh-Toan Do, and Dinh Phung

TL;DR
This paper introduces an optimal transport-based distributional robustness framework in model space, enabling robust training of deep learning models against adversarial attacks and distribution shifts, with theoretical insights and practical improvements.
Contribution
It develops a novel optimal transport approach in model space, incorporating sharpness awareness and unifying SAM within a probabilistic framework, with extensive empirical validation.
Findings
Framework improves robustness over baselines
Incorporates sharpness awareness into model training
Unifies SAM as a special case of the proposed method
Abstract
Distributional robustness is a promising framework for training deep learning models that are less vulnerable to adversarial examples and data distribution shifts. Previous works have mainly focused on exploiting distributional robustness in the data space. In this work, we explore an optimal transport-based distributional robustness framework in model spaces. Specifically, we examine a model distribution within a Wasserstein ball centered on a given model distribution that maximizes the loss. We have developed theories that enable us to learn the optimal robust center model distribution. Interestingly, our developed theories allow us to flexibly incorporate the concept of sharpness awareness into training, whether it's a single model, ensemble models, or Bayesian Neural Networks, by considering specific forms of the center model distribution. These forms include a Dirac delta…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsAdversarial Robustness in Machine Learning · Anomaly Detection Techniques and Applications · Advanced Neural Network Applications
MethodsSegment Anything Model · Sharpness-Aware Minimization
