Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers

About

We present a novel bi-directional Transformer architecture (BiXT) which scales linearly with input size in terms of computational cost and memory consumption, but does not suffer the drop in performance or limitation to only one input modality seen with other efficient Transformer-based approaches. BiXT is inspired by the Perceiver architectures but replaces iterative attention with an efficient bi-directional cross-attention module in which input tokens and latent variables attend to each other simultaneously, leveraging a naturally emerging attention-symmetry between the two. This approach unlocks a key bottleneck experienced by Perceiver-like architectures and enables the processing and interpretation of both semantics ('what') and location ('where') to develop alongside each other over multiple layers -- allowing its direct application to dense and instance-based tasks alike. By combining efficiency with the generality and performance of a full Transformer architecture, BiXT can process longer sequences like point clouds, text or images at higher feature resolutions and achieves competitive performance across a range of tasks like point cloud part segmentation, semantic image segmentation, image classification, hierarchical sequence modeling and document retrieval. Our experiments demonstrate that BiXT models outperform larger competitors by leveraging longer sequences more efficiently on vision tasks like classification and segmentation, and perform on par with full Transformer variants on sequence modeling and document retrieval -- but require $28\%$ fewer FLOPs and are up to $8.4\times$ faster.

Markus Hiller, Krista A. Ehinger, Tom Drummond• 2024

Related benchmarks

TaskDatasetResultRank
Semantic segmentationADE20K (val)
mIoU43.2
2731
Semantic segmentationADE20K
mIoU43.2
936
Point Cloud ClassificationModelNet40 (test)
Accuracy93.1
224
Image ClassificationImageNet1K (val)
Top-1 Accuracy83.1
29
Point Cloud Part SegmentationShapeNetPart 52 (test)
mIoU (Cls)84.7
4
Showing 5 of 5 rows

Other info

Code

Follow for update