Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning to Grow Pretrained Models for Efficient Transformer Training

About

Scaling transformers has led to significant breakthroughs in many domains, leading to a paradigm in which larger versions of existing models are trained and released on a periodic basis. New instances of such models are typically trained completely from scratch, despite the fact that they are often just scaled-up versions of their smaller counterparts. How can we use the implicit knowledge in the parameters of smaller, extant models to enable faster training of newer, larger models? This paper describes an approach for accelerating transformer training by learning to grow pretrained transformers, where we learn to linearly map the parameters of the smaller model to initialize the larger model. For tractable learning, we factorize the linear transformation as a composition of (linear) width- and depth-growth operators, and further employ a Kronecker factorization of these growth operators to encode architectural knowledge. Extensive experiments across both language and vision transformers demonstrate that our learned Linear Growth Operator (LiGO) can save up to 50% computational cost of training from scratch, while also consistently outperforming strong baselines that also reuse smaller pretrained models to initialize larger models.

Peihao Wang, Rameswar Panda, Lucas Torroba Hennigen, Philip Greengard, Leonid Karlinsky, Rogerio Feris, David Daniel Cox, Zhangyang Wang, Yoon Kim• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationImageNet-1K 1.0 (val)
Top-1 Accuracy60.9
1952
Image ClassificationImageNet-1K
Top-1 Acc73.7
1239
Image ClassificationStanford Cars
Accuracy87.9
635
Image ClassificationFood-101
Accuracy84.4
542
Natural Language UnderstandingGLUE
SST-292.75
531
Image ClassificationCIFAR-10
Accuracy96.9
507
Image ClassificationCIFAR-100
Accuracy82.2
435
ClassificationCars
Accuracy91.82
395
Image ClassificationCUB-200 2011
Accuracy74.8
356
Image ClassificationCIFAR100
Accuracy90.52
347
Showing 10 of 31 rows

Other info

Follow for update