Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

EMO: Earth Mover Distance Optimization for Auto-Regressive Language Modeling

About

Neural language models are probabilistic models of human text. They are predominantly trained using maximum likelihood estimation (MLE), which is equivalent to minimizing the forward cross-entropy between the empirical data distribution and the model distribution. However, various degeneration phenomena are still widely observed when decoding from the distributions learned by such models. We establish that the forward cross-entropy is suboptimal as a distance metric for aligning human and model distribution due to its (1) recall-prioritization (2) negative diversity ignorance and (3) train-test mismatch. In this paper, we propose Earth Mover Distance Optimization (EMO) for auto-regressive language modeling. EMO capitalizes on the inherent properties of earth mover distance to address the aforementioned challenges. Due to the high complexity of direct computation, we further introduce a feasible upper bound for EMO to ease end-to-end training. Upon extensive evaluation of language models trained using EMO and MLE. We find that EMO demonstrates a consistently better language modeling performance than MLE across domains. Moreover, EMO demonstrates noteworthy enhancements in downstream performance with minimal fine-tuning on merely 25,000 sentences. This highlights the tremendous potential of EMO as a lightweight calibration method for enhancing large-scale pre-trained language models.

Siyu Ren, Zhiyong Wu, Kenny Q. Zhu• 2023

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningMathematical Reasoning Suite GSM8K, MATH, SVAMP, SimulEq, AQuA, SAT, MMLU
Accuracy (Aggregate)68.5
40
Language ModelingWebText
Mauve58
33
Language ModelingWikiText-2
Mauve0.85
33
Language ModelingWritingPrompts
MAUVE13
33
Language ModelingWikiText-103
Mauve85
18
Numerical ReasoningGSM8K (test)
MAE (Scale 1)1.13
6
Mathematical ReasoningGSM8K
Accuracy (1 sample)69.3
6
Numerical ReasoningGSM8K (test)
Accuracy (Error <= 1)69.3
6
Showing 8 of 8 rows

Other info

Follow for update