Understanding the Generalization of Bilevel Programming in Hyperparameter Optimization: A Tale of Bias-Variance Decomposition
About
Gradient-based hyperparameter optimization (HPO) have emerged recently, leveraging bilevel programming techniques to optimize hyperparameter by estimating hypergradient w.r.t. validation loss. Nevertheless, previous theoretical works mainly focus on reducing the gap between the estimation and ground-truth (i.e., the bias), while ignoring the error due to data distribution (i.e., the variance), which degrades performance. To address this issue, we conduct a bias-variance decomposition for hypergradient estimation error and provide a supplemental detailed analysis of the variance term ignored by previous works. We also present a comprehensive analysis of the error bounds for hypergradient estimation. This facilitates an easy explanation of some phenomena commonly observed in practice, like overfitting to the validation set. Inspired by the derived theories, we propose an ensemble hypergradient strategy to reduce the variance in HPO algorithms effectively. Experimental results on tasks including regularization hyperparameter learning, data hyper-cleaning, and few-shot learning demonstrate that our variance reduction strategy improves hypergradient estimation. To explain the improved performance, we establish a connection between excess error and hypergradient estimation, offering some understanding of empirical observations.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | MNIST (test) | Accuracy93.4 | 882 | |
| Image Classification | FashionMNIST (test) | Accuracy85.03 | 218 | |
| Classification | Diabetes (test) | Accuracy78.21 | 32 | |
| Hyper-data Cleaning | MNIST (test) | Test Accuracy0.9316 | 31 | |
| Image Classification | CIFAR-10 (test) | Accuracy39.96 | 26 | |
| 5-way Few-shot Classification | miniImageNet 5-way (meta-test) | Accuracy76.5 | 24 | |
| 5-way Few-shot Classification | tieredImageNet 5-way (meta-test) | Accuracy80.62 | 24 | |
| Binary Classification | Heart (test) | -- | 16 | |
| Regression | Abalone (test) | L2 Risk5.01 | 14 | |
| Classification | a1a (test) | Loss0.3395 | 11 |