LUT-GEMM: Quantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models
About
Recent advances in self-supervised learning and the Transformer architecture have significantly improved natural language processing (NLP), achieving remarkably low perplexity. However, the growing size of NLP models introduces a memory wall problem during the generation phase. To mitigate this issue, recent efforts have focused on quantizing model weights to sub-4-bit precision while preserving full precision for activations, resulting in practical speed-ups during inference on a single GPU. However, these improvements primarily stem from reduced memory movement, which necessitates a resource-intensive dequantization process rather than actual computational reduction. In this paper, we introduce LUT-GEMM, an efficient kernel for quantized matrix multiplication, which not only eliminates the resource-intensive dequantization process but also reduces computational costs compared to previous kernels for weight-only quantization. Furthermore, we proposed group-wise quantization to offer a flexible trade-off between compression ratio and accuracy. The impact of LUT-GEMM is facilitated by implementing high compression ratios through low-bit quantization and efficient LUT-based operations. We show experimentally that when applied to the OPT-175B model with 3-bit quantization, LUT-GEMM substantially accelerates token generation latency, achieving a remarkable 2.1$\times$ improvement on a single GPU when compared to OPTQ, which relies on the costly dequantization process.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | WikiText-2 (test) | PPL4.01 | 1541 | |
| Language Modeling | WikiText-2 | Perplexity (PPL)4.01 | 841 | |
| Language Modeling | WikiText2 v1 (test) | Perplexity5.06 | 341 | |
| Inference Latency | OPT model family | Latency (ms)6.2 | 79 | |
| Inference Latency | A100 GPU | Latency (ms)27.4 | 48 | |
| Matrix Multiplication Latency | Llama-3-8B | Kernel-level latency (µs)160.1 | 8 | |
| Matrix Multiplication Latency | Llama-3 70B | Kernel Latency (µs)299.9 | 8 | |
| Language Understanding | Standard Downstream Tasks (ARC, COPA, BoolQ, PIQA, StoryCloze, RTE, MMLU) | ARC (Challenge)47.7 | 8 | |
| Energy Consumption Estimation | OPT-125M | Energy Consumption (J)12.48 | 8 | |
| Energy Consumption Estimation | OPT-350M | Energy (J)23.96 | 8 |