Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Representing Long-Range Context for Graph Neural Networks with Global Attention

About

Graph neural networks are powerful architectures for structured datasets. However, current methods struggle to represent long-range dependencies. Scaling the depth or width of GNNs is insufficient to broaden receptive fields as larger GNNs encounter optimization instabilities such as vanishing gradients and representation oversmoothing, while pooling-based approaches have yet to become as universally useful as in computer vision. In this work, we propose the use of Transformer-based self-attention to learn long-range pairwise relationships, with a novel "readout" mechanism to obtain a global graph embedding. Inspired by recent computer vision results that find position-invariant attention performant in learning long-range relationships, our method, which we call GraphTrans, applies a permutation-invariant Transformer module after a standard GNN module. This simple architecture leads to state-of-the-art results on several graph classification tasks, outperforming methods that explicitly encode graph structure. Our results suggest that purely-learning-based approaches without graph structure may be suitable for learning high-level, long-range relationships on graphs. Code for GraphTrans is available at https://github.com/ucbrise/graphtrans.

Zhanghao Wu, Paras Jain, Matthew A. Wright, Azalia Mirhoseini, Joseph E. Gonzalez, Ion Stoica• 2022

Related benchmarks

TaskDatasetResultRank
Graph ClassificationMutag (test)
Accuracy81.54
217
Graph Classificationogbg-molpcba (test)
AP27.61
206
Graph ClassificationNCI1 (test)
Accuracy82.6
174
Node ClassificationCora (semi-supervised)
Accuracy81.7
103
Graph ClassificationCOLLAB (test)
Accuracy79.81
96
Graph ClassificationNCI109 (test)
Accuracy82.3
64
Node ClassificationCite semi-supervised
Accuracy70.2
61
Graph property predictionOGBG-CODE2 (test)
F118.3
57
Node ClassificationCiteseer full-supervised
Accuracy0.721
51
Node ClassificationPubmed full-supervised
Accuracy87.8
48
Showing 10 of 28 rows

Other info

Code

Follow for update