Reducing the Cost of Dropout in Flash-Attention by Hiding RNG with GEMM
Haiyue Ma, Jian Liu, Ronny Krashinsky

TL;DR
This paper introduces a method to hide RNG latency by overlapping it with GEMM operations in Flash-Attention, significantly improving training speed for large language models on GPUs.
Contribution
It proposes a novel overlapping technique of RNG with GEMM layers, backed by a detailed performance model, achieving notable speedups over existing methods.
Findings
1.26x speedup for RNG-GEMM overlapping over sequential execution
1.22x speedup over state-of-the-art fusion method on Llama3
Performance benefits are generalizable across architectures and models
Abstract
Dropout, a network operator, when enabled is likely to dramatically impact the performance of Flash-Attention, which in turn increases the end-to-end training time of Large-Language-Models (LLMs). The main contributor to such performance degradation is the Random Number Generation (RNG) phase. The state-of-the-art optimization is to fuse RNG into the Flash-Attention kernel. However, while RNG and Attention do not compete on compute or memory resources, they are bounded on the same lower-level architecture bottlenecks. Fusion can hardly hide RNG latency within the Attention kernel. We propose overlapping RNG with previous GEMM layers in the network to hide RNG latency and improve end-to-end performance. RNG and GEMM have distinct resource requirements and hardware bottlenecks, so they can run together without compromising each other's performance. We propose a fine-grained analytical…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Data Storage Technologies · Cloud Computing and Remote Desktop Technologies · Distributed and Parallel Computing Systems
MethodsAttention Is All You Need · Linear Layer · Softmax · Multi-Head Attention
