Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation

About

Autoregressive models excel in modeling sequential dependencies by enforcing causal constraints, yet they struggle to capture complex bidirectional patterns due to their unidirectional nature. In contrast, mask-based models leverage bidirectional context, enabling richer dependency modeling. However, they often assume token independence during prediction, which undermines the modeling of sequential dependencies. Additionally, the corruption of sequences through masking or absorption can introduce unnatural distortions, complicating the learning process. To address these issues, we propose Bidirectional Autoregressive Diffusion (BAD), a novel approach that unifies the strengths of autoregressive and mask-based generative models. BAD utilizes a permutation-based corruption technique that preserves the natural sequence structure while enforcing causal dependencies through randomized ordering, enabling the effective capture of both sequential and bidirectional relationships. Comprehensive experiments show that BAD outperforms autoregressive and mask-based models in text-to-motion generation, suggesting a novel pre-training strategy for sequence modeling. The codebase for BAD is available on https://github.com/RohollahHS/BAD.

S. Rohollah Hosseyni, Ali Ahmad Rahmani, S. Jamal Seyedmohammadi, Sanaz Seyedin, Arash Mohammadi• 2024

Related benchmarks

TaskDatasetResultRank
Text-to-motion generationHumanML3D (test)
FID0.049
331
text-to-motion mappingKIT-ML (test)
R Precision (Top 3)0.75
275
Text-to-motion generationKIT-ML (test)
FID0.221
115
Text-to-motion generationHumanML3D 19 (test)
FID0.065
37
Text-to-motion generationKIT-ML 46 (test)
R-Precision Top 141.7
9
Suffix CompletionHumanML3D (test)
R-Precision Top-380.8
7
Temporal InpaintingHumanML3D (test)
R-Precision Top-30.81
7
Prefix PredictionHumanML3D (test)
R-Precision Top-380.6
3
Temporal OutpaintingHumanML3D (test)
R-Precision@380
3
Showing 9 of 9 rows

Other info

Code

Follow for update