Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

What Makes Graph Neural Networks Miscalibrated?

About

Given the importance of getting calibrated predictions and reliable uncertainty estimations, various post-hoc calibration methods have been developed for neural networks on standard multi-class classification tasks. However, these methods are not well suited for calibrating graph neural networks (GNNs), which presents unique challenges such as accounting for the graph structure and the graph-induced correlations between the nodes. In this work, we conduct a systematic study on the calibration qualities of GNN node predictions. In particular, we identify five factors which influence the calibration of GNNs: general under-confident tendency, diversity of nodewise predictive distributions, distance to training nodes, relative confidence level, and neighborhood similarity. Furthermore, based on the insights from this study, we design a novel calibration method named Graph Attention Temperature Scaling (GATS), which is tailored for calibrating graph neural networks. GATS incorporates designs that address all the identified influential factors and produces nodewise temperature scaling using an attention-based architecture. GATS is accuracy-preserving, data-efficient, and expressive at the same time. Our experiments empirically verify the effectiveness of GATS, demonstrating that it can consistently achieve state-of-the-art calibration results on various graph datasets for different GNN backbones.

Hans Hao-Hsun Hsu, Yuesong Shen, Christian Tomani, Daniel Cremers• 2022

Related benchmarks

TaskDatasetResultRank
Node ClassificationComputers--
143
Confidence calibrationCiteseer
ECE3.86
36
Confidence calibrationCora
ECE3.18
36
Confidence calibrationPubmed
ECE0.98
36
Confidence calibrationCoraFull
ECE3.54
28
Node ClassificationCora
ECE2.03
24
GNN CalibrationCS
ECE0.81
12
GNN CalibrationCora
NLL0.5591
12
GNN CalibrationCS
NLL (x10^-2)2.13e+3
12
GNN CalibrationPhysics
Negative Log-Likelihood (NLL)0.1187
12
Showing 10 of 35 rows

Other info

Code

Follow for update