Faster Language Models with Better Multi-Token Prediction Using Tensor Decomposition
Artem Basharin, Andrei Chertkov, Ivan Oseledets

TL;DR
This paper introduces a tensor decomposition-based method for multi-token prediction in transformers, significantly improving inference speed and efficiency in text and code generation tasks while maintaining accuracy.
Contribution
It generalizes rank-1 tensor decomposition to a rank-r model for simultaneous multi-token prediction, integrating mixture of experts for efficient training.
Findings
Significant speedup in inference for text and code generation
Effective across various model sizes and training epochs
Low overhead for training and sampling
Abstract
We propose a new model for multi-token prediction in transformers, aiming to enhance sampling efficiency without compromising accuracy. Motivated by recent work that predicts the probabilities of subsequent tokens using multiple heads, we connect this approach to rank- canonical tensor decomposition. By generalizing it to a rank- canonical probability decomposition, we develop an improved model that predicts multiple tokens simultaneously. This model can also be interpreted as a mixture of experts, allowing us to leverage successful techniques from that domain for efficient and robust training. Importantly, the overall overhead for training and sampling remains low. Our method demonstrates significant improvements in inference speed for both text and code generation tasks, proving particularly beneficial within the self-speculative decoding paradigm. It maintains its effectiveness…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTensor decomposition and applications · Computational Physics and Python Applications
MethodsSPEED: Separable Pyramidal Pooling EncodEr-Decoder for Real-Time Monocular Depth Estimation on Low-Resource Settings
