Sharpness-Aware Machine Unlearning
About
We characterize the effectiveness of Sharpness-aware minimization (SAM) under machine unlearning scheme, where unlearning forget signals interferes with learning retain signals. While previous work prove that SAM improves generalization with noise memorization prevention, we show that SAM abandons such denoising property when fitting the forget set, leading to altered generalization depending on signal strength. We further characterize the signal surplus of SAM in the order of signal strength, which enables learning from less retain signals to maintain model performance and putting more weight on unlearning the forget set. Empirical studies show that SAM outperforms SGD with relaxed requirement for retain signals and can enhance various unlearning methods either as pretrain or unlearn algorithm. Motivated by our refined characterization of SAM unlearning and observing that overfitting can benefit more stringent sample-specific unlearning, we propose Sharp MinMax, which splits the model into two to learn retain signals with SAM and unlearn forget signals with sharpness maximization, achieving best performance. Extensive experiments show that SAM enhances unlearning across varying difficulties measured by memorization, yielding decreased feature entanglement between retain and forget sets, stronger resistance to membership inference attacks, and a flatter loss landscape. Our observations generalize to more noised data, different optimizers, and different architectures.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Machine Unlearning | CIFAR-100 | -- | 48 | |
| Machine Unlearning | CIFAR-100 (test) | -- | 45 | |
| Membership Inference Attack | CIFAR-100 High forget set | -- | 40 | |
| Membership Inference Attack | CIFAR-100 Mid forget set | -- | 40 | |
| Membership Inference Attack | CIFAR-100 Low forget set | -- | 40 | |
| Membership Inference Attack | CIFAR-100 AVG forget set | -- | 40 | |
| Machine Unlearning | CIFAR-100 (forget set) | Avg Increase in Forget Accuracy18.511 | 36 | |
| Machine Unlearning | ImageNet-1K | -- | 32 | |
| Machine Unlearning | CIFAR-100 (forget) | -- | 12 | |
| Machine Unlearning | CIFAR-100 (Retain) | -- | 12 |