Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Zero-TPrune: Zero-Shot Token Pruning through Leveraging of the Attention Graph in Pre-Trained Transformers

About

Deployment of Transformer models on edge devices is becoming increasingly challenging due to the exponentially growing inference cost that scales quadratically with the number of tokens in the input sequence. Token pruning is an emerging solution to address this challenge due to its ease of deployment on various Transformer backbones. However, most token pruning methods require computationally expensive fine-tuning, which is undesirable in many edge deployment cases. In this work, we propose Zero-TPrune, the first zero-shot method that considers both the importance and similarity of tokens in performing token pruning. It leverages the attention graph of pre-trained Transformer models to produce an importance distribution for tokens via our proposed Weighted Page Rank (WPR) algorithm. This distribution further guides token partitioning for efficient similarity-based pruning. Due to the elimination of the fine-tuning overhead, Zero-TPrune can prune large models at negligible computational cost, switch between different pruning configurations at no computational cost, and perform hyperparameter tuning efficiently. We evaluate the performance of Zero-TPrune on vision tasks by applying it to various vision Transformer backbones and testing them on ImageNet. Without any fine-tuning, Zero-TPrune reduces the FLOPs cost of DeiT-S by 34.7% and improves its throughput by 45.3% with only 0.4% accuracy loss. Compared with state-of-the-art pruning methods that require fine-tuning, Zero-TPrune not only eliminates the need for fine-tuning after pruning but also does so with only 0.1% accuracy loss. Compared with state-of-the-art fine-tuning-free pruning methods, Zero-TPrune reduces accuracy loss by up to 49% with similar FLOPs budgets. Project webpage: https://jha-lab.github.io/zerotprune.

Hongjie Wang, Bhishma Dedhia, Niraj K. Jha• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationDTD
Accuracy70.9
487
ClassificationCars
Accuracy88.2
314
Image ClassificationImageNet (val)--
300
Image ClassificationPets
Accuracy86.9
204
Image ClassificationCUB-200
Accuracy74.4
92
Image ClassificationFlowers
Top-1 Acc95.1
80
Image ClassificationImageNet-1K
Top-1 Acc79.4
30
Image ClassificationAircrafts
Top-1 Accuracy76.7
27
Image ClassificationImageNet (val)
Top-1 Acc85.17
16
Image ClassificationIndoor67
Top-1 Acc73.7
6
Showing 10 of 10 rows

Other info

Code

Follow for update