Gauss-Newton Unlearning for the LLM Era
About
Standard large language model training can create models that produce outputs their trainer deems unacceptable in deployment. The probability of these outputs can be reduced using methods such as LLM unlearning. However, unlearning a set of data (called the forget set) can degrade model performance on other distributions where the trainer wants to retain the model's behavior. To improve this trade-off, we demonstrate that using the forget set to compute only a few uphill Gauss-Newton steps provides a conceptually simple, state-of-the-art unlearning approach for LLMs. While Gauss-Newton steps adapt Newton's method to non-linear models, it is non-trivial to efficiently and accurately compute such steps for LLMs. Hence, our approach crucially relies on parametric Hessian approximations such as Kronecker-Factored Approximate Curvature (K-FAC). We call this combined approach K-FADE (K-FAC for Distribution Erasure). Our evaluation on the WMDP and ToFU benchmarks demonstrates that K-FADE suppresses outputs from the forget set and approximates, in output space, the results of retraining without the forget set. Critically, our method does this while altering the outputs on the retain set less than previous methods. This is because K-FADE transforms a constraint on the model's outputs across the entire retain set into a constraint on the model's weights, allowing the algorithm to minimally change the model's behavior on the retain set at each step. Moreover, the unlearning updates computed by K-FADE can be reapplied later if the model undergoes further training, allowing unlearning to be cheaply maintained.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Instruction Following | MT-Bench | -- | 189 | |
| Knowledge Unlearning | WMDP bio | Accuracy30.1 | 20 | |
| Model Unlearning | TOFU 1.0 (Forget 10%) | Forget Quality85 | 7 | |
| Model Unlearning | TOFU Forget 5% 1.0 | Forget Quality0.87 | 6 | |
| Knowledge Suppression | WMDP cyber | Accuracy27.7 | 4 | |
| Output Specificity | Alpaca | KL Divergence (Specificity)0.029 | 4 | |
| Language Understanding | MMLU | MMLU Knowledge Score57.2 | 4 |