Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training

About

Given the massive cost of language model pre-training, a non-trivial improvement of the optimization algorithm would lead to a material reduction on the time and cost of training. Adam and its variants have been state-of-the-art for years, and more sophisticated second-order (Hessian-based) optimizers often incur too much per-step overhead. In this paper, we propose Sophia, Second-order Clipped Stochastic Optimization, a simple scalable second-order optimizer that uses a light-weight estimate of the diagonal Hessian as the pre-conditioner. The update is the moving average of the gradients divided by the moving average of the estimated Hessian, followed by element-wise clipping. The clipping controls the worst-case update size and tames the negative impact of non-convexity and rapid change of Hessian along the trajectory. Sophia only estimates the diagonal Hessian every handful of iterations, which has negligible average per-step time and memory overhead. On language modeling with GPT models of sizes ranging from 125M to 1.5B, Sophia achieves a 2x speed-up compared to Adam in the number of steps, total compute, and wall-clock time, achieving the same perplexity with 50% fewer steps, less total compute, and reduced wall-clock time. Theoretically, we show that Sophia, in a much simplified setting, adapts to the heterogeneous curvatures in different parameter dimensions, and thus has a run-time bound that does not depend on the condition number of the loss.

Hong Liu, Zhiyuan Li, David Hall, Percy Liang, Tengyu Ma• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-10 (test)
Accuracy97.1
3381
Language ModelingC4
Perplexity25.63
1422
Image ClassificationTiny ImageNet (test)
Accuracy86.16
362
Image ClassificationCIFAR-10 (test)
Accuracy94.53
129
Language ModelingC4--
121
Image ClassificationImageNet-100 (test)
Clean Accuracy88.2
119
Image ClassificationFood-101 (test)--
89
Image ClassificationOxford-IIIT Pet (test)
Overall Accuracy88.8
59
Image ClassificationCIFAR-100 Dir-0.1
Accuracy56.65
52
Image ClassificationCIFAR-100 IID
Accuracy62.17
42
Showing 10 of 31 rows

Other info

Follow for update