Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Equivariant Adaptation of Large Pretrained Models

About

Equivariant networks are specifically designed to ensure consistent behavior with respect to a set of input transformations, leading to higher sample efficiency and more accurate and robust predictions. However, redesigning each component of prevalent deep neural network architectures to achieve chosen equivariance is a difficult problem and can result in a computationally expensive network during both training and inference. A recently proposed alternative towards equivariance that removes the architectural constraints is to use a simple canonicalization network that transforms the input to a canonical form before feeding it to an unconstrained prediction network. We show here that this approach can effectively be used to make a large pretrained network equivariant. However, we observe that the produced canonical orientations can be misaligned with those of the training distribution, hindering performance. Using dataset-dependent priors to inform the canonicalization function, we are able to make large pretrained models equivariant while maintaining their performance. This significantly improves the robustness of these models to deterministic transformations of the data, such as rotations. We believe this equivariant adaptation of large pretrained models can help their domain-specific applications with known symmetry priors.

Arnab Kumar Mondal, Siba Smarak Panigrahi, S\'ekou-Oumar Kaba, Sai Rajeswar, Siamak Ravanbakhsh• 2023

Related benchmarks

TaskDatasetResultRank
Instance SegmentationCOCO 2017 (val)--
1201
Point Cloud ClassificationModelNet40 (test)
Accuracy66.3
229
Shape Part SegmentationShapeNet (test)
Mean IoU79.39
164
ClassificationModelNet40 (test)--
120
ClassificationModelNet40
Accuracy88.49
108
Image ClassificationCIFAR-10 original (test)
Accuracy96.19
87
Image ClassificationCIFAR100 original (test)
Accuracy84.27
20
Image ClassificationCIFAR-10 data-augmented (+) (test)--
16
Image ClassificationSTL10 original (test)
Accuracy97.01
15
Image ClassificationCIFAR100 C8-augmented (test)
C8 Average Accuracy83.61
10
Showing 10 of 15 rows

Other info

Follow for update