A projection-based framework for gradient-free and parallel learning
Andreas Bergmeister, Manish Krishan Lal, Stefanie Jegelka, Suvrit Sra

TL;DR
This paper introduces a projection-based, gradient-free framework for neural network training that leverages feasibility algorithms, enabling parallelism and handling non-differentiable operations, demonstrated on various architectures and benchmarks.
Contribution
It presents a novel feasibility-seeking approach and a JAX-based software framework, PJAX, for parallel, gradient-free neural network training.
Findings
Successfully trained diverse architectures on standard benchmarks.
Demonstrated advantages in parallelism and non-differentiable operations.
Provided a flexible, GPU/TPU-compatible implementation.
Abstract
We present a feasibility-seeking approach to neural network training. This mathematical optimization framework is distinct from conventional gradient-based loss minimization and uses projection operators and iterative projection algorithms. We reformulate training as a large-scale feasibility problem: finding network parameters and states that satisfy local constraints derived from its elementary operations. Training then involves projecting onto these constraints, a local operation that can be parallelized across the network. We introduce PJAX, a JAX-based software framework that enables this paradigm. PJAX composes projection operators for elementary operations, automatically deriving the solution operators for the feasibility problems (akin to autodiff for derivatives). It inherently supports GPU/TPU acceleration, provides a familiar NumPy-like API, and is extensible. We train…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
