Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

A Generalization of Transformer Networks to Graphs

About

We propose a generalization of transformer neural network architecture for arbitrary graphs. The original transformer was designed for Natural Language Processing (NLP), which operates on fully connected graphs representing all connections between the words in a sequence. Such architecture does not leverage the graph connectivity inductive bias, and can perform poorly when the graph topology is important and has not been encoded into the node features. We introduce a graph transformer with four new properties compared to the standard model. First, the attention mechanism is a function of the neighborhood connectivity for each node in the graph. Second, the positional encoding is represented by the Laplacian eigenvectors, which naturally generalize the sinusoidal positional encodings often used in NLP. Third, the layer normalization is replaced by a batch normalization layer, which provides faster training and better generalization performance. Finally, the architecture is extended to edge feature representation, which can be critical to tasks s.a. chemistry (bond type) or link prediction (entity relationship in knowledge graphs). Numerical experiments on a graph benchmark demonstrate the performance of the proposed graph transformer architecture. This work closes the gap between the original transformer, which was designed for the limited case of line graphs, and graph neural networks, that can work with arbitrary graphs. As our architecture is simple and generic, we believe it can be used as a black box for future applications that wish to consider transformer and graphs.

Vijay Prakash Dwivedi, Xavier Bresson• 2020

Related benchmarks

TaskDatasetResultRank
Node ClassificationCora
Accuracy86.17
1215
Graph ClassificationPROTEINS
Accuracy77.25
994
Node ClassificationCiteseer
Accuracy72.51
931
Graph ClassificationMUTAG
Accuracy83.9
862
Node ClassificationPubmed
Accuracy88.79
819
Node ClassificationChameleon
Accuracy65.55
640
Node ClassificationWisconsin
Accuracy82.4
627
Node ClassificationTexas
Accuracy0.8444
616
Node ClassificationSquirrel
Accuracy49.5
591
Node ClassificationCornell
Accuracy70.37
582
Showing 10 of 92 rows
...

Other info

Follow for update