Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning

About

Ensembles, where multiple neural networks are trained individually and their predictions are averaged, have been shown to be widely successful for improving both the accuracy and predictive uncertainty of single neural networks. However, an ensemble's cost for both training and testing increases linearly with the number of networks, which quickly becomes untenable. In this paper, we propose BatchEnsemble, an ensemble method whose computational and memory costs are significantly lower than typical ensembles. BatchEnsemble achieves this by defining each weight matrix to be the Hadamard product of a shared weight among all ensemble members and a rank-one matrix per member. Unlike ensembles, BatchEnsemble is not only parallelizable across devices, where one device trains one member, but also parallelizable within a device, where multiple ensemble members are updated simultaneously for a given mini-batch. Across CIFAR-10, CIFAR-100, WMT14 EN-DE/EN-FR translation, and out-of-distribution tasks, BatchEnsemble yields competitive accuracy and uncertainties as typical ensembles; the speedup at test time is 3X and memory reduction is 3X at an ensemble of size 4. We also apply BatchEnsemble to lifelong learning, where on Split-CIFAR-100, BatchEnsemble yields comparable performance to progressive neural networks while having a much lower computational and memory costs. We further show that BatchEnsemble can easily scale up to lifelong learning on Split-ImageNet which involves 100 sequential learning tasks.

Yeming Wen, Dustin Tran, Jimmy Ba• 2020

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)
Accuracy81.25
3518
Image ClassificationCIFAR-10 (test)
Accuracy96.19
3381
Image ClassificationCIFAR-100--
622
Image ClassificationDTD--
487
Image ClassificationFlowers102
Accuracy91.7
478
Character-level Language Modelingenwik8 (test)
BPC1.616
195
ClassificationCIFAR-100 (test)
Accuracy78.3
129
Image ClassificationCIFAR-10-C
Accuracy71.67
127
Out-of-Distribution DetectionCIFAR-10
AUROC82.9
105
Out-of-Distribution DetectionSVHN
AUROC84.8
62
Showing 10 of 27 rows

Other info

Follow for update