Classifying the classifier: dissecting the weight space of neural networks
About
This paper presents an empirical study on the weights of neural networks, where we interpret each model as a point in a high-dimensional space -- the neural weight space. To explore the complex structure of this space, we sample from a diverse selection of training variations (dataset, optimization procedure, architecture, etc.) of neural network classifiers, and train a large number of models to represent the weight space. Then, we use a machine learning approach for analyzing and extracting information from this space. Most centrally, we train a number of novel deep meta-classifiers with the objective of classifying different properties of the training setup by identifying their footprints in the weight space. Thus, the meta-classifiers probe for patterns induced by hyper-parameters, so that we can quantify how much, where, and when these are encoded through the optimization process. This provides a novel and complementary view for explainable AI, and we show how meta-classifiers can reveal a great deal of information about the training setup and optimization, by only considering a small subset of randomly selected consecutive weights. To promote further research on the weight space, we release the neural weight space (NWS) dataset -- a collection of 320K weight snapshots from 16K individually trained deep neural networks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| INR classification | F-MNIST Implicit Neural Representations (test) | Accuracy73.1 | 15 | |
| INR classification | CIFAR-10 (test) | Accuracy33 | 7 | |
| INR classification | MNIST (test) | Accuracy73.4 | 7 | |
| Accuracy Prediction | K-MNIST (test) | MSE3.27 | 6 | |
| Accuracy Prediction | F-MNIST (test) | MSE6.46 | 6 | |
| Regression | MNIST (test) | MSE7 | 6 | |
| Pruning mask prediction | MNIST (test) | Accuracy93.07 | 6 | |
| Pruning mask prediction | Fashion MNIST (test) | Accuracy96.59 | 6 | |
| Pruning mask prediction | Kuzushiji-MNIST (test) | Accuracy91.39 | 6 | |
| Predicting image classifier test accuracy | Small ResNet CIFAR-10 trained (test) | R^20.95 | 4 |