Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Learning to Prompt for Continual Learning

About

The mainstream paradigm behind continual learning has been to adapt the model parameters to non-stationary data distributions, where catastrophic forgetting is the central challenge. Typical methods rely on a rehearsal buffer or known task identity at test time to retrieve learned knowledge and address forgetting, while this work presents a new paradigm for continual learning that aims to train a more succinct memory system without accessing task identity at test time. Our method learns to dynamically prompt (L2P) a pre-trained model to learn tasks sequentially under different task transitions. In our proposed framework, prompts are small learnable parameters, which are maintained in a memory space. The objective is to optimize prompts to instruct the model prediction and explicitly manage task-invariant and task-specific knowledge while maintaining model plasticity. We conduct comprehensive experiments under popular image classification benchmarks with different challenging continual learning settings, where L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer and is directly applicable to challenging task-agnostic continual learning. Source code is available at https://github.com/google-research/l2p.

Zifeng Wang, Zizhao Zhang, Chen-Yu Lee, Han Zhang, Ruoxi Sun, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, Tomas Pfister• 2021

Related benchmarks

TaskDatasetResultRank
Language UnderstandingMMLU
Accuracy2.24
756
ReasoningBBH--
507
Physical Commonsense ReasoningPIQA
Accuracy54.19
329
Image ClassificationCIFAR-100
Accuracy83.05
302
Class-incremental learningCIFAR-100
Averaged Incremental Accuracy89.51
234
Image ClassificationDomainNet (test)--
209
Few-Shot Class-Incremental LearningminiImageNet (test)
Accuracy (Session 1)87.2
173
Few-Shot Class-Incremental LearningCIFAR100 (test)
Session 4 Top-1 Acc68.66
122
Class-incremental learningImageNet-R
Average Accuracy77.07
103
Class-incremental learningImageNet A
Average Accuracy49.39
86
Showing 10 of 182 rows
...

Other info

Code

Follow for update