Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

A projection-based framework for gradient-free and parallel learning

About

We present a feasibility-seeking approach to neural network training. This mathematical optimization framework is distinct from conventional gradient-based loss minimization and uses projection operators and iterative projection algorithms. We reformulate training as a large-scale feasibility problem: finding network parameters and states that satisfy local constraints derived from its elementary operations. Training then involves projecting onto these constraints, a local operation that can be parallelized across the network. We introduce PJAX, a JAX-based software framework that enables this paradigm. PJAX composes projection operators for elementary operations, automatically deriving the solution operators for the feasibility problems (akin to autodiff for derivatives). It inherently supports GPU/TPU acceleration, provides a familiar NumPy-like API, and is extensible. We train diverse architectures (MLPs, CNNs, RNNs) on standard benchmarks using PJAX, demonstrating its functionality and generality. Our results show that this approach is a compelling alternative to gradient-based training, with clear advantages in parallelism and the ability to handle non-differentiable operations.

Andreas Bergmeister, Manish Krishan Lal, Stefanie Jegelka, Suvrit Sra• 2025

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR10 (train)
Accuracy55.6
144
Image ClassificationMNIST (train)
Train Accuracy96.9
107
Image ClassificationMNIST (test)
Accuracy95.8
44
ClassificationHiggs (train)
Accuracy67
21
Image ClassificationCIFAR-10 (test)
Accuracy33.7
19
Character-level Language ModelingShakespeare (test)
Accuracy32.4
12
Character-level Language ModelingShakespeare (train)
Accuracy34
12
Showing 7 of 7 rows

Other info

Follow for update