AdaSplash-2: Faster Differentiable Sparse Attention
Nuno Gon\c{c}alves, Hugo Pitorro, Vlad Niculae, Edoardo Ponti, Lei Li, Andre Martins, Marcos Treviso

TL;DR
AdaSplash-2 introduces a histogram-based initialization for $oldsymbol{ extalpha}$-entmax sparse attention, significantly improving computational efficiency and enabling better long-context training in transformers.
Contribution
It proposes a novel histogram-based initialization method that reduces normalizer computation iterations, enhancing the speed of differentiable sparse attention.
Findings
AdaSplash-2 matches or improves training time compared to FlashAttention-2 at moderate-to-high sparsity.
Models trained with AdaSplash-2's attention outperform softmax baselines in long-context tasks.
The method enables efficient long-context training with substantial gains in downstream performance.
Abstract
Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is -entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer . In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
