AdaGC: Improving Training Stability for Large Language Model Pretraining
About
Loss spikes remain a persistent obstacle in large-scale language model pretraining. While previous research has attempted to identify the root cause of loss spikes by investigating individual factors, we observe that, in practice, such spikes are typically triggered by the confluence of heterogeneous factors. Empirically, loss spikes may arise from a combination of data outliers, hardware or transient computational faults, numerical precision issues, and hyperparameter settings. Regardless of the underlying cause, these spikes manifest as unstable optimizer updates, as abnormal gradients contaminate both first- and second-moment states. In this paper, we propose a principled gradient-centric remedy: AdaGC, an adaptive per-tensor gradient clipping scheme that mitigates such contamination by bounding gradient norms relative to a tensor-wise exponential moving average of their historical clipped values. AdaGC is optimizer-agnostic, introduces negligible memory overhead, and reduces communication costs compared to GlobalGC, particularly in hybrid-parallel distributed training environments. Experiments on Llama-2 7B, Mixtral 8x1B, and ERNIE 10B-A1.4B demonstrate that AdaGC robustly eliminates training instabilities, consistently reducing spike scores to zero for all models and improving downstream accuracy over GlobalGC by 1.32%, 1.27%, and 2.48%, respectively. Furthermore, AdaGC seamlessly integrates with optimizers such as Muon and Lion, consistently yielding higher average accuracy and zero spike scores.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K (test) | Accuracy36.01 | 797 | |
| Commonsense Reasoning | WinoGrande | Accuracy58.09 | 776 | |
| Code Generation | HumanEval (test) | -- | 444 | |
| Boolean Question Answering | BoolQ | Accuracy58.93 | 307 | |
| Question Answering | ARC-E | Accuracy49.58 | 242 | |
| Multitask Language Understanding | MMLU | Accuracy23.62 | 206 | |
| Language Understanding | MMLU (test) | MMLU Average Accuracy48.7 | 136 | |
| Science Question Answering | SciQ | Normalized Accuracy76.6 | 44 | |
| Physical Commonsense Reasoning | PIQA | Accuracy74.32 | 41 | |
| Reasoning | BBH (test) | Accuracy31.38 | 40 |