DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training
Dacheng Li, Rulin Shao, Anze Xie, Eric P. Xing, Xuezhe Ma, Ion Stoica,, Joseph E. Gonzalez, Hao Zhang

TL;DR
DISTFLASHATTN introduces a distributed, memory-efficient attention mechanism enabling training of long-context LLMs with significant speedups and longer sequence handling, surpassing existing methods in efficiency and scalability.
Contribution
The paper presents DISTFLASHATTN, a novel distributed attention mechanism with three key techniques, improving memory efficiency and enabling longer sequence training for LLMs.
Findings
Achieves 8x longer sequences than baseline methods.
Provides 4.45 - 5.64x speedup over Ring Self-Attention.
Supports sequence lengths up to 512K with notable efficiency gains.
Abstract
FlashAttention (Dao, 2023) effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU. In this paper, we introduce DISTFLASHATTN, a distributed memory-efficient attention mechanism optimized for long-context LLMs training. We propose three key techniques: token-level workload balancing, overlapping key-value communication, and a rematerialization-aware gradient checkpointing algorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequence lengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 - 5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 - 2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses. Code is available at…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Advanced Neural Network Applications · Natural Language Processing Techniques
MethodsGradient Checkpointing
