Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching

About

Diffusion Transformers have recently demonstrated unprecedented generative capabilities for various tasks. The encouraging results, however, come with the cost of slow inference, since each denoising step requires inference on a transformer model with a large scale of parameters. In this study, we make an interesting and somehow surprising observation: the computation of a large proportion of layers in the diffusion transformer, through introducing a caching mechanism, can be readily removed even without updating the model parameters. In the case of U-ViT-H/2, for example, we may remove up to 93.68% of the computation in the cache steps (46.84% for all steps), with less than 0.01 drop in FID. To achieve this, we introduce a novel scheme, named Learning-to-Cache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers. Specifically, by leveraging the identical structure of layers in transformers and the sequential nature of diffusion, we explore redundant computations between timesteps by treating each layer as the fundamental unit for caching. To address the challenge of the exponential search space in deep models for identifying layers to cache and remove, we propose a novel differentiable optimization objective. An input-invariant yet timestep-variant router is then optimized, which can finally produce a static computation graph. Experimental results show that L2C largely outperforms samplers such as DDIM and DPM-Solver, alongside prior cache-based methods at the same inference speed. Code is available at https://github.com/horseee/learning-to-cache

Xinyin Ma, Gongfan Fang, Michael Bi Mi, Xinchao Wang• 2024

Related benchmarks

TaskDatasetResultRank
Class-conditional Image GenerationImageNet 256x256
Inception Score (IS)240.7
967
Class-conditional Image GenerationImageNet 256x256 (val)
Inception Score (IS)225.4
493
Image GenerationImageNet 256x256 (val)
FID2.23
399
Image GenerationImageNet 512x512 (val)
FID-50K3.98
219
Class-conditional Image GenerationImageNet 512x512
FID3.76
126
Class-conditional Image GenerationImageNet 512x512 (val)
FID (Val)6.18
102
Class-conditional Image GenerationImageNet class-conditional 256x256
Inception Score (IS)246.5
61
Video GenerationImage and Video Generation
FID6.12
20
Class-to-image generationImageNet
Speedup1.25
19
Multi-stage Robotic ManipulationKitchen (test)
Success Rate (Kit_p1)100
15
Showing 10 of 23 rows

Other info

Follow for update