FlashRNN: I/O-Aware Optimization of Traditional RNNs on modern hardware
Korbinian P\"oppel, Maximilian Beck, Sepp Hochreiter

TL;DR
FlashRNN introduces hardware-optimized kernels and a parallelization approach for traditional RNNs, significantly improving their speed and capacity on modern GPUs, enabling better sequence modeling with state-tracking.
Contribution
The paper presents a novel hardware-aware optimization framework and parallelization method for traditional RNNs, achieving substantial speed-ups and larger hidden sizes on GPUs.
Findings
50x speed-up over vanilla PyTorch RNNs
40x larger hidden sizes enabled
Open-source kernels and optimization library released
Abstract
While Transformers and other sequence-parallelizable neural network architectures seem like the current state of the art in sequence modeling, they specifically lack state-tracking capabilities. These are important for time-series tasks and logical reasoning. Traditional RNNs like LSTMs and GRUs, as well as modern variants like sLSTM do have these capabilities at the cost of strictly sequential processing. While this is often seen as a strong limitation, we show how fast these networks can get with our hardware-optimization FlashRNN in Triton and CUDA, optimizing kernels to the register level on modern GPUs. We extend traditional RNNs with a parallelization variant that processes multiple RNNs of smaller hidden state in parallel, similar to the head-wise processing in Transformers. To enable flexibility on different GPU variants, we introduce a new optimization framework for…
Peer Reviews
Decision·ICLR 2025 Poster
- The paper is easy to follow and well organized. - Hardware aware fine tuning - Comprehensive benchmarking
- The paper emphasizes the register level CUDA optimization. It'd be great if the authors can show how the implementation is close to the roofline through profiling. - I believe comparing to pytorch is not fair as pytorch has overhead to launch its internal CUDA kernel. - Several fusion can be achieved through compilation. How is it different? Also, not using compilation seems to be unfair comparison. - Why FlashAttention2 matters in the RNN?
+ Hardware-aware optimization for accelerating RNN model training and inference is promising. + It claims that FlashRNN will be open-sourced, which helps further research in this direction.
- Certain critical procedures are somewhat unclear and would benefit from further clarification. - The experimental section could be strengthened with additional details and improvements.
- The paper is overall well-written, uses clear language and structure. - it tackles an important problem of overcoming a fundamental limitation important for further scaling rnns. This would enable state-intensive tasks where state continuity and stepwise dependency matter, which might be less naturally handled by Transformers - They compare to relevant Backends (e.g., cuDNN)
They mention the HASTE library and that they overcome their limitations. However, I think they don't benchmark directly against HASTE which would be relevant.
Code & Models
Videos
Taxonomy
TopicsNeural Networks and Applications · Brain Tumor Detection and Classification
MethodsLib
