Equivariance with Learned Canonicalization Functions
About
Symmetry-based neural networks often constrain the architecture in order to achieve invariance or equivariance to a group of transformations. In this paper, we propose an alternative that avoids this architectural constraint by learning to produce canonical representations of the data. These canonicalization functions can readily be plugged into non-equivariant backbone architectures. We offer explicit ways to implement them for some groups of interest. We show that this approach enjoys universality while providing interpretable insights. Our main hypothesis, supported by our empirical results, is that learning a small neural network to perform canonicalization is better than using predefined heuristics. Our experiments show that learning the canonicalization function is competitive with existing techniques for learning equivariant functions across many tasks, including image classification, $N$-body dynamics prediction, point cloud classification and part segmentation, while being faster across the board.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Node Classification | PATTERN (test) | Test Accuracy86.534 | 88 | |
| Image Classification | CIFAR-10 original (test) | Accuracy95 | 57 | |
| Graph Classification | EXP (test) | Accuracy50 | 33 | |
| Image Classification | CIFAR100 original (test) | Accuracy80.86 | 20 | |
| Image Classification | CIFAR-10 data-augmented (+) (test) | -- | 16 | |
| Graph Separation | GRAPH8c random initialization | Non-Separated Pairs0.00e+0 | 11 | |
| Graph Separation | EXP random initialization | Non-separated Graph Pairs0.00e+0 | 11 | |
| Image Classification | CIFAR100 C8-augmented (test) | C8 Average Accuracy80.48 | 10 | |
| Image Classification | STL10 C8-augmented (test) | C8-Avg Accuracy94.67 | 10 | |
| Image Classification | STL10 original (test) | Accuracy95.3 | 10 |