Neural Attention Search
Difan Deng, Marius Lindauer

TL;DR
Neural Attention Search (NAtS) introduces a learnable framework to dynamically evaluate token importance in transformer models, enabling significant KV cache size reduction during inference without performance loss.
Contribution
NAtS is the first method to automatically optimize token retention strategies in transformers, reducing inference costs through a learnable attention mask.
Findings
Reduces KV cache size significantly during inference.
Maintains model performance after cache reduction.
Effective on both training from scratch and fine-tuning.
Abstract
We present Neural Attention Search (NAtS), a framework that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. This approach can efficiently reduce the KV cache sizes required by transformer-based models during inference and thus reduce inference costs. In this paper, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsNeural Networks and Applications
MethodsSoftmax · Attention Is All You Need
