Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness

About

Message Passing Neural Networks (MPNNs) are a common type of Graph Neural Network (GNN), in which each node's representation is computed recursively by aggregating representations (messages) from its immediate neighbors akin to a star-shaped pattern. MPNNs are appealing for being efficient and scalable, how-ever their expressiveness is upper-bounded by the 1st-order Weisfeiler-Lehman isomorphism test (1-WL). In response, prior works propose highly expressive models at the cost of scalability and sometimes generalization performance. Our work stands between these two regimes: we introduce a general framework to uplift any MPNN to be more expressive, with limited scalability overhead and greatly improved practical performance. We achieve this by extending local aggregation in MPNNs from star patterns to general subgraph patterns (e.g.,k-egonets):in our framework, each node representation is computed as the encoding of a surrounding induced subgraph rather than encoding of immediate neighbors only (i.e. a star). We choose the subgraph encoder to be a GNN (mainly MPNNs, considering scalability) to design a general framework that serves as a wrapper to up-lift any GNN. We call our proposed method GNN-AK(GNN As Kernel), as the framework resembles a convolutional neural network by replacing the kernel with GNNs. Theoretically, we show that our framework is strictly more powerful than 1&2-WL, and is not less powerful than 3-WL. We also design subgraph sampling strategies which greatly reduce memory footprint and improve speed while maintaining performance. Our method sets new state-of-the-art performance by large margins for several well-known graph ML tasks; specifically, 0.08 MAE on ZINC,74.79% and 86.887% accuracy on CIFAR10 and PATTERN respectively.

Lingxiao Zhao, Wei Jin, Leman Akoglu, Neil Shah• 2021

Related benchmarks

TaskDatasetResultRank
Graph ClassificationMUTAG (10-fold cross-validation)
Accuracy91.3
206
Graph Classificationogbg-molpcba (test)
AP29.3
206
Graph RegressionZINC (test)
MAE0.08
204
Graph ClassificationPROTEINS (10-fold cross-validation)
Accuracy77.1
197
Graph ClassificationPTC
Accuracy67.8
167
Graph RegressionZINC 12K (test)
MAE0.08
164
Graph ClassificationIMDB-B (10-fold cross-validation)
Accuracy75
148
Graph ClassificationCIFAR10 (test)
Test Accuracy72.19
139
Graph ClassificationPTC (10-fold cross-validation)
Accuracy67.7
115
Graph ClassificationCIFAR10
Accuracy74.79
108
Showing 10 of 47 rows

Other info

Code

Follow for update