Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

DIFFormer: Scalable (Graph) Transformers Induced by Energy Constrained Diffusion

About

Real-world data generation often involves complex inter-dependencies among instances, violating the IID-data hypothesis of standard learning paradigms and posing a challenge for uncovering the geometric structures for learning desired instance representations. To this end, we introduce an energy constrained diffusion model which encodes a batch of instances from a dataset into evolutionary states that progressively incorporate other instances' information by their interactions. The diffusion process is constrained by descent criteria w.r.t.~a principled energy function that characterizes the global consistency of instance representations over latent structures. We provide rigorous theory that implies closed-form optimal estimates for the pairwise diffusion strength among arbitrary instance pairs, which gives rise to a new class of neural encoders, dubbed as DIFFormer (diffusion-based Transformers), with two instantiations: a simple version with linear complexity for prohibitive instance numbers, and an advanced version for learning complex structures. Experiments highlight the wide applicability of our model as a general-purpose encoder backbone with superior performance in various tasks, such as node classification on large graphs, semi-supervised image/text classification, and spatial-temporal dynamics prediction.

Qitian Wu, Chenxiao Yang, Wentao Zhao, Yixuan He, David Wipf, Junchi Yan• 2023

Related benchmarks

TaskDatasetResultRank
Node ClassificationwikiCS
Accuracy73.46
198
Node ClassificationPhoto
Mean Accuracy95.1
165
Node ClassificationPhysics
Accuracy96.6
145
Node ClassificationComputers
Mean Accuracy91.99
143
Node Classificationamazon-ratings
Accuracy47.84
138
Node ClassificationRoman-Empire
Accuracy79.1
135
Node ClassificationCS
Accuracy94.78
128
Node Classificationquestions
ROC AUC0.7215
87
Node ClassificationOGBN-Products
Accuracy74.16
86
Node ClassificationOgbn-arxiv
Mean Accuracy69.86
74
Showing 10 of 26 rows

Other info

Follow for update