Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Overcoming Recency Bias of Normalization Statistics in Continual Learning: Balance and Adaptation

About

Continual learning entails learning a sequence of tasks and balancing their knowledge appropriately. With limited access to old training samples, much of the current work in deep neural networks has focused on overcoming catastrophic forgetting of old tasks in gradient-based optimization. However, the normalization layers provide an exception, as they are updated interdependently by the gradient and statistics of currently observed training samples, which require specialized strategies to mitigate recency bias. In this work, we focus on the most popular Batch Normalization (BN) and provide an in-depth theoretical analysis of its sub-optimality in continual learning. Our analysis demonstrates the dilemma between balance and adaptation of BN statistics for incremental tasks, which potentially affects training stability and generalization. Targeting on these particular challenges, we propose Adaptive Balance of BN (AdaB$^2$N), which incorporates appropriately a Bayesian-based strategy to adapt task-wise contributions and a modified momentum to balance BN statistics, corresponding to the training and testing stages. By implementing BN in a continual learning fashion, our approach achieves significant performance gains across a wide range of benchmarks, particularly for the challenging yet realistic online scenarios (e.g., up to 7.68%, 6.86% and 4.26% on Split CIFAR-10, Split CIFAR-100 and Split Mini-ImageNet, respectively). Our code is available at https://github.com/lvyilin/AdaB2N.

Yilin Lyu, Liyuan Wang, Xingxing Zhang, Zicheng Sun, Hang Su, Jun Zhu, Liping Jing• 2023

Related benchmarks

TaskDatasetResultRank
Task-Incremental LearningCIFAR-10 Split (test)
Average Accuracy91.99
46
Task-Incremental LearningSplit CIFAR-100 (test)
Average Accuracy (A_T)71.7
43
Online Class-Incremental LearningCIFAR-10 Split
Final Avg Accuracy64.83
24
Online Class-Incremental LearningCIFAR-100 Split
Final Accuracy28.15
24
Online Class-Incremental LearningMini-ImageNet Split
Final Average Accuracy17.08
24
Online Task-incremental LearningCIFAR-100 Split
Forgetting Measure0.73
24
Online Task-incremental LearningMini-ImageNet Split
Forgetting1.36
24
Online Task-incremental LearningMini-ImageNet Split (test)
Avg Accuracy69.12
24
Online Task-incremental LearningCIFAR-10 Split
Forgetting Measure0.38
24
Continual LearningMNIST permuted
AT77.15
19
Showing 10 of 12 rows

Other info

Code

Follow for update