Continual learning with hypernetworks
About
Artificial neural networks suffer from catastrophic forgetting when they are sequentially trained on multiple tasks. To overcome this problem, we present a novel approach based on task-conditioned hypernetworks, i.e., networks that generate the weights of a target model based on task identity. Continual learning (CL) is less difficult for this class of models thanks to a simple key feature: instead of recalling the input-output relations of all previously seen data, task-conditioned hypernetworks only require rehearsing task-specific weight realizations, which can be maintained in memory using a simple regularizer. Besides achieving state-of-the-art performance on standard CL benchmarks, additional experiments on long task sequences reveal that task-conditioned hypernetworks display a very large capacity to retain previous memories. Notably, such long memory lifetimes are achieved in a compressive regime, when the number of trainable hypernetwork weights is comparable or smaller than target network size. We provide insight into the structure of low-dimensional task embedding spaces (the input space of the hypernetwork) and show that task-conditioned hypernetworks demonstrate transfer learning. Finally, forward information transfer is further supported by empirical results on a challenging CL benchmark based on the CIFAR-10/100 image datasets.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | permuted MNIST (pMNIST) (test) | Accuracy97.87 | 69 | |
| Image Classification | Split MNIST | Average Accuracy95.3 | 49 | |
| Image Classification | M(CIFAR100-10, F-CelebA) (Overall) | Accuracy60.36 | 22 | |
| Image Classification | M(EMNIST-10, F-EMNIST) Overall | Accuracy62.71 | 22 | |
| Image Classification | M(EMNIST-20, F-EMNIST) | Accuracy89.7 | 22 | |
| Image Classification | M(CIFAR100-20, F-CelebA) Overall | Accuracy58.78 | 22 | |
| Image Classification | Split MNIST S (test) | Task-Averaged Accuracy99.83 | 18 | |
| Continual Learning | CIFAR-10 Split | Average Accuracy77.7 | 17 | |
| Image Classification | M(CIFAR100-10, F-CelebA) | Accuracy46.67 | 15 | |
| Image Classification | M(CIFAR100-20, F-CelebA) | Accuracy60.53 | 15 |