Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

One-stage Prompt-based Continual Learning

About

Prompt-based Continual Learning (PCL) has gained considerable attention as a promising continual learning solution as it achieves state-of-the-art performance while preventing privacy violation and memory overhead issues. Nonetheless, existing PCL approaches face significant computational burdens because of two Vision Transformer (ViT) feed-forward stages; one is for the query ViT that generates a prompt query to select prompts inside a prompt pool; the other one is a backbone ViT that mixes information between selected prompts and image tokens. To address this, we introduce a one-stage PCL framework by directly using the intermediate layer's token embedding as a prompt query. This design removes the need for an additional feed-forward stage for query ViT, resulting in ~50% computational cost reduction for both training and inference with marginal accuracy drop < 1%. We further introduce a Query-Pool Regularization (QR) loss that regulates the relationship between the prompt query and the prompt pool to improve representation power. The QR loss is only applied during training time, so there is no computational overhead at inference from the QR loss. With the QR loss, our approach maintains ~ 50% computational cost reduction during inference as well as outperforms the prior two-stage PCL methods by ~1.4% on public class-incremental continual learning benchmarks including CIFAR-100, ImageNet-R, and DomainNet.

Youngeun Kim, Yuhang Li, Priyadarshini Panda• 2024

Related benchmarks

TaskDatasetResultRank
Domain-incremental learningCORe50
Avg Accuracy (A)98.3
49
Continual LearningImageNet-R
Accuracy50.3
15
Continual LearningCUB-200
Task Accuracy (ACC_T)52.92
11
Continual LearningCIFAR-100
Task Accuracy (ACC_T)66.64
11
Diabetic Retinopathy ClassificationDiabetic Retinopathy (APTOS, DDR, DRD) (test)
Average Accuracy81.2
9
Domain-incremental learningDR APTOS → DDR → DRD
Average Accuracy76.9
8
Domain-incremental learningSkin Cancer ISIC → HAM → DERM7
Average Accuracy72.5
8
Skin Cancer ClassificationISIC, HAM, and DERM7 (small external)
Average Accuracy72.5
8
Continual LearningAPTOS (final stage)
Accuracy74.3
8
Continual LearningDDR (final stage)
Accuracy69.7
8
Showing 10 of 11 rows

Other info

Follow for update