Fixing the NTK: From Neural Network Linearizations to Exact Convex Programs
Rajat Vadiraj Dwaraknath, Tolga Ergen, Mert Pilanci

TL;DR
This paper bridges neural tangent kernels and convex reformulations of ReLU networks, showing how to optimize the NTK via multiple kernel learning to achieve exact convex solutions with improved predictive performance.
Contribution
It establishes a connection between NTK and MKL for gated ReLU networks, enabling exact convex optimization and improved kernel-based training.
Findings
NTK is equivalent to a specific MKL kernel with fixed mask weights.
Iterative reweighting yields the optimal MKL kernel matching the convex reformulation.
Numerical simulations support the theoretical connection and improvements.
Abstract
Recently, theoretical analyses of deep neural networks have broadly focused on two directions: 1) Providing insight into neural network training by SGD in the limit of infinite hidden-layer width and infinitesimally small learning rate (also known as gradient flow) via the Neural Tangent Kernel (NTK), and 2) Globally optimizing the regularized training objective via cone-constrained convex reformulations of ReLU networks. The latter research direction also yielded an alternative formulation of the ReLU network, called a gated ReLU network, that is globally optimizable via efficient unconstrained convex programs. In this work, we interpret the convex program for this gated ReLU network as a Multiple Kernel Learning (MKL) model with a weighted data masking feature map and establish a connection to the NTK. Specifically, we show that for a particular choice of mask weights that do not…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsSparse and Compressive Sensing Techniques · Domain Adaptation and Few-Shot Learning · Stochastic Gradient Optimization Techniques
MethodsNeural Tangent Kernel · Stochastic Gradient Descent
