TL;DR
FlashSampling is an exact, memory-efficient sampling method integrated into the matrix multiplication step, significantly speeding up large-vocabulary decoding on multiple GPUs without approximation.
Contribution
It introduces a novel kernel that fuses sampling into the LM-head matmul, enabling exact sampling with reduced memory traffic and improved scaling across GPUs.
Findings
Achieves kernel-level speedups on decode workloads across 4 GPUs.
Reduces time per output token by up to 10% in end-to-end experiments.
Replaces all-gather with streaming peer-to-peer writes for better scalability.
Abstract
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. In tensor-parallel decoding, FlashSampling replaces the all-gather of logits with streaming peer-to-peer writes: This overlaps GPU-to-GPU communication with computation and HBM loads across up to 8 GPUs, with near-ideal scaling at large batch sizes. Our kernel is exact because argmax decomposes over partitions; grouped variants for online and tensor-parallel settings are exact by hierarchical…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
