Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference
Zongyue Qin, Ziniu Hu, Zifan He, Neha Prakriya, Jason Cong, Yizhou Sun

TL;DR
This paper introduces MTAD, a novel decoding framework that accelerates multi-token joint decoding in large language models, reducing perplexity, improving performance, and saving energy by using an auxiliary model for efficient approximation.
Contribution
The paper proposes MTAD, a new auxiliary-model-based framework that enhances multi-token joint decoding efficiency and effectiveness in large language models.
Findings
MTAD reduces perplexity by 21.2%.
MTAD achieves a 1.42x speed-up over conventional methods.
MTAD consumes 1.54x less energy than traditional speculative decoding.
Abstract
Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsNatural Language Processing Techniques · Topic Modeling · Speech Recognition and Synthesis
MethodsSPEED: Separable Pyramidal Pooling EncodEr-Decoder for Real-Time Monocular Depth Estimation on Low-Resource Settings · OPT
