Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Make Sharpness-Aware Minimization Stronger: A Sparsified Perturbation Approach

About

Deep neural networks often suffer from poor generalization caused by complex and non-convex loss landscapes. One of the popular solutions is Sharpness-Aware Minimization (SAM), which smooths the loss landscape via minimizing the maximized change of training loss when adding a perturbation to the weight. However, we find the indiscriminate perturbation of SAM on all parameters is suboptimal, which also results in excessive computation, i.e., double the overhead of common optimizers like Stochastic Gradient Descent (SGD). In this paper, we propose an efficient and effective training scheme coined as Sparse SAM (SSAM), which achieves sparse perturbation by a binary mask. To obtain the sparse mask, we provide two solutions which are based onFisher information and dynamic sparse training, respectively. In addition, we theoretically prove that SSAM can converge at the same rate as SAM, i.e., $O(\log T/\sqrt{T})$. Sparse SAM not only has the potential for training acceleration but also smooths the loss landscape effectively. Extensive experimental results on CIFAR10, CIFAR100, and ImageNet-1K confirm the superior efficiency of our method to SAM, and the performance is preserved or even better with a perturbation of merely 50% sparsity. Code is availiable at https://github.com/Mi-Peng/Sparse-Sharpness-Aware-Minimization.

Peng Mi, Li Shen, Tianhe Ren, Yiyi Zhou, Xiaoshuai Sun, Rongrong Ji, Dacheng Tao• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)--
3518
Image ClassificationImageNet-1k (val)--
1453
Natural Language UnderstandingGLUE (test dev)
MRPC Accuracy92.73
81
Image ClassificationCIFAR10 (test)
Accuracy97.72
76
Natural Language InferenceMNLI
Max Memory (MB)7.43e+3
5
Linguistic AcceptabilityCOLA
Max Memory (MB)7.29e+3
5
Showing 6 of 6 rows

Other info

Code

Follow for update