FlashMask: Efficient and Rich Mask Extension of FlashAttention
Guoxia Wang, Jinle Zeng, Xiyuan Xiao, Siming Wu, Jiabin Yang, Lujing, Zheng, Zeyu Chen, Jiang Bian, Dianhai Yu, and Haifeng Wang

TL;DR
FlashMask extends FlashAttention by introducing a sparse mask representation that reduces memory complexity to linear, enabling efficient processing of long sequences in large language models with significant speedups.
Contribution
It proposes a novel sparse mask representation for FlashAttention, achieving linear memory complexity and improved kernel efficiency for long-sequence modeling.
Findings
Achieves 1.65x to 3.22x throughput speedup over existing methods.
Surpasses FlexAttention by 12.1% to 60.7% in kernel TFLOPs/s.
Supports models with over 100 billion parameters and contexts up to 128K tokens.
Abstract
The computational and memory demands of vanilla attention scale quadratically with the sequence length , posing significant challenges for processing long sequences in Transformer models. FlashAttention alleviates these challenges by eliminating the memory dependency and reducing attention latency through IO-aware memory optimizations. However, its native support for certain attention mask types is limited, and it does not inherently accommodate more complex masking requirements. Previous approaches resort to using dense masks with memory complexity, leading to inefficiencies. In this paper, we propose FlashMask, an extension of FlashAttention that introduces a column-wise sparse representation of attention masks. This approach efficiently represents a wide range of mask types and facilitates the development of optimized kernel implementations. By adopting this…
Peer Reviews
Decision·ICLR 2025 Poster
The paper is well-written and easy to understand. The results section is elaborate with a wide range of benchmarks to demonstrate the advantages of the proposed methods. The appendix section and the analysis with synthetic data to corroborate the claims are very insightful. The compute and memory utilization advantages of FlashMask are well demonstrated. The proposed sparse representation scheme is novel and should be adopted wherever applicable for its memory efficiency and ability to support l
While the results section shows that FlashMask achieves higher computational efficiency, I’m not sure if it’s attributable to the proposed columns-wise sparse representation. The computational efficiency of FlashMask comes from skipping computation on entirely masked blocks as discussed in section 4.3. However, this technique is also used in Block-Sparse FlashAttention and FlexAttention. The advantages of FlashMask over Block-Sparse FlashAttention and FlexAttention in terms of computational effi
1. The paper open-sourced a rather general sparse self-attention representation framework, which could facilitate many research and production attempts in the field. 2. The implementation is practical, shown wall-clock speed-up over FlashAttention-2.
1. It seems the implementation is limited to Paddle. It would be good to see if it can also be made more general so that the Torch/Megatron community can also leverage the framework. 2. Inference support is missing. It would make more sense to discuss how such sparse mask can be put into actual inference/serving. 3. [1] was published earlier, and also provide a general sparse self-attention training & serving framework. It would be ideal to also cite [1]. [1] S2-Attention: Hardware-Aware Conte
Although this representation might be similar to COO/CSR/CSC, this is the first time I have ever seen these techniques used in attention, one of the most important operators in LLMs.
This paper lacks two baselines: 1. Flashinfer with dense masks; 2. Flashinfer sparse mask (https://docs.flashinfer.ai/api/python/sparse.html); Although Flashinfer does not support backward, I believe it is an important baseline for SOTA attention implementation. If this comparison is presented, I will raise my score.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Steganography and Watermarking Techniques · Data Visualization and Analytics · Multimedia Communication and Technology
MethodsDirect Preference Optimization · Linear Layer · Multi-Head Attention · Layer Normalization · Dense Connections · Attention Is All You Need · Adam · Residual Connection · Position-Wise Feed-Forward Layer · Label Smoothing
