VQ-GNN: A Universal Framework to Scale up Graph Neural Networks using Vector Quantization
About
Most state-of-the-art Graph Neural Networks (GNNs) can be defined as a form of graph convolution which can be realized by message passing between direct neighbors or beyond. To scale such GNNs to large graphs, various neighbor-, layer-, or subgraph-sampling techniques are proposed to alleviate the "neighbor explosion" problem by considering only a small subset of messages passed to the nodes in a mini-batch. However, sampling-based methods are difficult to apply to GNNs that utilize many-hops-away or global context each layer, show unstable performance for different tasks and datasets, and do not speed up model inference. We propose a principled and fundamentally different approach, VQ-GNN, a universal framework to scale up any convolution-based GNNs using Vector Quantization (VQ) without compromising the performance. In contrast to sampling-based techniques, our approach can effectively preserve all the messages passed to a mini-batch of nodes by learning and updating a small number of quantized reference vectors of global node representations, using VQ within each GNN layer. Our framework avoids the "neighbor explosion" problem of GNNs using quantized representations combined with a low-rank version of the graph convolution matrix. We show that such a compact low-rank version of the gigantic convolution matrix is sufficient both theoretically and experimentally. In company with VQ, we design a novel approximated message passing algorithm and a nontrivial back-propagation rule for our framework. Experiments on various types of GNN backbones demonstrate the scalability and competitive performance of our framework on large-graph node classification and link prediction benchmarks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Node Classification | Reddit (test) | Accuracy94.49 | 134 | |
| Link Prediction | ogbl-collab (test) | Hits@5046.73 | 92 | |
| Node Classification | OGBN-Products | Accuracy79.08 | 86 | |
| Node Classification | Ogbn-arxiv | Mean Accuracy70.55 | 74 | |
| Node Classification | Accuracy94.85 | 66 | ||
| Node Classification | ogbn-arxiv transductive official (test) | Accuracy70.55 | 20 | |
| Node Classification | PPI inductive (test) | F1 Score97.37 | 14 | |
| Node Classification | Flickr (transductive) | Accuracy53.23 | 14 |