Generalizing Graph Neural Networks on Out-Of-Distribution Graphs
About
Graph Neural Networks (GNNs) are proposed without considering the agnostic distribution shifts between training and testing graphs, inducing the degeneration of the generalization ability of GNNs on Out-Of-Distribution (OOD) settings. The fundamental reason for such degeneration is that most GNNs are developed based on the I.I.D hypothesis. In such a setting, GNNs tend to exploit subtle statistical correlations existing in the training set for predictions, even though it is a spurious correlation. However, such spurious correlations may change in testing environments, leading to the failure of GNNs. Therefore, eliminating the impact of spurious correlations is crucial for stable GNNs. To this end, we propose a general causal representation framework, called StableGNN. The main idea is to extract high-level representations from graph data first and resort to the distinguishing ability of causal inference to help the model get rid of spurious correlations. Particularly, we exploit a graph pooling layer to extract subgraph-based representations as high-level representations. Furthermore, we propose a causal variable distinguishing regularizer to correct the biased training distribution. Hence, GNNs would concentrate more on the stable correlations. Extensive experiments on both synthetic and real-world OOD graph datasets well verify the effectiveness, flexibility and interpretability of the proposed framework.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Graph Classification | MolHIV | ROC AUC56.71 | 82 | |
| Graph Classification | CKuzushiji-75sp unbiased (test) | Accuracy49.41 | 60 | |
| Graph Classification | CMNIST-75sp unbiased (test) | Accuracy77.65 | 60 | |
| Graph Classification | CFashion-75sp unbiased (test) | Accuracy64.03 | 60 | |
| Molecular property prediction | BACE | ROC-AUC72.29 | 35 | |
| Molecular property prediction | BBBP | ROC AUC0.6695 | 35 | |
| Molecular property prediction | ClinTox | ROC AUC85.59 | 34 | |
| Graph Classification | Molbbbp (scaffold) | ROC-AUC66.74 | 31 | |
| Graph Classification | Motif base | Accuracy57.07 | 29 | |
| Graph Classification | Motif (size) | Accuracy46.93 | 29 |