Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Learning Partial Equivariances from Data

About

Group Convolutional Neural Networks (G-CNNs) constrain learned features to respect the symmetries in the selected group, and lead to better generalization when these symmetries appear in the data. If this is not the case, however, equivariance leads to overly constrained models and worse performance. Frequently, transformations occurring in data can be better represented by a subset of a group than by a group as a whole, e.g., rotations in $[-90^{\circ}, 90^{\circ}]$. In such cases, a model that respects equivariance $\textit{partially}$ is better suited to represent the data. In addition, relevant transformations may differ for low and high-level features. For instance, full rotation equivariance is useful to describe edge orientations in a face, but partial rotation equivariance is better suited to describe face poses relative to the camera. In other words, the optimal level of equivariance may differ per layer. In this work, we introduce $\textit{Partial G-CNNs}$: G-CNNs able to learn layer-wise levels of partial and full equivariance to discrete, continuous groups and combinations thereof as part of training. Partial G-CNNs retain full equivariance when beneficial, e.g., for rotated MNIST, but adjust it whenever it becomes harmful, e.g., for classification of 6 / 9 digits or natural images. We empirically show that partial G-CNNs pair G-CNNs when full equivariance is advantageous, and outperform them otherwise.

David W. Romero, Suhas Lohit• 2021

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)
Accuracy56.53
3518
Image ClassificationCIFAR-10 (test)
Accuracy90.12
906
Image ClassificationCIFAR-100 (test)
Top-1 Acc61.46
275
ClassificationCIFAR10 (test)
Accuracy91.66
266
ClassificationCIFAR-100 (test)
Accuracy69.66
129
ClassificationCIFAR-10
Accuracy90.12
80
ClassificationCIFAR100
Accuracy61.46
66
ClassificationRotMNIST (test)
Classification Accuracy99.28
32
Image ClassificationPatchCamelyon (test)
Accuracy90.31
28
ClassificationRotMNIST
Accuracy99.23
8
Showing 10 of 10 rows

Other info

Code

Follow for update