Is Flash Attention Stable?
Alicia Golden, Samuel Hsia, Fei Sun, Bilge Acun, Basil Hosmer, Yejin, Lee, Zachary DeVito, Jeff Johnson, Gu-Yeon Wei, David Brooks, Carole-Jean Wu

TL;DR
This paper investigates the numerical stability of Flash Attention, revealing it causes more deviation than baseline methods but has a limited impact on training stability compared to low-precision training.
Contribution
It introduces a framework for analyzing numeric deviation effects and applies it to assess Flash Attention's stability in large-scale training.
Findings
Flash Attention exhibits roughly ten times more numeric deviation than Baseline Attention at BF16.
The deviation in Flash Attention is 2-5 times less impactful than low-precision training.
The framework helps quantify the impact of numeric deviation on training stability.
Abstract
Training large-scale machine learning models poses distinct system challenges, given both the size and complexity of today's workloads. Recently, many organizations training state-of-the-art Generative AI models have reported cases of instability during training, often taking the form of loss spikes. Numeric deviation has emerged as a potential cause of this training instability, although quantifying this is especially challenging given the costly nature of training runs. In this work, we develop a principled approach to understanding the effects of numeric deviation, and construct proxies to put observations into context when downstream effects are difficult to quantify. As a case study, we apply this framework to analyze the widely-adopted Flash Attention optimization. We find that Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsOcular and Laser Science Research · Advanced Image Fusion Techniques · Retinal Imaging and Analysis
