Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

iDARTS: Differentiable Architecture Search with Stochastic Implicit Gradients

About

\textit{Differentiable ARchiTecture Search} (DARTS) has recently become the mainstream of neural architecture search (NAS) due to its efficiency and simplicity. With a gradient-based bi-level optimization, DARTS alternately optimizes the inner model weights and the outer architecture parameter in a weight-sharing supernet. A key challenge to the scalability and quality of the learned architectures is the need for differentiating through the inner-loop optimisation. While much has been discussed about several potentially fatal factors in DARTS, the architecture gradient, a.k.a. hypergradient, has received less attention. In this paper, we tackle the hypergradient computation in DARTS based on the implicit function theorem, making it only depends on the obtained solution to the inner-loop optimization and agnostic to the optimization path. To further reduce the computational requirements, we formulate a stochastic hypergradient approximation for differentiable NAS, and theoretically show that the architecture optimization with the proposed method, named iDARTS, is expected to converge to a stationary point. Comprehensive experiments on two NAS benchmark search spaces and the common NAS search space verify the effectiveness of our proposed method. It leads to architectures outperforming, with large margins, those learned by the baseline methods.

Miao Zhang, Steven Su, Shirui Pan, Xiaojun Chang, Ehsan Abbasnejad, Reza Haffari• 2021

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)
Accuracy70.83
3518
Image ClassificationCIFAR-10 (test)
Accuracy93.58
3381
Image ClassificationCIFAR-100 (val)
Accuracy70.57
661
Image ClassificationCIFAR-10 (val)
Top-1 Accuracy89.86
329
Image ClassificationImageNet (test)--
235
Image ClassificationCIFAR-10 NAS-Bench-201 (test)
Accuracy93.58
173
Image ClassificationCIFAR-100 NAS-Bench-201 (test)
Accuracy70.83
169
Image ClassificationImageNet-16-120 NAS-Bench-201 (test)
Accuracy40.89
139
Image ClassificationCIFAR-10 NAS-Bench-201 (val)
Accuracy89.86
119
Image ClassificationCIFAR-100 NAS-Bench-201 (val)
Accuracy70.57
109
Showing 10 of 20 rows

Other info

Code

Follow for update