Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Conditional Channel Gated Networks for Task-Aware Continual Learning

About

Convolutional Neural Networks experience catastrophic forgetting when optimized on a sequence of learning problems: as they meet the objective of the current training examples, their performance on previous tasks drops drastically. In this work, we introduce a novel framework to tackle this problem with conditional computation. We equip each convolutional layer with task-specific gating modules, selecting which filters to apply on the given input. This way, we achieve two appealing properties. Firstly, the execution patterns of the gates allow to identify and protect important filters, ensuring no loss in the performance of the model for previously learned tasks. Secondly, by using a sparsity objective, we can promote the selection of a limited set of kernels, allowing to retain sufficient model capacity to digest new tasks.Existing solutions require, at test time, awareness of the task to which each example belongs to. This knowledge, however, may not be available in many practical scenarios. Therefore, we additionally introduce a task classifier that predicts the task label of each example, to deal with settings in which a task oracle is not available. We validate our proposal on four continual learning datasets. Results show that our model consistently outperforms existing methods both in the presence and the absence of a task oracle. Notably, on Split SVHN and Imagenet-50 datasets, our model yields up to 23.98% and 17.42% improvement in accuracy w.r.t. competing methods.

Davide Abati, Jakub Tomczak, Tijmen Blankevoort, Simone Calderara, Rita Cucchiara, Babak Ehteshami Bejnordi• 2020

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-10
Accuracy96.6
471
Image ClassificationMNIST
Accuracy99.8
395
Image ClassificationSVHN
Accuracy97.4
359
Image ClassificationSplit MNIST
Average Accuracy99.7
49
Class-incremental learningMNIST (test)
Average Accuracy96.08
35
Class-incremental learningSplit CIFAR-10
Accuracy70.06
26
Class-incremental learningSVHN (test)
Average Accuracy81.02
20
Image ClassificationS-MNIST (test)
Average Accuracy99.6
18
Image ClassificationS-TinyImageNet (test)
Average Accuracy49.2
14
Image ClassificationS-CIFAR100 (test)
Average Accuracy60.1
14
Showing 10 of 15 rows

Other info

Follow for update