Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Meta-Learning to Improve Pre-Training

About

Pre-training (PT) followed by fine-tuning (FT) is an effective method for training neural networks, and has led to significant performance improvements in many domains. PT can incorporate various design choices such as task and data reweighting strategies, augmentation policies, and noise models, all of which can significantly impact the quality of representations learned. The hyperparameters introduced by these strategies therefore must be tuned appropriately. However, setting the values of these hyperparameters is challenging. Most existing methods either struggle to scale to high dimensions, are too slow and memory-intensive, or cannot be directly applied to the two-stage PT and FT learning process. In this work, we propose an efficient, gradient-based algorithm to meta-learn PT hyperparameters. We formalize the PT hyperparameter optimization problem and propose a novel method to obtain PT hyperparameter gradients by combining implicit differentiation and backpropagation through unrolled optimization. We demonstrate that our method improves predictive performance on two real-world domains. First, we optimize high-dimensional task weighting hyperparameters for multitask pre-training on protein-protein interaction graphs and improve AUROC by up to 3.9%. Second, we optimize a data augmentation neural network for self-supervised PT with SimCLR on electrocardiography data and improve AUROC by up to 1.9%.

Aniruddh Raghu, Jonathan Lorraine, Simon Kornblith, Matthew McDermott, David Duvenaud• 2021

Related benchmarks

TaskDatasetResultRank
Pathology detection from ECG dataPTB-XL (test)
AUC0.867
15
Protein Function PredictionPPI 40 tasks (test)
Mean AUC78.6
13
Protein Function PredictionProtein Function Prediction 10 held-out tasks (test)
AUC0.77
11
Multitask protein function predictionPPI Full FT Access
AUC78.6
5
Multitask protein function predictionPPI Task Generalization
AUC0.77
5
Multitask protein function predictionPPI Full FT Access (test)
AUC78.6
5
Multitask protein function predictionPPI Partial FT Access - 50% data (test)
AUC0.782
5
Multitask protein function predictionPPI Partial FT Access - Unseen tasks (test)
AUC0.77
5
Showing 8 of 8 rows

Other info

Follow for update