Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Measuring and regularizing networks in function space

About

To optimize a neural network one often thinks of optimizing its parameters, but it is ultimately a matter of optimizing the function that maps inputs to outputs. Since a change in the parameters might serve as a poor proxy for the change in the function, it is of some concern that primacy is given to parameters but that the correspondence has not been tested. Here, we show that it is simple and computationally feasible to calculate distances between functions in a $L^2$ Hilbert space. We examine how typical networks behave in this space, and compare how parameter $\ell^2$ distances compare to function $L^2$ distances between various points of an optimization trajectory. We find that the two distances are nontrivially related. In particular, the $L^2/\ell^2$ ratio decreases throughout optimization, reaching a steady value around when test error plateaus. We then investigate how the $L^2$ distance could be applied directly to optimization. We first propose that in multitask learning, one can avoid catastrophic forgetting by directly limiting how much the input/output function changes between tasks. Secondly, we propose a new learning rule that constrains the distance a network can travel through $L^2$-space in any one update. This allows new examples to be learned in a way that minimally interferes with what has previously been learned. These applications demonstrate how one can measure and regularize function distances directly, without relying on parameters or local approximations like loss curvature.

Ari S. Benjamin, David Rolnick, Konrad Kording• 2018

Related benchmarks

TaskDatasetResultRank
Continual LearningSequential MNIST
Avg Acc97.79
149
Class-incremental learningCIFAR10 (test)
Average Accuracy59.62
59
Class-incremental learningMNIST (test)
Average Accuracy81.25
35
Image ClassificationS-CIFAR-10 Task-IL
Accuracy94.32
33
Image ClassificationS-CIFAR-10 Class-IL
Accuracy30.91
32
Class-incremental learningCIFAR100-B0 5 steps (test)
Last Step Top-1 Acc29.99
31
Class-incremental learningCIFAR100 B0 (20 steps) (test)
Last Step Top-1 Acc13.1
31
Image ClassificationR-MNIST Domain-IL
Accuracy94.19
28
Image ClassificationP-MNIST Domain-IL
Accuracy90.87
28
ClassificationS-CIFAR-10
Accuracy30.91
26
Showing 10 of 47 rows

Other info

Follow for update