Global Self-Attention as a Replacement for Graph Convolution
About
We propose an extension to the transformer neural network architecture for general-purpose graph learning by adding a dedicated pathway for pairwise structural information, called edge channels. The resultant framework - which we call Edge-augmented Graph Transformer (EGT) - can directly accept, process and output structural information of arbitrary form, which is important for effective learning on graph-structured data. Our model exclusively uses global self-attention as an aggregation mechanism rather than static localized convolutional aggregation. This allows for unconstrained long-range dynamic interactions between nodes. Moreover, the edge channels allow the structural information to evolve from layer to layer, and prediction tasks on edges/links can be performed directly from the output embeddings of these channels. We verify the performance of EGT in a wide range of graph-learning experiments on benchmark datasets, in which it outperforms Convolutional/Message-Passing Graph Neural Networks. EGT sets a new state-of-the-art for the quantum-chemical regression task on the OGB-LSC PCQM4Mv2 dataset containing 3.8 million molecular graphs. Our findings indicate that global self-attention based aggregation can serve as a flexible, adaptive and effective replacement of graph convolution for general-purpose graph learning. Therefore, convolutional local neighborhood aggregation is not an essential inductive bias.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Graph Classification | ogbg-molpcba (test) | AP29.61 | 206 | |
| Graph Regression | ZINC (test) | MAE0.108 | 204 | |
| Graph Regression | ZINC 12K (test) | MAE0.108 | 164 | |
| Graph Classification | CIFAR10 (test) | Test Accuracy68.702 | 139 | |
| Node Classification | CLUSTER (test) | Test Accuracy79.232 | 113 | |
| Graph Classification | MNIST (test) | Accuracy98.173 | 110 | |
| Graph Classification | CIFAR10 | Accuracy68.702 | 108 | |
| Graph Regression | ZINC | MAE0.108 | 96 | |
| Graph Classification | MNIST | Accuracy98.173 | 95 | |
| Node Classification | PATTERN (test) | Test Accuracy86.821 | 88 |