Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Optimal Transport Graph Neural Networks

About

Current graph neural network (GNN) architectures naively average or sum node embeddings into an aggregated graph representation -- potentially losing structural or semantic information. We here introduce OT-GNN, a model that computes graph embeddings using parametric prototypes that highlight key facets of different graph aspects. Towards this goal, we successfully combine optimal transport (OT) with parametric graph models. Graph representations are obtained from Wasserstein distances between the set of GNN node embeddings and ``prototype'' point clouds as free parameters. We theoretically prove that, unlike traditional sum aggregation, our function class on point clouds satisfies a fundamental universal approximation theorem. Empirically, we address an inherent collapse optimization issue by proposing a noise contrastive regularizer to steer the model towards truly exploiting the OT geometry. Finally, we outperform popular methods on several molecular property prediction tasks, while exhibiting smoother graph representations.

Benson Chen, Gary B\'ecigneul, Octavian-Eugen Ganea, Regina Barzilay, Tommi Jaakkola• 2020

Related benchmarks

TaskDatasetResultRank
Graph ClassificationPROTEINS
Accuracy76.6
994
Graph ClassificationMUTAG
Accuracy91.6
862
Graph ClassificationCOLLAB
Accuracy80.7
422
Graph ClassificationIMDB-M
Accuracy52.1
275
Graph ClassificationMUTAG (10-fold cross-validation)
Accuracy92.1
219
Graph ClassificationMutag (test)
Accuracy94.74
217
Graph ClassificationPROTEINS (10-fold cross-validation)
Accuracy78
214
Graph ClassificationPTC-MR
Accuracy68
197
Graph ClassificationPROTEINS (test)
Accuracy72.59
180
Graph ClassificationIMDB-B (10-fold cross-validation)
Accuracy69.1
148
Showing 10 of 29 rows

Other info

Follow for update