Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Neural Functional Transformers

About

The recent success of neural networks as implicit representation of data has driven growing interest in neural functionals: models that can process other neural networks as input by operating directly over their weight spaces. Nevertheless, constructing expressive and efficient neural functional architectures that can handle high-dimensional weight-space objects remains challenging. This paper uses the attention mechanism to define a novel set of permutation equivariant weight-space layers and composes them into deep equivariant models called neural functional Transformers (NFTs). NFTs respect weight-space permutation symmetries while incorporating the advantages of attention, which have exhibited remarkable success across multiple domains. In experiments processing the weights of feedforward MLPs and CNNs, we find that NFTs match or exceed the performance of prior weight-space methods. We also leverage NFTs to develop Inr2Array, a novel method for computing permutation invariant latent representations from the weights of implicit neural representations (INRs). Our proposed method improves INR classification accuracy by up to $+17\%$ over existing methods. We provide an implementation of our layers at https://github.com/AllanYangZhou/nfn.

Allan Zhou, Kaien Yang, Yiding Jiang, Kaylee Burns, Winnie Xu, Samuel Sokota, J. Zico Kolter, Chelsea Finn• 2023

Related benchmarks

TaskDatasetResultRank
Weight-space INR classificationMNIST (test)
Test Accuracy98.5
13
INR editing (dilate)MNIST (test)
MSE0.051
8
Weight-space INR classificationFashionMNIST (test)
Test Accuracy79.3
5
Weight-space INR classificationCIFAR-10 (test)
Test Accuracy63.4
4
Predicting CNN classifier generalizationSmall CNN Zoo CIFAR-10-GS (test)
Rank Correlation (tau)0.926
4
Predicting CNN classifier generalizationSmall CNN Zoo SVHN-GS (test)
Rank Correlation (tau)0.858
4
INR editing (contrast)CIFAR (test)
MSE0.02
3
INR editing (erode)MNIST (test)
MSE0.0194
3
INR editing (gradient)MNIST (test)
MSE0.0484
3
INR editing (gradient)FashionMNIST (test)
MSE0.08
3
Showing 10 of 10 rows

Other info

Code

Follow for update