Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Root Mean Square Layer Normalization

About

Layer normalization (LayerNorm) has been successfully applied to various deep neural networks to help stabilize training and boost model convergence because of its capability in handling re-centering and re-scaling of both inputs and weight matrix. However, the computational overhead introduced by LayerNorm makes these improvements expensive and significantly slows the underlying network, e.g. RNN in particular. In this paper, we hypothesize that re-centering invariance in LayerNorm is dispensable and propose root mean square layer normalization, or RMSNorm. RMSNorm regularizes the summed inputs to a neuron in one layer according to root mean square (RMS), giving the model re-scaling invariance property and implicit learning rate adaptation ability. RMSNorm is computationally simpler and thus more efficient than LayerNorm. We also present partial RMSNorm, or pRMSNorm where the RMS is estimated from p% of the summed inputs without breaking the above properties. Extensive experiments on several tasks using diverse network architectures show that RMSNorm achieves comparable performance against LayerNorm but reduces the running time by 7%~64% on different models. Source code is available at https://github.com/bzhangGo/rmsnorm.

Biao Zhang, Rico Sennrich• 2019

Related benchmarks

TaskDatasetResultRank
Image ClassificationImageNet-1k 1.0 (test)
Top-1 Accuracy83
251
Image GenerationImageNet-1k (val)
FID20.76
106
PDE solvingDarcy Flow 64x64 resolution
Relative L2 Error0.0731
27
PDE solvingDarcy Flow 32x32 resolution
Relative L2 Error0.0359
17
PDE solvingDarcy Flow 128x128 resolution
Relative L2 Error10.13
17
PDE solvingDarcy Flow 256x256 resolution
Relative L2 Error11.55
17
DNA classificationGenomicBenchmarks
Accuracy86.9
14
Question Answering and ReasoningDownstream Reasoning Suite (Arc-e, PIQA, Hellaswag, OpenBookQA, Winogrande, MMLU, BoolQ)
ARC-e37.67
14
Language ModelingPretraining Dataset
Train Loss (PT)3.203
14
Speech pretrainingLibriSpeech (val)
Validation Loss1.93
14
Showing 10 of 18 rows

Other info

Follow for update