Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Monomial Matrix Group Equivariant Neural Functional Networks

About

Neural functional networks (NFNs) have recently gained significant attention due to their diverse applications, ranging from predicting network generalization and network editing to classifying implicit neural representation. Previous NFN designs often depend on permutation symmetries in neural networks' weights, which traditionally arise from the unordered arrangement of neurons in hidden layers. However, these designs do not take into account the weight scaling symmetries of $\ReLU$ networks, and the weight sign flipping symmetries of $\sin$ or $\Tanh$ networks. In this paper, we extend the study of the group action on the network weights from the group of permutation matrices to the group of monomial matrices by incorporating scaling/sign-flipping symmetries. Particularly, we encode these scaling/sign-flipping symmetries by designing our corresponding equivariant and invariant layers. We name our new family of NFNs the Monomial Matrix Group Equivariant Neural Functional Networks (Monomial-NFN). Because of the expansion of the symmetries, Monomial-NFN has much fewer independent trainable parameters compared to the baseline NFNs in the literature, thus enhancing the model's efficiency. Moreover, for fully connected and convolutional neural networks, we theoretically prove that all groups that leave these networks invariant while acting on their weight spaces are some subgroups of the monomial matrix group. We provide empirical evidence to demonstrate the advantages of our model over existing baselines, achieving competitive performance and efficiency.

Viet-Hoang Tran, Thieu N. Vo, Tho H. Tran, An T. Nguyen, Tan M. Nguyen• 2024

Related benchmarks

TaskDatasetResultRank
Performance PredictionSmall CNN Zoo ReLU subset (test)
Kendall’s Tau0.923
35
INR classificationF-MNIST Implicit Neural Representations (test)
Accuracy61.44
21
INR classificationCIFAR-10 (test)
Accuracy34.26
13
INR editing (dilate)MNIST (test)
MSE0.069
13
INR classificationMNIST (test)
Accuracy68.87
13
Weight-space INR classificationCIFAR-10 (test)
Test Accuracy34.26
10
INR classificationFashionMNIST INR
Accuracy61.44
6
INR classificationMNIST INR
Accuracy68.87
6
Weight space style editing (Contrast enhancement)CIFAR-10 SIREN encoded (test)
MSE0.02
5
Showing 9 of 9 rows

Other info

Follow for update