Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning

About

Computer vision models suffer from a phenomenon known as catastrophic forgetting when learning novel concepts from continuously shifting training data. Typical solutions for this continual learning problem require extensive rehearsal of previously seen data, which increases memory costs and may violate data privacy. Recently, the emergence of large-scale pre-trained vision transformer models has enabled prompting approaches as an alternative to data-rehearsal. These approaches rely on a key-query mechanism to generate prompts and have been found to be highly resistant to catastrophic forgetting in the well-established rehearsal-free continual learning setting. However, the key mechanism of these methods is not trained end-to-end with the task sequence. Our experiments show that this leads to a reduction in their plasticity, hence sacrificing new task accuracy, and inability to benefit from expanded parameter capacity. We instead propose to learn a set of prompt components which are assembled with input-conditioned weights to produce input-conditioned prompts, resulting in a novel attention-based end-to-end key-query scheme. Our experiments show that we outperform the current SOTA method DualPrompt on established benchmarks by as much as 4.5% in average final accuracy. We also outperform the state of art by as much as 4.4% accuracy on a continual learning benchmark which contains both class-incremental and domain-incremental task shifts, corresponding to many practical settings. Our code is available at https://github.com/GT-RIPL/CODA-Prompt

James Seale Smith, Leonid Karlinsky, Vyshnavi Gutta, Paola Cascante-Bonilla, Donghyun Kim, Assaf Arbelle, Rameswar Panda, Rogerio Feris, Zsolt Kira• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100
Accuracy86.25
302
Class-incremental learningCIFAR-100
Averaged Incremental Accuracy91.55
234
Image ClassificationDomainNet (test)--
209
Few-Shot Class-Incremental LearningminiImageNet (test)
Accuracy (Session 1)88.86
173
Few-Shot Class-Incremental LearningCIFAR100 (test)
Session 4 Top-1 Acc71.91
122
Class-incremental learningImageNet-R
Average Accuracy82.06
103
Class-incremental learningImageNet A
Average Accuracy53.54
86
Class-incremental learningCIFAR-100 10 (test)
Average Top-1 Accuracy91.08
75
Class-incremental learningCIFAR-100 Split (test)
Avg Acc86.94
75
Few-Shot Class-Incremental LearningCUB-200
Session 1 Accuracy78.1
75
Showing 10 of 126 rows
...

Other info

Follow for update