Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

NN-Former: Rethinking Graph Structure in Neural Architecture Representation

About

The growing use of deep learning necessitates efficient network design and deployment, making neural predictors vital for estimating attributes such as accuracy and latency. Recently, Graph Neural Networks (GNNs) and transformers have shown promising performance in representing neural architectures. However, each of both methods has its disadvantages. GNNs lack the capabilities to represent complicated features, while transformers face poor generalization when the depth of architecture grows. To mitigate the above issues, we rethink neural architecture topology and show that sibling nodes are pivotal while overlooked in previous research. We thus propose a novel predictor leveraging the strengths of GNNs and transformers to learn the enhanced topology. We introduce a novel token mixer that considers siblings, and a new channel mixer named bidirectional graph isomorphism feed-forward network. Our approach consistently achieves promising performance in both accuracy and latency prediction, providing valuable insights for learning Directed Acyclic Graph (DAG) topology. The code is available at https://github.com/XuRuihan/NNFormer.

Ruihan Xu, Haokui Zhang, Yaowei Wang, Wei Zeng, Shiliang Zhang• 2025

Related benchmarks

TaskDatasetResultRank
Accuracy PredictionNAS-Bench-101 1.0
Kendall's Tau0.877
46
Accuracy PredictionNAS-Bench-201 8 (whole dataset)
Kendall's Tau0.89
36
Latency PredictionNNLQ in-domain v1 (test)
MAPE (Average)1.11
33
Latency PredictionNNLQ Out-of-domain EfficientNet
MAPE (Avg)5.13
8
Latency PredictionNNLQ Out-of-domain MnasNet
MAPE (avg)2.71
8
Latency PredictionNNLQ Out-of-domain MobileNetV2
MAPE (avg)4.17
8
Latency PredictionNNLQ Out-of-domain NasBench201
MAPE (avg)7.93
8
Latency PredictionNNLQ Out-of-domain Average
MAPE (Average)8.39
8
Latency PredictionNNLQ Out-of-domain GoogleNet
MAPE (avg)6.74
8
Latency PredictionNNLQ Out-of-domain MobileNetV3
MAPE (%)9.07
8
Showing 10 of 16 rows

Other info

Code

Follow for update