Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Weighted Ensemble Models Are Strong Continual Learners

About

In this work, we study the problem of continual learning (CL) where the goal is to learn a model on a sequence of tasks, such that the data from the previous tasks becomes unavailable while learning on the current task data. CL is essentially a balancing act between being able to learn on the new task (i.e., plasticity) and maintaining the performance on the previously learned concepts (i.e., stability). Intending to address the stability-plasticity trade-off, we propose to perform weight-ensembling of the model parameters of the previous and current tasks. This weighted-ensembled model, which we call Continual Model Averaging (or CoMA), attains high accuracy on the current task by leveraging plasticity, while not deviating too far from the previous weight configuration, ensuring stability. We also propose an improved variant of CoMA, named Continual Fisher-weighted Model Averaging (or CoFiMA), that selectively weighs each parameter in the weights ensemble by leveraging the Fisher information of the weights of the model. Both variants are conceptually simple, easy to implement, and effective in attaining state-of-the-art performance on several standard CL benchmarks. Code is available at: https://github.com/IemProg/CoFiMA.

Imad Eddine Marouf, Subhankar Roy, Enzo Tartaglione, St\'ephane Lathuili\`ere• 2023

Related benchmarks

TaskDatasetResultRank
Class-incremental learningCIFAR-100--
281
Class-incremental learningImageNet-R
Last Accuracy77.47
147
Class-incremental learningCUB200
Last Accuracy85.95
64
Class-incremental learningCARS 196
Last Accuracy73.35
32
Class-incremental learningCUB-200, Cars-196, CIFAR-100, ImageNet-R
Last Accuracy82.19
22
Class-incremental learningFour within-domain datasets average (test)
Last Accuracy74.63
17
Continual Few-Shot LearningCGQA (test)
Sys Score94.04
9
Compositional Few-Shot TransferCGQA 10-way 10-shot
NOC91.97
9
Continual Few-Shot LearningCOBJ (test)
SYS Score77.12
9
Continual LearningCGQA T=10 sessions
AA85.71
3
Showing 10 of 10 rows

Other info

Follow for update