Hierarchical Graph Representation Learning with Differentiable Pooling
About
Recently, graph neural networks (GNNs) have revolutionized the field of graph representation learning through effectively learned node embeddings, and achieved state-of-the-art results in tasks such as node classification and link prediction. However, current GNN methods are inherently flat and do not learn hierarchical representations of graphs---a limitation that is especially problematic for the task of graph classification, where the goal is to predict the label associated with an entire graph. Here we propose DiffPool, a differentiable graph pooling module that can generate hierarchical representations of graphs and can be combined with various graph neural network architectures in an end-to-end fashion. DiffPool learns a differentiable soft cluster assignment for nodes at each layer of a deep GNN, mapping nodes to a set of clusters, which then form the coarsened input for the next GNN layer. Our experimental results show that combining existing GNN methods with DiffPool yields an average improvement of 5-10% accuracy on graph classification benchmarks, compared to all existing pooling approaches, achieving a new state-of-the-art on four out of five benchmark data sets.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Graph Classification | PROTEINS | Accuracy78.1 | 742 | |
| Graph Classification | MUTAG | Accuracy86.72 | 697 | |
| Graph Classification | NCI1 | Accuracy79 | 460 | |
| Graph Classification | COLLAB | Accuracy82.13 | 329 | |
| Graph Classification | IMDB-B | Accuracy73.55 | 322 | |
| Graph Classification | ENZYMES | Accuracy62.53 | 305 | |
| Graph Classification | NCI109 | Accuracy61.98 | 223 | |
| Graph Classification | Mutag (test) | Accuracy87.5 | 217 | |
| Graph Classification | MUTAG (10-fold cross-validation) | Accuracy86.1 | 206 | |
| Graph Classification | PROTEINS (10-fold cross-validation) | Accuracy78.1 | 197 |