Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

signSGD: Compressed Optimisation for Non-Convex Problems

About

Training large neural networks requires distributing learning across multiple workers, where the cost of communicating gradients can be a significant bottleneck. signSGD alleviates this problem by transmitting just the sign of each minibatch stochastic gradient. We prove that it can get the best of both worlds: compressed gradients and SGD-level convergence rate. The relative $\ell_1/\ell_2$ geometry of gradients, noise and curvature informs whether signSGD or SGD is theoretically better suited to a particular problem. On the practical side we find that the momentum counterpart of signSGD is able to match the accuracy and convergence speed of Adam on deep Imagenet models. We extend our theory to the distributed setting, where the parameter server uses majority vote to aggregate gradient signs from each worker enabling 1-bit compression of worker-server communication in both directions. Using a theorem by Gauss we prove that majority vote can achieve the same reduction in variance as full precision distributed SGD. Thus, there is great promise for sign-based optimisation schemes to achieve fast communication and fast convergence. Code to reproduce experiments is to be found at https://github.com/jxbz/signSGD .

Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli, Anima Anandkumar• 2018

Related benchmarks

TaskDatasetResultRank
Image ClassificationTinyImageNet (val)
Accuracy78.885
289
Language ModelingWikiText-103 (val)--
261
Language ModelingFineWeb (val)
Validation Loss3.197
217
Language ModelingC4 LLaMA-130M (val)
Perplexity19.693
40
Language Modeling(val)
Validation Loss2.68
38
Language ModelingSlimPajama latest (val)
Validation Loss3.145
26
Stochastic Optimization ConvergenceTheoretical Analysis
Convergence Rate Bound4
23
Next-Character PredictionMultilingual NLI
Character-level Accuracy59.71
16
Language Modeling Pre-trainingC4 (val)--
14
Language ModelingLLaMA-350M pre-training (val)
Validation Loss2.717
10
Showing 10 of 18 rows

Other info

Follow for update