On Feature Learning in Neural Networks with Global Convergence Guarantees
Zhengdao Chen, Eric Vanden-Eijnden, Joan Bruna

TL;DR
This paper proves global convergence guarantees for training wide neural networks with feature learning capabilities, demonstrating linear convergence rates and empirical advantages over NTK regimes.
Contribution
It provides the first non-asymptotic convergence analysis for feature learning in wide NNs under gradient flow, including multi-layer models.
Findings
Training loss converges linearly to zero in wide shallow NNs with general activation functions.
Multi-layer models trained via gradient flow exhibit feature learning and better generalization than NTK models.
Empirical results confirm the theoretical advantages of feature learning over NTK regimes.
Abstract
We study the optimization of wide neural networks (NNs) via gradient flow (GF) in setups that allow feature learning while admitting non-asymptotic global convergence guarantees. First, for wide shallow NNs under the mean-field scaling and with a general class of activation functions, we prove that when the input dimension is no less than the size of the training set, the training loss converges to zero at a linear rate under GF. Building upon this analysis, we study a model of wide multi-layer NNs whose second-to-last layer is trained via GF, for which we also prove a linear-rate convergence of the training loss to zero, but regardless of the input dimension. We also show empirically that, unlike in the Neural Tangent Kernel (NTK) regime, our multi-layer model exhibits feature learning and can achieve better generalization performance than its NTK counterpart.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsStochastic Gradient Optimization Techniques · Model Reduction and Neural Networks · Neural Networks and Applications
MethodsNeural Tangent Kernel
