Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Preserving Diversity in Supervised Fine-Tuning of Large Language Models

About

Large Language Models (LLMs) typically rely on Supervised Fine-Tuning (SFT) to specialize in downstream tasks, with the Cross Entropy (CE) loss being the de facto choice. However, CE maximizes the likelihood of observed data without accounting for alternative possibilities. As such, CE usually leads to reduced diversity in the model's outputs, which hinders further development that requires sampling to explore better responses. To address this limitation, this paper introduces a new game-theoretic formulation for SFT. In this framework, an auxiliary variable is introduced to regulate the learning process. We prove that the proposed game-theoretic approach connects to the problem of reverse KL minimization with entropy regularization. This regularization prevents over-memorization of training data and promotes output diversity. To implement this framework, we develop GEM, a new training algorithm that is computationally efficient as CE by leveraging some unique properties of LLMs. Empirical studies of pre-trained models from 3B to 70B parameters show that GEM achieves comparable downstream performance to CE while significantly enhancing output diversity. This increased diversity translates to performance gains in test-time compute scaling for chat and code generation tasks. Moreover, we observe that preserving output diversity has the added benefit of mitigating forgetting, as maintaining diverse outputs encourages models to retain pre-trained knowledge throughout the training process.

Ziniu Li, Congliang Chen, Tian Xu, Zeyu Qin, Jiancong Xiao, Zhi-Quan Luo, Ruoyu Sun• 2024

Related benchmarks

TaskDatasetResultRank
Instruction FollowingIFEval--
625
Mathematical ReasoningCollegeMATH
Accuracy48.7
276
Mathematical ReasoningMATH 500
pass@187
239
Mathematical Multimodal ReasoningMathVerse
Accuracy43.12
221
Mathematical Multimodal ReasoningMathVista
Accuracy72.7
218
Question AnsweringTruthfulQA
Accuracy77.89
152
Massive Multi-discipline Multimodal UnderstandingMMMU
Accuracy40.78
152
Mathematical ReasoningAMC
Accuracy (%)44.69
134
Mathematical ReasoningMinerva
Pass@1 Accuracy31.85
90
LLM Alignment EvaluationAlpacaEval 2
LC Win Rate12.24
86
Showing 10 of 38 rows

Other info

Follow for update