Out-of-distribution Generalization for Total Variation based Invariant Risk Minimization
Yuanchao Wang, Zhao-Rong Lai, Tianqi Zhong

TL;DR
This paper introduces OOD-TV-IRM, a primal-dual optimization framework that enhances out-of-distribution generalization in invariant risk minimization by balancing training risk and OOD robustness.
Contribution
It extends IRM-TV to a Lagrangian multiplier model, providing a novel primal-dual optimization approach for better OOD generalization.
Findings
OOD-TV-IRM outperforms IRM-TV in most experiments.
The model achieves a semi-Nash equilibrium between training loss and OOD robustness.
A convergent primal-dual algorithm facilitates adversarial learning.
Abstract
Invariant risk minimization is an important general machine learning framework that has recently been interpreted as a total variation model (IRM-TV). However, how to improve out-of-distribution (OOD) generalization in the IRM-TV setting remains unsolved. In this paper, we extend IRM-TV to a Lagrangian multiplier model named OOD-TV-IRM. We find that the autonomous TV penalty hyperparameter is exactly the Lagrangian multiplier. Thus OOD-TV-IRM is essentially a primal-dual optimization model, where the primal optimization minimizes the entire invariant risk and the dual optimization strengthens the TV penalty. The objective is to reach a semi-Nash equilibrium where the balance between the training loss and OOD generalization is maintained. We also develop a convergent primal-dual algorithm that facilitates an adversarial learning scheme. Experimental results show that OOD-TV-IRM…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsStochastic Gradient Optimization Techniques · Generative Adversarial Networks and Image Synthesis · Adversarial Robustness in Machine Learning
