Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

SOAP: Improving and Stabilizing Shampoo using Adam

About

There is growing evidence of the effectiveness of Shampoo, a higher-order preconditioning method, over Adam in deep learning optimization tasks. However, Shampoo's drawbacks include additional hyperparameters and computational overhead when compared to Adam, which only updates running averages of first- and second-moment quantities. This work establishes a formal connection between Shampoo (implemented with the 1/2 power) and Adafactor -- a memory-efficient approximation of Adam -- showing that Shampoo is equivalent to running Adafactor in the eigenbasis of Shampoo's preconditioner. This insight leads to the design of a simpler and computationally efficient algorithm: $\textbf{S}$hampo$\textbf{O}$ with $\textbf{A}$dam in the $\textbf{P}$reconditioner's eigenbasis (SOAP). With regards to improving Shampoo's computational efficiency, the most straightforward approach would be to simply compute Shampoo's eigendecomposition less frequently. Unfortunately, as our empirical results show, this leads to performance degradation that worsens with this frequency. SOAP mitigates this degradation by continually updating the running average of the second moment, just as Adam does, but in the current (slowly changing) coordinate basis. Furthermore, since SOAP is equivalent to running Adam in a rotated space, it introduces only one additional hyperparameter (the preconditioning frequency) compared to Adam. We empirically evaluate SOAP on language model pre-training with 360m and 660m sized models. In the large batch regime, SOAP reduces the number of iterations by over 40% and wall clock time by over 35% compared to AdamW, with approximately 20% improvements in both metrics compared to Shampoo. An implementation of SOAP is available at https://github.com/nikhilvyas/SOAP.

Nikhil Vyas, Depen Morwani, Rosie Zhao, Mujin Kwun, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade• 2024

Related benchmarks

TaskDatasetResultRank
Image ClassificationTiny ImageNet (test)
Accuracy87.72
265
Language ModelingC4--
73
Language Model Pre-trainingC4 Llama 2 pre-training (val)
Perplexity14.52
47
Image ClassificationCIFAR-100 IID
Accuracy71.98
37
Image ClassificationTiny-ImageNet Dirichlet-0.05 (test)
Accuracy50.02
32
Image ClassificationTiny-ImageNet Dirichlet alpha=0.1 (test)
Test Accuracy54.42
30
Image ClassificationCIFAR-100 Dir-0.1
Accuracy68.44
28
Image ClassificationCIFAR-100 Dirichlet-0.1 (test)
Accuracy68.44
20
Image ClassificationCIFAR-100 Dirichlet-0.05 (test)
Accuracy58.16
20
Image ClassificationCIFAR-100 Dir-0.5
Accuracy71.19
12
Showing 10 of 17 rows

Other info

Follow for update