Primal-Attention: Self-attention through Asymmetric Kernel SVD in Primal Representation
About
Recently, a new line of works has emerged to understand and improve self-attention in Transformers by treating it as a kernel machine. However, existing works apply the methods for symmetric kernels to the asymmetric self-attention, resulting in a nontrivial gap between the analytical understanding and numerical implementation. In this paper, we provide a new perspective to represent and optimize self-attention through asymmetric Kernel Singular Value Decomposition (KSVD), which is also motivated by the low-rank property of self-attention normally observed in deep layers. Through asymmetric KSVD, $i$) a primal-dual representation of self-attention is formulated, where the optimization objective is cast to maximize the projection variances in the attention outputs; $ii$) a novel attention mechanism, i.e., Primal-Attention, is proposed via the primal representation of KSVD, avoiding explicit computation of the kernel matrix in the dual; $iii$) with KKT conditions, we prove that the stationary solution to the KSVD optimization in Primal-Attention yields a zero-value objective. In this manner, KSVD optimization can be implemented by simply minimizing a regularization loss, so that low-rank property is promoted without extra decomposition. Numerical experiments show state-of-the-art performance of our Primal-Attention with improved efficiency. Moreover, we demonstrate that the deployed KSVD optimization regularizes Primal-Attention with a sharper singular value decay than that of the canonical self-attention, further verifying the great potential of our method. To the best of our knowledge, this is the first work that provides a primal-dual representation for the asymmetric kernel in self-attention and successfully applies it to modeling and optimization.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | WikiText-103 (test) | Perplexity31 | 524 | |
| Long-range sequence modeling | Long Range Arena (LRA) (test) | Accuracy (Avg)60.4 | 158 | |
| Time-series classification | UEA time series classification archive (test) | Average Accuracy73.1 | 27 | |
| Offline Reinforcement Learning | D4RL | Return (HalfCheetah, Medium-Expert)77.8 | 7 |