Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning to Prompt for Continual Learning

About

The mainstream paradigm behind continual learning has been to adapt the model parameters to non-stationary data distributions, where catastrophic forgetting is the central challenge. Typical methods rely on a rehearsal buffer or known task identity at test time to retrieve learned knowledge and address forgetting, while this work presents a new paradigm for continual learning that aims to train a more succinct memory system without accessing task identity at test time. Our method learns to dynamically prompt (L2P) a pre-trained model to learn tasks sequentially under different task transitions. In our proposed framework, prompts are small learnable parameters, which are maintained in a memory space. The objective is to optimize prompts to instruct the model prediction and explicitly manage task-invariant and task-specific knowledge while maintaining model plasticity. We conduct comprehensive experiments under popular image classification benchmarks with different challenging continual learning settings, where L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer and is directly applicable to challenging task-agnostic continual learning. Source code is available at https://github.com/google-research/l2p.

Zifeng Wang, Zizhao Zhang, Chen-Yu Lee, Han Zhang, Ruoxi Sun, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, Tomas Pfister• 2021

Related benchmarks

TaskDatasetResultRank
Language UnderstandingMMLU
Accuracy2.24
825
ReasoningBBH--
672
Physical Commonsense ReasoningPIQA
Accuracy54.19
572
Image ClassificationCIFAR-100
Accuracy83.05
302
Class-incremental learningCIFAR-100
Averaged Incremental Accuracy89.51
248
Image ClassificationDomainNet (test)--
219
Few-Shot Class-Incremental LearningminiImageNet (test)
Accuracy (Session 1)87.2
173
Few-Shot Class-Incremental LearningCIFAR100 (test)
Session 4 Top-1 Acc68.66
122
Class-incremental learningCIFAR-100
Average Accuracy85.05
116
Class-incremental learningImageNet-R
Average Accuracy77.07
112
Showing 10 of 288 rows
...

Other info

Code

Follow for update