Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Learning Sparse Neural Networks through $L_0$ Regularization

About

We propose a practical method for $L_0$ norm regularization for neural networks: pruning the network during training by encouraging weights to become exactly zero. Such regularization is interesting since (1) it can greatly speed up training and inference, and (2) it can improve generalization. AIC and BIC, well-known model selection criteria, are special cases of $L_0$ regularization. However, since the $L_0$ norm of weights is non-differentiable, we cannot incorporate it directly as a regularization term in the objective function. We propose a solution through the inclusion of a collection of non-negative stochastic gates, which collectively determine which weights to set to zero. We show that, somewhat surprisingly, for certain distributions over the gates, the expected $L_0$ norm of the resulting gated weights is differentiable with respect to the distribution parameters. We further propose the \emph{hard concrete} distribution for the gates, which is obtained by "stretching" a binary concrete distribution and then transforming its samples with a hard-sigmoid. The parameters of the distribution over the gates can then be jointly optimized with the original network parameters. As a result our method allows for straightforward and efficient learning of model structures with stochastic gradient descent and allows for conditional computation in a principled way. We perform various experiments to demonstrate the effectiveness of the resulting approach and regularizer.

Christos Louizos, Max Welling, Diederik P. Kingma• 2017

Related benchmarks

TaskDatasetResultRank
Image ClassificationMNIST (test)--
882
Question AnsweringSQuAD
F181.9
127
Natural Language InferenceMNLI
Accuracy (matched)78.7
80
Paraphrase IdentificationQQP
Accuracy88.1
78
Gene expression dynamics predictionHematopoesis Erythroid lineage (test)
Sparsity0.1203
12
Gene regulatory network inferenceYeast cell cycle
Sparsity34.43
12
Gene regulatory network inferenceSIM350 5% noise (test)
Sparsity33.8
12
Gene regulatory network inferenceBreast cancer in pseudotime
Sparsity10.77
12
Image ClassificationCIFAR-10 (test)
Sparsity18.2
12
Gene regulatory dynamics predictionSIM350 5% noise (test)
MSE8.5
12
Showing 10 of 17 rows

Other info

Follow for update