Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Fine-Tuning Language Models with Just Forward Passes

About

Fine-tuning language models (LMs) has yielded success on diverse downstream tasks, but as LMs grow in size, backpropagation requires a prohibitively large amount of memory. Zeroth-order (ZO) methods can in principle estimate gradients using only two forward passes but are theorized to be catastrophically slow for optimizing large models. In this work, we propose a memory-efficient zerothorder optimizer (MeZO), adapting the classical ZO-SGD method to operate in-place, thereby fine-tuning LMs with the same memory footprint as inference. For example, with a single A100 80GB GPU, MeZO can train a 30-billion parameter model, whereas fine-tuning with backpropagation can train only a 2.7B LM with the same budget. We conduct comprehensive experiments across model types (masked and autoregressive LMs), model scales (up to 66B), and downstream tasks (classification, multiple-choice, and generation). Our results demonstrate that (1) MeZO significantly outperforms in-context learning and linear probing; (2) MeZO achieves comparable performance to fine-tuning with backpropagation across multiple tasks, with up to 12x memory reduction and up to 2x GPU-hour reduction in our implementation; (3) MeZO is compatible with both full-parameter and parameter-efficient tuning techniques such as LoRA and prefix tuning; (4) MeZO can effectively optimize non-differentiable objectives (e.g., maximizing accuracy or F1). We support our empirical findings with theoretical insights, highlighting how adequate pre-training and task prompts enable MeZO to fine-tune huge models, despite classical ZO analyses suggesting otherwise.

Sadhika Malladi, Tianyu Gao, Eshaan Nichani, Alex Damian, Jason D. Lee, Danqi Chen, Sanjeev Arora• 2023

Related benchmarks

TaskDatasetResultRank
Language ModelingWikiText-2--
2320
Physical Commonsense ReasoningPIQA
Accuracy84.3
696
Natural Language InferenceSNLI (test)
Accuracy81
694
Natural Language InferenceRTE
Accuracy73.9
590
Question AnsweringBoolQ
Accuracy76.6
317
Image ClassificationCIFAR-100
Accuracy64.5
302
Word Sense DisambiguationWiC
Avg Accuracy60.92
261
Common Sense ReasoningCOPA
Accuracy92
256
Sentiment ClassificationSST2 (test)
Accuracy79
233
Question AnsweringBoolQ
Accuracy63
201
Showing 10 of 119 rows
...

Other info

Follow for update