DiffusionBlocks: Block-wise Neural Network Training via Diffusion Interpretation
About
End-to-end backpropagation requires storing activations throughout all layers, creating memory bottlenecks that limit model scalability. Existing block-wise training methods offer means to alleviate this problem, but they rely on ad-hoc local objectives and remain largely unexplored beyond classification tasks. We propose $\textit{DiffusionBlocks}$, a principled framework for transforming transformer-based networks into genuinely independent trainable blocks that maintain competitive performance with end-to-end training. Our key insight leverages the fact that residual connections naturally correspond to updates in a dynamical system. With minimal modifications to this system, we can convert the updates to those of a denoising process, where each block can be learned independently by leveraging the score matching objective. This independence enables training with gradients for only one block at a time, thereby reducing memory requirements in proportion to the number of blocks. Our experiments on a range of transformer architectures (vision, diffusion, autoregressive, recurrent-depth, and masked diffusion) demonstrate that DiffusionBlocks training matches the performance of end-to-end training while enabling scalable block-wise training on practical tasks beyond small-scale classification. DiffusionBlocks provides a theoretically grounded approach that successfully scales to modern generative tasks across diverse architectures. Code is available at https://github.com/SakanaAI/DiffusionBlocks .
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR-100 | Accuracy59.3 | 302 | |
| Image Generation | ImageNet 256x256 (train) | FID9 | 91 | |
| Image Generation | ImageNet 256x256 (test) | FID10.63 | 46 | |
| Image Generation | CIFAR10 (train) | FID30.59 | 32 | |
| Image Generation | CIFAR-10 (test) | FID37.2 | 24 | |
| Text Generation | 1 Billion Words Dataset (LM1B) (test) | MAUVE0.71 | 4 | |
| Image Classification | Tiny-ImageNet 2015 | Accuracy36.16 | 2 | |
| Text Generation | text8 | BPC1.45 | 2 | |
| Text Generation | OpenWebText (OWT) (test) | MAUVE0.82 | 2 |