Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Hierarchical Decomposition of Prompt-Based Continual Learning: Rethinking Obscured Sub-optimality

About

Prompt-based continual learning is an emerging direction in leveraging pre-trained knowledge for downstream continual learning, and has almost reached the performance pinnacle under supervised pre-training. However, our empirical research reveals that the current strategies fall short of their full potential under the more realistic self-supervised pre-training, which is essential for handling vast quantities of unlabeled data in practice. This is largely due to the difficulty of task-specific knowledge being incorporated into instructed representations via prompt parameters and predicted by uninstructed representations at test time. To overcome the exposed sub-optimality, we conduct a theoretical analysis of the continual learning objective in the context of pre-training, and decompose it into hierarchical components: within-task prediction, task-identity inference, and task-adaptive prediction. Following these empirical and theoretical insights, we propose Hierarchical Decomposition (HiDe-)Prompt, an innovative approach that explicitly optimizes the hierarchical components with an ensemble of task-specific prompts and statistics of both uninstructed and instructed representations, further with the coordination of a contrastive regularization strategy. Our extensive experiments demonstrate the superior performance of HiDe-Prompt and its robustness to pre-training paradigms in continual learning (e.g., up to 15.01% and 9.61% lead on Split CIFAR-100 and Split ImageNet-R, respectively). Our code is available at \url{https://github.com/thu-ml/HiDe-Prompt}.

Liyuan Wang, Jingyi Xie, Xingxing Zhang, Mingyi Huang, Hang Su, Jun Zhu• 2023

Related benchmarks

TaskDatasetResultRank
Class-incremental learningCIFAR-100--
248
Audio ClassificationESC-50 (test)
Accuracy83.75
87
Class-incremental learningImageNet-100--
82
Class-incremental learningImageNet-R B0 Inc20
Last Accuracy76.06
79
Class-incremental learningCIFAR-100 Split (test)
Avg Acc93.48
75
Class-incremental learningSplit ImageNet-R
Average Forgetting Measure4.09
57
Class-incremental learningImageNet-R 10-task
FAA76.74
54
Class-incremental learningCIFAR-100 B0_Inc5
Average Accuracy85.99
47
Audio ClassificationSpeech Commands V2 (test)
Accuracy40.1
46
Class-incremental learningImageNet-R 5-task--
45
Showing 10 of 65 rows

Other info

Code

Follow for update