Dispatch-Aware Ragged Attention for Pruned Vision Transformers
Seifeldin Abdellatif, Ahmad Almasri

TL;DR
This paper introduces a dispatch-aware ragged attention kernel for pruned Vision Transformers, significantly reducing overhead and improving throughput at short sequence lengths.
Contribution
It presents a lightweight Triton-based attention kernel that minimizes dispatch overhead, enabling practical speedups for pruned ViT models.
Findings
Achieves 1.88× end-to-end throughput over padded PyTorch SDPA.
Delivers 9-12% higher throughput than FlashAttention-2 varlen at standard inputs.
Provides 2.17× lower kernel latency at 80% token pruning.
Abstract
Token pruning methods for Vision Transformers (ViTs) promise quadratic reductions in attention FLOPs by dropping uninformative patches. Yet standard variable-length attention APIs -- including FlashAttention-2's varlen and PyTorch's NestedTensor SDPA -- fail to translate these savings into proportional wall-clock gains at the short post-pruning sequence lengths typical of ViTs (197 tokens). We identify a dispatch-overhead bottleneck: at these lengths, host-side kernel dispatch consumes 50\,s regardless of workload, exceeding the actual GPU compute time at moderate-to-high pruning rates. We present a lightweight bidirectional Triton attention kernel whose dispatch floor is 24\,s -- roughly 2.17 lower than FlashAttention-2 varlen -- allowing pruning savings to become visible in wall-clock time. Integrated into a complete pack-attend-unpack pipeline…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
