Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

On the Inductive Bias of Stacking Towards Improving Reasoning

About

Given the increasing scale of model sizes, novel training strategies like gradual stacking [Gong et al., 2019, Reddi et al., 2023] have garnered interest. Stacking enables efficient training by gradually growing the depth of a model in stages and using layers from a smaller model in an earlier stage to initialize the next stage. Although efficient for training, the model biases induced by such growing approaches are largely unexplored. In this work, we examine this fundamental aspect of gradual stacking, going beyond its efficiency benefits. We propose a variant of gradual stacking called MIDAS that can speed up language model training by up to 40%. Furthermore we discover an intriguing phenomenon: MIDAS is not only training-efficient but surprisingly also has an inductive bias towards improving downstream tasks, especially tasks that require reasoning abilities like reading comprehension and math problems, despite having similar or slightly worse perplexity compared to baseline training. To further analyze this inductive bias, we construct reasoning primitives -- simple synthetic tasks that are building blocks for reasoning -- and find that a model pretrained with stacking is significantly better than standard pretraining on these primitives, with and without fine-tuning. This provides stronger and more robust evidence for this inductive bias towards reasoning. These findings of training efficiency and inductive bias towards reasoning are verified at 1B, 2B and 8B parameter language models. Finally, we conjecture the underlying reason for this inductive bias by exploring the connection of stacking to looped models and provide strong supporting empirical analysis.

Nikunj Saunshi, Stefani Karp, Shankar Krishnan, Sobhan Miryoosefi, Sashank J. Reddi, Sanjiv Kumar• 2024

Related benchmarks

TaskDatasetResultRank
Question AnsweringOpenBookQA
Accuracy40.2
465
Language ModelingLAMBADA
Accuracy50.81
183
Common Sense ReasoningHellaSwag
Accuracy46.19
164
Math ReasoningGSM8K (test)
Accuracy18.7
155
Synthetic ReasoningReasoning Primitives
Accuracy55.46
16
Mathematical ReasoningMath Word
Accuracy24.01
16
Closed-book Question AnsweringClosed Book QA (TriviaQA, TydiQA, NaturalQuestions, WebQuestions)
Accuracy21.8
14
General Downstream EvaluationAll Downstream Tasks 15 tasks
Average Accuracy (All Tasks)36.4
14
Math Word Problem SolvingMath Word Problems 6 tasks
Accuracy43.1
14
Language ModelingUL2 Pre-training (val)
Loss1.844
14
Showing 10 of 14 rows

Other info

Follow for update