Provable Multi-Task Representation Learning by Two-Layer ReLU Neural Networks
Liam Collins, Hamed Hassani, Mahdi Soltanolkotabi, Aryan Mokhtari,, Sanjay Shakkottai

TL;DR
This paper proves that multi-task pretraining with two-layer ReLU neural networks enables feature learning of underlying data projections, outperforming single-task training in recovering true features with sample complexity independent of input dimension.
Contribution
It provides the first theoretical proof that nonlinear multi-task training induces feature learning, specifically recovering data projections in high-dimensional spaces.
Findings
Multi-task pretraining induces a pseudo-contrastive loss.
Gradient-based training recovers the true data projection.
Single-task training generally fails to learn all ground-truth features.
Abstract
An increasingly popular machine learning paradigm is to pretrain a neural network (NN) on many tasks offline, then adapt it to downstream tasks, often by re-training only the last linear layer of the network. This approach yields strong downstream performance in a variety of contexts, demonstrating that multitask pretraining leads to effective feature learning. Although several recent theoretical studies have shown that shallow NNs learn meaningful features when either (i) they are trained on a {\em single} task or (ii) they are {\em linear}, very little is known about the closer-to-practice case of {\em nonlinear} NNs trained on {\em multiple} tasks. In this work, we present the first results proving that feature learning occurs during training with a nonlinear model on multiple tasks. Our key insight is that multi-task pretraining induces a pseudo-contrastive loss that favors…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDomain Adaptation and Few-Shot Learning · Stochastic Gradient Optimization Techniques · Adversarial Robustness in Machine Learning
MethodsLinear Layer · ALIGN
