Attn-QAT: 4-Bit Attention With Quantization-Aware Training
About
Achieving reliable 4-bit attention is a prerequisite for end-to-end FP4 computation on emerging FP4-capable GPUs, yet attention remains the main obstacle due to FP4's tiny dynamic range and attention's heavy-tailed activations. This paper presents the first systematic study of 4-bit quantization-aware training (QAT) for attention. We find that "drop-in" QAT, which naively combines an FP4 forward pass with a high-precision Flash Attention (FA)-style backward pass, leads to training instability. We identify two key principles for stable FP4 attention: (1) matching low-precision recomputation of attention scores in the backward pass, and (2) resolving implicit precision assumptions in FA's gradient calculation. Based on these insights, we propose Attn-QAT and implement fused Triton kernels for training as well as FP4 inference kernels. Across diffusion and language models, Attn-QAT recovers the quality drop from FP4 attention without explicit outlier-mitigation heuristics used in prior FP4 attention, and delivers up to a 1.5x speedup on an RTX 5090. Video demos can be found at https://drive.google.com/drive/folders/190F6xbBDUF2kGQYIcXBt3ehSYij5jlim?usp=sharing.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K | Accuracy92.95 | 1362 | |
| Commonsense Reasoning | WinoGrande | Accuracy79.4 | 1085 | |
| Language Understanding | MMLU | Accuracy80.44 | 825 | |
| Language Modeling | WikiText | PPL0.3076 | 732 | |
| Instruction Following | IFEval | IFEval Accuracy86.37 | 625 | |
| Science Question Answering | ARC-C | Accuracy61.53 | 193 | |
| commonsense inference | HellaSwag | Accuracy85.57 | 91 | |
| Physical Commonsense Reasoning | PIQA | Accuracy83.51 | 78 | |
| Graduate-Level Reasoning | GPQA Diamond | Accuracy44.95 | 28 | |
| Language Understanding | MMLU-Redux | -- | 24 |