FlashKAT: Understanding and Addressing Performance Bottlenecks in the Kolmogorov-Arnold Transformer
Matthew Raffel, Lizhong Chen

TL;DR
This paper identifies memory stalls as the main bottleneck in the Kolmogorov-Arnold Transformer (KAT) training and introduces FlashKAT, a restructured kernel that significantly accelerates training and improves gradient accuracy.
Contribution
The paper presents FlashKAT, a novel kernel restructuring approach that reduces memory stalls and atomic operations, substantially speeding up KAT training and enhancing gradient precision.
Findings
FlashKAT achieves up to 86.5x training speedup.
Memory stalls are the primary bottleneck in KAT training.
FlashKAT reduces rounding errors in gradient computation.
Abstract
The Kolmogorov-Arnold Network (KAN) has been gaining popularity as an alternative to the multilayer perceptron (MLP) due to its greater expressiveness and interpretability. Even so, KAN suffers from training instability and being orders of magnitude slower due to its increased computational cost, limiting its applicability to large-scale tasks. Recently, the Kolmogorov-Arnold Transformer (KAT) has been proposed, achieving FLOPs comparable to traditional Transformer models with MLPs by leveraging Group-Rational KAN (GR-KAN). Unfortunately, despite the comparable FLOPs, our testing shows that KAT remains 123x slower during training, indicating that there are other performance bottlenecks beyond FLOPs. In this paper, we conduct a series of experiments to understand the root cause of the slowdown in KAT. We uncover that the slowdown can be isolated to memory stalls, linked more specifically…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsNeural Networks and Applications
Methods+ ( 1 ) ⟷ 805 ⟷ ( 330 ) ⟷ 4056|How do I file a complaint with Expedia? · Attention Is All You Need · Linear Layer · Byte Pair Encoding · Label Smoothing · Dropout · Adam · Multi-Head Attention · Dense Connections · Layer Normalization
