TaskMet: Task-Driven Metric Learning for Model Learning
Dishank Bansal, Ricky T. Q. Chen, Mustafa Mukadam, Brandon Amos

TL;DR
TaskMet introduces a task-driven metric learning approach that optimizes model training to better align with downstream task requirements without altering the core prediction model.
Contribution
It proposes learning a metric in the prediction space guided by task loss signals, improving downstream task performance while maintaining the original prediction model.
Findings
Enhanced performance in portfolio optimization tasks
Improved reinforcement learning in noisy environments
Effective alignment of training with downstream objectives
Abstract
Deep learning models are often deployed in downstream tasks that the training procedure may not be aware of. For example, models solely trained to achieve accurate predictions may struggle to perform well on downstream tasks because seemingly small prediction errors may incur drastic task errors. The standard end-to-end learning approach is to make the task loss differentiable or to introduce a differentiable surrogate that the model can be trained on. In these settings, the task loss needs to be carefully balanced with the prediction loss because they may have conflicting objectives. We propose take the task loss signal one level deeper than the parameters of the model and use it to learn the parameters of the loss function the model is trained on, which can be done by learning a metric in the prediction space. This approach does not alter the optimal prediction model itself, but…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Adversarial Robustness in Machine Learning · Machine Learning and Data Classification
MethodsAttentive Walk-Aggregating Graph Neural Network
