Representing Long-Range Context for Graph Neural Networks with Global Attention
About
Graph neural networks are powerful architectures for structured datasets. However, current methods struggle to represent long-range dependencies. Scaling the depth or width of GNNs is insufficient to broaden receptive fields as larger GNNs encounter optimization instabilities such as vanishing gradients and representation oversmoothing, while pooling-based approaches have yet to become as universally useful as in computer vision. In this work, we propose the use of Transformer-based self-attention to learn long-range pairwise relationships, with a novel "readout" mechanism to obtain a global graph embedding. Inspired by recent computer vision results that find position-invariant attention performant in learning long-range relationships, our method, which we call GraphTrans, applies a permutation-invariant Transformer module after a standard GNN module. This simple architecture leads to state-of-the-art results on several graph classification tasks, outperforming methods that explicitly encode graph structure. Our results suggest that purely-learning-based approaches without graph structure may be suitable for learning high-level, long-range relationships on graphs. Code for GraphTrans is available at https://github.com/ucbrise/graphtrans.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Graph Classification | Mutag (test) | Accuracy81.54 | 217 | |
| Graph Classification | ogbg-molpcba (test) | AP27.61 | 206 | |
| Graph Classification | NCI1 (test) | Accuracy82.6 | 174 | |
| Node Classification | Cora (semi-supervised) | Accuracy81.7 | 103 | |
| Graph Classification | COLLAB (test) | Accuracy79.81 | 96 | |
| Graph Classification | NCI109 (test) | Accuracy82.3 | 64 | |
| Node Classification | Cite semi-supervised | Accuracy70.2 | 61 | |
| Graph property prediction | OGBG-CODE2 (test) | F118.3 | 57 | |
| Node Classification | Citeseer full-supervised | Accuracy0.721 | 51 | |
| Node Classification | Pubmed full-supervised | Accuracy87.8 | 48 |