Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference
About
As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128K or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware KV cache selection algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 2.23x self-attention speedup, which reduces inference latency by 7.03x while performing well on tasks with long dependencies with negligible accuracy loss. Code is available at http://github.com/mit-han-lab/Quest .
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | AIME 2024 | Accuracy71.5 | 251 | |
| Long-context Language Understanding | LongBench | M-Avg48.6 | 219 | |
| Long-context Understanding | LongBench | Overall Average Score50 | 115 | |
| Long-context Understanding | LongBench (test) | Avg Score46.92 | 80 | |
| Long-context Language Understanding | InfiniteBench | En.Sum28.86 | 63 | |
| End-to-end throughput | LLaMA-2-7B-Chat | Throughput (tokens/sec)410 | 60 | |
| Attention Operator Latency | LLaMA-2 Chat 7B | Attention Latency (ms)0.2 | 60 | |
| Long-context language modeling | LongBench-E 1.0 (test) | S-Doc QA Perf.30.76 | 37 | |
| Long-context Understanding | LongBench v2 | Overall Score28 | 37 | |
| Long-context language modeling | RULER | Accuracy (8K Context)88.81 | 34 |