On Warm-Starting Neural Network Training
About
In many real-world deployments of machine learning systems, data arrive piecemeal. These learning scenarios may be passive, where data arrive incrementally due to structural properties of the problem (e.g., daily financial data) or active, where samples are selected according to a measure of their quality (e.g., experimental design). In both of these cases, we are building a sequence of models that incorporate an increasing amount of data. We would like each of these models in the sequence to be performant and take advantage of all the data that are available to that point. Conventional intuition suggests that when solving a sequence of related optimization problems of this form, it should be possible to initialize using the solution of the previous iterate -- to "warm start" the optimization rather than initialize from scratch -- and see reductions in wall-clock time. However, in practice this warm-starting seems to yield poorer generalization performance than models that have fresh random initializations, even though the final training losses are similar. While it appears that some hyperparameter settings allow a practitioner to close this generalization gap, they seem to only do so in regimes that damage the wall-clock gains of the warm start. Nevertheless, it is highly desirable to be able to warm-start neural network training, as it would dramatically reduce the resource usage associated with the construction of performant deep learning systems. In this work, we take a closer look at this empirical phenomenon and try to understand when and how it occurs. We also provide a surprisingly simple trick that overcomes this pathology in several important situations, and present experiments that elucidate some of its properties.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR-100 (test) | Test Acc (Last)63.67 | 30 | |
| Image Classification | SVHN (test) | Test Accuracy94.83 | 30 | |
| Continual Reinforcement Learning | MiniGrid | Mean SR46.3 | 10 | |
| Image Classification | CIFAR-10 (test) | Accuracy (Test, Last)94.73 | 10 | |
| Image Classification | SVHN (test) | Test Acc (Last)94.27 | 8 | |
| Image Classification | CIFAR-100 (test) | Final Test Accuracy59.61 | 3 |