Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs
About
Knowledge distillation can be a cost-effective technique to distill knowledge in Large Language Models, if the teacher output logits can be pre-computed and cached. However, successfully applying this to pre-training remains largely unexplored. In this work, we prove that naive approaches for sparse knowledge distillation such as caching Top-K probabilities, while intuitive, provide biased estimates of teacher probability distribution to the student, resulting in suboptimal performance and calibration. We propose an importance-sampling-based method `Random Sampling Knowledge Distillation', which provides unbiased estimates, preserves the gradient in expectation, and requires storing significantly sparser logits. Our method enables faster training of student models with marginal overhead (<10%) compared to cross-entropy based training, while maintaining competitive performance compared to full distillation, across a range of model sizes from 300M to 3B.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | LAMBADA (test) | -- | 71 | |
| Instruction Following | SelfInst | -- | 50 | |
| Instruction Following | IFEval (test) | IFEval Score20.9 | 45 | |
| Instruction Following | Dolly | Score71.3 | 18 | |
| Instruction Following | Vicuna | Score58.2 | 18 | |
| General Knowledge Evaluation | General-purpose benchmarks average (test) | Accuracy64.7 | 12 | |
| Language Modeling | Fineweb-edu distillation 8B to 300M | LM Loss2.74 | 7 | |
| Speculative Decoding | Fineweb-edu distillation 8B to 300M | Spec. Accept %62 | 7 | |
| Instruction Following | Instruction Following SFT 1.0 (eval) | SFT Score59.4 | 6 | |
| Language Modeling | Fineweb-edu 1.0 (test) | LM Loss2.32 | 6 |