Tuning Mixed Input Hyperparameters on the Fly for Efficient Population Based AutoRL
About
Despite a series of recent successes in reinforcement learning (RL), many RL algorithms remain sensitive to hyperparameters. As such, there has recently been interest in the field of AutoRL, which seeks to automate design decisions to create more general algorithms. Recent work suggests that population based approaches may be effective AutoRL algorithms, by learning hyperparameter schedules on the fly. In particular, the PB2 algorithm is able to achieve strong performance in RL tasks by formulating online hyperparameter optimization as time varying GP-bandit problem, while also providing theoretical guarantees. However, PB2 is only designed to work for continuous hyperparameters, which severely limits its utility in practice. In this paper we introduce a new (provably) efficient hierarchical approach for optimizing both continuous and categorical variables, using a new time-varying bandit algorithm specifically designed for the population based training regime. We evaluate our approach on the challenging Procgen benchmark, where we show that explicitly modelling dependence between data augmentation and other hyperparameters improves generalization.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Reinforcement Learning | Procgen (test) | BigFish Return10.6 | 21 | |
| Reinforcement Learning | Procgen CaveFlyer 1.0 (train) | Mean Performance (Train)7.5 | 6 | |
| Reinforcement Learning | Procgen Jumper 1.0 (train levels) | Mean Train Performance9.2 | 6 | |
| Reinforcement Learning | Procgen Leaper 1.0 (train) | Mean Train Performance7.1 | 6 | |
| Reinforcement Learning | Procgen BigFish 1.0 (train) | Mean Train Performance18.2 | 6 | |
| Reinforcement Learning | Procgen CoinRun 1.0 (train) | Mean Train Performance9.9 | 6 | |
| Reinforcement Learning | Procgen FruitBot 1.0 (train) | Mean Train Performance30.9 | 6 | |
| Reinforcement Learning | Procgen StarPilot 1.0 (train) | Mean Train Performance41.8 | 6 |