Adaptive Computation Depth via Learned Token Routing in Transformers
Ahmed Abdelmuniem Abdalla Mohammed

TL;DR
This paper introduces Token-Selective Attention (TSA), a learned token routing mechanism in transformers that adaptively skips layers for easier tokens, reducing computation by up to 23% with minimal quality loss.
Contribution
TSA is a lightweight, end-to-end differentiable token routing method that learns difficulty-based depth skipping without explicit regularization.
Findings
TSA reduces token-layer operations by 14-23% on language modeling tasks.
TSA achieves 0.7% lower validation loss than early exit at similar efficiency.
Routing learned during training transfers effectively to inference, enabling real speedup.
Abstract
Standard transformer architectures apply the same number of layers to every token regardless of contextual difficulty. We present Token-Selective Attention (TSA), a learned per-token gate on residual updates between consecutive transformer blocks. Each gate is a lightweight two-layer multi-layer perceptron (MLP) that produces a continuous halting probability, making the mechanism end-to-end differentiable with 1.7% parameter overhead and no changes to the base architecture. Notably, TSA learns difficulty-proportional routing without any explicit depth pressure: even at (no depth regularisation), the task-loss gradient alone drives the router to skip 20% of token-layer operations. On character-level language modeling, TSA saved 14-23% of token-layer operations (TLOps) across Tiny-Shakespeare and enwik8 at <0.5% quality loss. At matched efficiency, TSA achieved 0.7% lower…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
