Efficient Data Subset Selection to Generalize Training Across Models: Transductive and Inductive Networks
About
Existing subset selection methods for efficient learning predominantly employ discrete combinatorial and model-specific approaches which lack generalizability. For an unseen architecture, one cannot use the subset chosen for a different model. To tackle this problem, we propose $\texttt{SubSelNet}$, a trainable subset selection framework, that generalizes across architectures. Here, we first introduce an attention-based neural gadget that leverages the graph structure of architectures and acts as a surrogate to trained deep neural networks for quick model prediction. Then, we use these predictions to build subset samplers. This naturally provides us two variants of $\texttt{SubSelNet}$. The first variant is transductive (called as Transductive-$\texttt{SubSelNet}$) which computes the subset separately for each model by solving a small optimization problem. Such an optimization is still super fast, thanks to the replacement of explicit model training by the model approximator. The second variant is inductive (called as Inductive-$\texttt{SubSelNet}$) which computes the subset using a trained subset selector, without any optimization. Our experiments show that our model outperforms several methods across several real datasets
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | FMNIST | Speedup69.24 | 21 | |
| Neural Architecture Search | CIFAR-10 (test) | Test Error Rate2.68 | 21 | |
| Image Classification | CIFAR10 | Speedup16.52 | 18 | |
| Neural Architecture Search | NAS-Bench-101 CIFAR-10 (test) | -- | 18 | |
| Image Classification | Tiny-ImageNet | Speedup3.97 | 14 | |
| Amortized cost estimation | CIFAR10 (test) | Amortization Cost1.50e+3 | 12 | |
| Image Classification | CIFAR100 | Speedup3.47 | 11 | |
| Subset Selection | fMNIST (train) | Speedup69.24 | 10 | |
| Image Classification | Caltech-256 | Speedup3.16 | 9 | |
| Hyper-parameter optimization | CIFAR10 (test) | Test Error2.7 | 8 |