NN-Former: Rethinking Graph Structure in Neural Architecture Representation
About
The growing use of deep learning necessitates efficient network design and deployment, making neural predictors vital for estimating attributes such as accuracy and latency. Recently, Graph Neural Networks (GNNs) and transformers have shown promising performance in representing neural architectures. However, each of both methods has its disadvantages. GNNs lack the capabilities to represent complicated features, while transformers face poor generalization when the depth of architecture grows. To mitigate the above issues, we rethink neural architecture topology and show that sibling nodes are pivotal while overlooked in previous research. We thus propose a novel predictor leveraging the strengths of GNNs and transformers to learn the enhanced topology. We introduce a novel token mixer that considers siblings, and a new channel mixer named bidirectional graph isomorphism feed-forward network. Our approach consistently achieves promising performance in both accuracy and latency prediction, providing valuable insights for learning Directed Acyclic Graph (DAG) topology. The code is available at https://github.com/XuRuihan/NNFormer.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Accuracy Prediction | NAS-Bench-101 1.0 | Kendall's Tau0.877 | 46 | |
| Accuracy Prediction | NAS-Bench-201 8 (whole dataset) | Kendall's Tau0.89 | 36 | |
| Latency Prediction | NNLQ in-domain v1 (test) | MAPE (Average)1.11 | 33 | |
| Latency Prediction | NNLQ Out-of-domain EfficientNet | MAPE (Avg)5.13 | 8 | |
| Latency Prediction | NNLQ Out-of-domain MnasNet | MAPE (avg)2.71 | 8 | |
| Latency Prediction | NNLQ Out-of-domain MobileNetV2 | MAPE (avg)4.17 | 8 | |
| Latency Prediction | NNLQ Out-of-domain NasBench201 | MAPE (avg)7.93 | 8 | |
| Latency Prediction | NNLQ Out-of-domain Average | MAPE (Average)8.39 | 8 | |
| Latency Prediction | NNLQ Out-of-domain GoogleNet | MAPE (avg)6.74 | 8 | |
| Latency Prediction | NNLQ Out-of-domain MobileNetV3 | MAPE (%)9.07 | 8 |