Learning to Prompt for Continual Learning
About
The mainstream paradigm behind continual learning has been to adapt the model parameters to non-stationary data distributions, where catastrophic forgetting is the central challenge. Typical methods rely on a rehearsal buffer or known task identity at test time to retrieve learned knowledge and address forgetting, while this work presents a new paradigm for continual learning that aims to train a more succinct memory system without accessing task identity at test time. Our method learns to dynamically prompt (L2P) a pre-trained model to learn tasks sequentially under different task transitions. In our proposed framework, prompts are small learnable parameters, which are maintained in a memory space. The objective is to optimize prompts to instruct the model prediction and explicitly manage task-invariant and task-specific knowledge while maintaining model plasticity. We conduct comprehensive experiments under popular image classification benchmarks with different challenging continual learning settings, where L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer and is directly applicable to challenging task-agnostic continual learning. Source code is available at https://github.com/google-research/l2p.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Understanding | MMLU | Accuracy2.24 | 756 | |
| Reasoning | BBH | -- | 507 | |
| Physical Commonsense Reasoning | PIQA | Accuracy54.19 | 329 | |
| Image Classification | CIFAR-100 | Accuracy83.05 | 302 | |
| Class-incremental learning | CIFAR-100 | Averaged Incremental Accuracy89.51 | 234 | |
| Image Classification | DomainNet (test) | -- | 209 | |
| Few-Shot Class-Incremental Learning | miniImageNet (test) | Accuracy (Session 1)87.2 | 173 | |
| Few-Shot Class-Incremental Learning | CIFAR100 (test) | Session 4 Top-1 Acc68.66 | 122 | |
| Class-incremental learning | ImageNet-R | Average Accuracy77.07 | 103 | |
| Class-incremental learning | ImageNet A | Average Accuracy49.39 | 86 |