Equivariant Architectures for Learning in Deep Weight Spaces
About
Designing machine learning architectures for processing neural networks in their raw weight matrix form is a newly introduced research direction. Unfortunately, the unique symmetry structure of deep weight spaces makes this design very challenging. If successful, such architectures would be capable of performing a wide range of intriguing tasks, from adapting a pre-trained network to a new domain to editing objects represented as functions (INRs or NeRFs). As a first step towards this goal, we present here a novel network architecture for learning in deep weight spaces. It takes as input a concatenation of weights and biases of a pre-trained MLP and processes it using a composition of layers that are equivariant to the natural permutation symmetry of the MLP's weights: Changing the order of neurons in intermediate layers of the MLP does not affect the function it represents. We provide a full characterization of all affine equivariant and invariant layers for these symmetries and show how these layers can be implemented using three basic operations: pooling, broadcasting, and fully connected layers applied to the input in an appropriate manner. We demonstrate the effectiveness of our architecture and its advantages over natural baselines in a variety of learning tasks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | MNIST (test) | Accuracy85.71 | 882 | |
| Image Classification | Fashion MNIST (test) | Accuracy67.06 | 568 | |
| Classification | CIFAR10 (test) | Accuracy34.45 | 266 | |
| INR classification | F-MNIST Implicit Neural Representations (test) | Accuracy67.06 | 15 | |
| Weight-space INR classification | MNIST (test) | Test Accuracy85.71 | 13 | |
| INR Dilation | MNIST INR | MSE2.58 | 8 | |
| INR classification | CIFAR-10 Implicit Neural Representations (test) | Accuracy34.45 | 7 | |
| INR classification | Augmented CIFAR-10 Implicit Neural Representations (test) | Accuracy41.27 | 7 | |
| Weight-space INR classification | FashionMNIST (test) | Test Accuracy65.5 | 5 |