Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

About

Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best 8B scale instruction-tuned linear RNN model. We also find that the distilled model has natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length. Code and pre-trained checkpoints are open-sourced at https://github.com/jxiw/MambaInLlama and https://github.com/itsdaniele/speculative_mamba.

Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao• 2024

Related benchmarks

TaskDatasetResultRank
Question AnsweringSQuAD 2.0
F124.1098
190
Long-context language modelingRULER
RULER Score0.8031
148
Structured Web Data ExtractionSWDE
Performance87.4
120
Long-context language modeling evaluationFDA (test)
Score0.7178
120
ChatAlpacaEval 2.0 (test)
AlpacaEval (LC win %)29.61
46
Language ModelingLM Evaluation Harness (LM Eval) (test)
WG (Winograd Schema)74.11
22
ChatMT-Bench 1.0 (test)
MT-Bench Score7.7
19
General Language UnderstandingOpen LLM Leaderboard (test)
ARC60.41
9
Mathematical and Code ReasoningZeroEval (test)
GSM8K Accuracy67.85
8
Showing 9 of 9 rows

Other info

Code

Follow for update