Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Frame Averaging for Invariant and Equivariant Network Design

About

Many machine learning tasks involve learning functions that are known to be invariant or equivariant to certain symmetries of the input data. However, it is often challenging to design neural network architectures that respect these symmetries while being expressive and computationally efficient. For example, Euclidean motion invariant/equivariant graph or point cloud neural networks. We introduce Frame Averaging (FA), a general purpose and systematic framework for adapting known (backbone) architectures to become invariant or equivariant to new symmetry types. Our framework builds on the well known group averaging operator that guarantees invariance or equivariance but is intractable. In contrast, we observe that for many important classes of symmetries, this operator can be replaced with an averaging operator over a small subset of the group elements, called a frame. We show that averaging over a frame guarantees exact invariance or equivariance while often being much simpler to compute than averaging over the entire group. Furthermore, we prove that FA-based models have maximal expressive power in a broad setting and in general preserve the expressive power of their backbone architectures. Using frame averaging, we propose a new class of universal Graph Neural Networks (GNNs), universal Euclidean motion invariant point cloud networks, and Euclidean motion invariant Message Passing (MP) GNNs. We demonstrate the practical effectiveness of FA on several applications including point cloud normal estimation, beyond $2$-WL graph separation, and $n$-body dynamics prediction, achieving state-of-the-art results in all of these benchmarks.

Omri Puny, Matan Atzmon, Heli Ben-Hamu, Ishan Misra, Aditya Grover, Edward J. Smith, Yaron Lipman• 2021

Related benchmarks

TaskDatasetResultRank
Node ClassificationPATTERN (test)
Test Accuracy80.015
88
Graph ClassificationEXP (test)
Accuracy100
33
Antibody GenerationPaired OAS (test)
W1 (Natural)0.4141
16
Antibody Binder GenerationTrastuzumab CDR H3 mutant dataset (test)
W1 (Natural)0.0018
13
Aptamer ScreeningGFP
Top-10 Precision0.3
12
Graph SeparationGRAPH8c random initialization
Non-Separated Pairs0.00e+0
11
Graph SeparationEXP random initialization
Non-separated Graph Pairs0.00e+0
11
Aptamer ScreeningHNRNPC
Top-10 Precision6.66
10
Position Regressionn-body (test)
Position MSE0.0057
9
Aptamer ScreeningNELF
Top-10 Precision16.66
9
Showing 10 of 17 rows

Other info

Follow for update