Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Scaling Law with Learning Rate Annealing

About

We find that the cross-entropy loss curves of neural language models empirically adhere to a scaling law with learning rate (LR) annealing over training steps: $$L(s) = L_0 + A\cdot S_1^{-\alpha} - C\cdot S_2,$$ where $L(s)$ is the validation loss at step $s$, $S_1$ is the area under the LR curve, $S_2$ is the LR annealing area, and $L_0$, $A$, $C$, $\alpha$ are constant parameters. This formulation takes into account two factors: (1) power-law scaling over data size, and (2) the additional loss reduction during LR annealing. Therefore, this formulation can describe the full loss curve at each step, rather than the single loss point at the end of training. Applying the scaling law with LR annealing and fitting only one or two training curves, we can accurately predict the loss at any given step across any learning rate scheduler (LRS). This approach significantly reduces computational cost in formulating scaling laws while providing more accuracy and expressiveness for training dynamics. Extensive experiments demonstrate that our findings hold across a range of hyper-parameters and model architectures, and our equation can extend to scaling effect of model sizes. Moreover, our formulation provides accurate theoretical verification and explanation for empirical results observed in numerous previous studies, particularly those focusing on LR schedule and annealing. We believe that this work is promising to enhance the understanding of LLM training dynamics while greatly democratizing scaling laws, and it can guide researchers in refining training strategies (e.g. critical LRS) for further LLMs.

Howe Tissue, Venus Wang, Lu Wang• 2024

Related benchmarks

TaskDatasetResultRank
Loss Curve PredictionDense Model Loss Curve Prediction WSD to Cosine transfer
MAPE0.202
9
Loss Curve PredictionDense Model Loss Curve Prediction Cosine to WSD transfer
MAPE0.404
9
Loss curve fitting across batch sizesModel loss data (train)--
7
Loss prediction and hyper-parameter configuration rankingLLM training loss grids 5-fold configuration-level cross-validation (held-out)
R20.973
6
Pre-training extrapolationconfigurations 300B single-phase Table 2
Mean Relative Error11.3
4
Pre-training extrapolationTable 2 configurations 100B multi-phase
Mean Relative Error9.7
4
Pre-training extrapolationconfigurations 100B loss-equivalent Table 2
Mean Relative Error10.2
4
Pre-training extrapolationTable 2 configurations All 10
Mean Relative Error10.3
4
Loss curve fitting across model sizesDense models--
3
Loss curve fitting across model sizesMoE models (various sizes)--
3
Showing 10 of 10 rows

Other info

Follow for update