Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Self-Attention through Kernel-Eigen Pair Sparse Variational Gaussian Processes

About

While the great capability of Transformers significantly boosts prediction accuracy, it could also yield overconfident predictions and require calibrated uncertainty estimation, which can be commonly tackled by Gaussian processes (GPs). Existing works apply GPs with symmetric kernels under variational inference to the attention kernel; however, omitting the fact that attention kernels are in essence asymmetric. Moreover, the complexity of deriving the GP posteriors remains high for large-scale data. In this work, we propose Kernel-Eigen Pair Sparse Variational Gaussian Processes (KEP-SVGP) for building uncertainty-aware self-attention where the asymmetry of attention kernels is tackled by Kernel SVD (KSVD) and a reduced complexity is acquired. Through KEP-SVGP, i) the SVGP pair induced by the two sets of singular vectors from KSVD w.r.t. the attention kernel fully characterizes the asymmetry; ii) using only a small set of adjoint eigenfunctions from KSVD, the derivation of SVGP posteriors can be based on the inversion of a diagonal matrix containing singular values, contributing to a reduction in time complexity; iii) an evidence lower bound is derived so that variational parameters and network weights can be optimized with it. Experiments verify our excellent performances and efficiency on in-distribution, distribution-shift and out-of-distribution benchmarks.

Yingyi Chen, Qinghua Tao, Francesco Tonin, Johan A.K. Suykens• 2024

Related benchmarks

TaskDatasetResultRank
Out-of-Distribution DetectionCIFAR-100
AUROC77.93
107
Out-of-Distribution DetectionSVHN
AUROC88.25
62
Uncertainty CalibrationCIFAR-10-C--
35
Out-of-Distribution DetectionLSUN
AUROC0.8835
26
Image ClassificationCIFAR-10 (test)
Accuracy84.52
8
ClassificationIMDB
Accuracy87.77
8
Text ClassificationIMDB (test)
Accuracy0.8576
8
Text ClassificationCoLA (test)
MCC30.86
8
ClassificationCIFAR-10
Acc87.48
8
Linguistic AcceptabilityCoLA (OOD)
MCC21.14
6
Showing 10 of 12 rows

Other info

Follow for update