Task-Robust Model-Agnostic Meta-Learning
About
Meta-learning methods have shown an impressive ability to train models that rapidly learn new tasks. However, these methods only aim to perform well in expectation over tasks coming from some particular distribution that is typically equivalent across meta-training and meta-testing, rather than considering worst-case task performance. In this work we introduce the notion of "task-robustness" by reformulating the popular Model-Agnostic Meta-Learning (MAML) objective [Finn et al. 2017] such that the goal is to minimize the maximum loss over the observed meta-training tasks. The solution to this novel formulation is task-robust in the sense that it places equal importance on even the most difficult and/or rare tasks. This also means that it performs well over all distributions of the observed tasks, making it robust to shifts in the task distribution between meta-training and meta-testing. We present an algorithm to solve the proposed min-max problem, and show that it converges to an $\epsilon$-accurate point at the optimal rate of $\mathcal{O}(1/\epsilon^2)$ in the convex setting and to an $(\epsilon, \delta)$-stationary point at the rate of $\mathcal{O}(\max\{1/\epsilon^5, 1/\delta^5\})$ in nonconvex settings. We also provide an upper bound on the new task generalization error that captures the advantage of minimizing the worst-case task loss, and demonstrate this advantage in sinusoid regression and image classification experiments.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Few-shot Image Classification | Omniglot Alphabets (meta-test) | Average Score93.1 | 6 | |
| Few-shot Image Classification | Omniglot Meta-Training Alphabets (train) | Average Performance97.4 | 6 | |
| Image Classification | mini-ImageNet (train) | Average Score62.2 | 5 | |
| System Identification | Pendulum (test) | Average MSE0.76 | 5 | |
| Few-shot Image Classification | mini-ImageNet (Four Meta-Testing Tasks) | Average Accuracy48.5 | 3 | |
| Few-shot Image Classification | mini-ImageNet Eight Meta (train) | Average Accuracy63.2 | 3 | |
| Few-shot Sinusoid Regression | Sinusoid 490 tasks 5-shot (test) | Avg MSE1.09 | 3 | |
| Few-shot Sinusoid Regression | Sinusoid 490 meta-test tasks 10-shot (test) | Average MSE0.77 | 3 |