Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Hierarchical Decomposition of Prompt-Based Continual Learning: Rethinking Obscured Sub-optimality

About

Prompt-based continual learning is an emerging direction in leveraging pre-trained knowledge for downstream continual learning, and has almost reached the performance pinnacle under supervised pre-training. However, our empirical research reveals that the current strategies fall short of their full potential under the more realistic self-supervised pre-training, which is essential for handling vast quantities of unlabeled data in practice. This is largely due to the difficulty of task-specific knowledge being incorporated into instructed representations via prompt parameters and predicted by uninstructed representations at test time. To overcome the exposed sub-optimality, we conduct a theoretical analysis of the continual learning objective in the context of pre-training, and decompose it into hierarchical components: within-task prediction, task-identity inference, and task-adaptive prediction. Following these empirical and theoretical insights, we propose Hierarchical Decomposition (HiDe-)Prompt, an innovative approach that explicitly optimizes the hierarchical components with an ensemble of task-specific prompts and statistics of both uninstructed and instructed representations, further with the coordination of a contrastive regularization strategy. Our extensive experiments demonstrate the superior performance of HiDe-Prompt and its robustness to pre-training paradigms in continual learning (e.g., up to 15.01% and 9.61% lead on Split CIFAR-100 and Split ImageNet-R, respectively). Our code is available at \url{https://github.com/thu-ml/HiDe-Prompt}.

Liyuan Wang, Jingyi Xie, Xingxing Zhang, Mingyi Huang, Hang Su, Jun Zhu• 2023

Related benchmarks

TaskDatasetResultRank
Audio ClassificationESC-50 (test)
Accuracy83.75
84
Class-incremental learningCIFAR-100 Split (test)
Avg Acc93.48
75
Class-incremental learningSplit ImageNet-R
Average Forgetting Measure4.09
57
Class-incremental learningImageNet-R 10-task
FAA76.74
44
Class-incremental learningSplit CIFAR-100 (10-task)
CAA95.02
41
Audio ClassificationUS8K (test)
R@1 Accuracy0.7989
41
Class-incremental learningCIFAR-100 B0_Inc5
Average Accuracy85.99
36
Audio ClassificationSpeech Commands V2 (test)
Accuracy40.1
35
Class-incremental learningImageNet-R 20-task
Average Accuracy73.59
33
Class-incremental learningImageNet-R 5-task--
27
Showing 10 of 34 rows

Other info

Code

Follow for update