Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Permutation Equivariant Neural Functionals

About

This work studies the design of neural networks that can process the weights or gradients of other neural networks, which we refer to as neural functional networks (NFNs). Despite a wide range of potential applications, including learned optimization, processing implicit neural representations, network editing, and policy evaluation, there are few unifying principles for designing effective architectures that process the weights of other networks. We approach the design of neural functionals through the lens of symmetry, in particular by focusing on the permutation symmetries that arise in the weights of deep feedforward networks because hidden layer neurons have no inherent order. We introduce a framework for building permutation equivariant neural functionals, whose architectures encode these symmetries as an inductive bias. The key building blocks of this framework are NF-Layers (neural functional layers) that we constrain to be permutation equivariant through an appropriate parameter sharing scheme. In our experiments, we find that permutation equivariant neural functionals are effective on a diverse set of tasks that require processing the weights of MLPs and CNNs, such as predicting classifier generalization, producing "winning ticket" sparsity masks for initializations, and classifying or editing implicit neural representations (INRs). In addition, we provide code for our models and experiments at https://github.com/AllanYangZhou/nfn.

Allan Zhou, Kaien Yang, Kaylee Burns, Adriano Cardace, Yiding Jiang, Samuel Sokota, J. Zico Kolter, Chelsea Finn• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationMNIST (test)
Accuracy95
882
Image ClassificationFashion MNIST (test)
Accuracy68.94
568
ClassificationCIFAR10 (test)
Accuracy33.41
266
Image ClassificationFashionMNIST (test)
Accuracy75.6
218
Image ClassificationMNIST-10 (test)
Test Accuracy92.9
19
3D Object ClassificationScanNet 10
Accuracy65.9
17
INR classificationF-MNIST Implicit Neural Representations (test)
Accuracy68.94
15
Weight-space INR classificationMNIST (test)
Test Accuracy92.9
13
INR editing (dilate)MNIST (test)
MSE0.068
8
INR DilationMNIST INR
MSE2.55
8
Showing 10 of 31 rows

Other info

Code

Follow for update