Hardware-Aware Parallel Prompt Decoding for Memory-Efficient Acceleration of LLM Inference
Hao Mark Chen, Wayne Luk, Ka Fai Cedric Yiu, Rui Li, Konstantin Mishchenko, Stylianos I. Venieris, Hongxiang Fan

TL;DR
This paper introduces a hardware-aware parallel prompt decoding method for LLM inference that significantly reduces memory and training costs while improving speed and prediction accuracy, adaptable to various GPU architectures.
Contribution
It proposes a novel parallel decoding technique inspired by human language generation, with a dynamic sparse tree optimization, enabling efficient training and inference on standard GPUs.
Findings
Achieves up to 2.49× speedup in inference.
Requires only 0.0002% trainable parameters, enabling quick training.
Maintains minimal memory overhead of 0.0004%."],
Abstract
The auto-regressive decoding of Large Language Models (LLMs) results in significant overheads in their hardware performance. While recent research has investigated various speculative decoding techniques for multi-token generation, these efforts have primarily focused on improving processing speed such as throughput. Crucially, they often neglect other metrics essential for real-life deployments, such as memory consumption and training cost. To overcome these limitations, we propose a novel parallel prompt decoding that requires only % trainable parameters, enabling efficient training on a single A100-40GB GPU in just 16 hours. Inspired by the human natural language generation process, approximates outputs generated at future timesteps in parallel by using multiple prompt tokens. This approach partially recovers the missing conditional dependency information necessary for…
Peer Reviews
Decision·Submitted to ICLR 2025
- They present a prompt token-based method to adapt the model to perform parallel tree decoding. These tokens are appended to the end of the sequence and tuned to allow for predicting multiple tokens into the future by approximating tokens generated at future timesteps in order to recover missing conditional dependency information. - Their hardware-aware sparse algorithm dynamically allocates more or fewer tokens to particular branches depending on their probability, and also incorporates hardwa
- The inference-time speedups relative to other prior decoding approaches like Medusa are relatively minor - Their approach requires fine-tuning, which may make it harder to adapt for different end use cases depending on training availability (relative to speculative decoding methods) - The memory savings at inference time of their approach relative to Medusa is smaller for larger models
Some secondary contributions are deserving positive mention, namely hardware-aware optimisation. Analysis of accuracy across token positions is helpful in determining the method's performance drivers
The paper's claim for novelty may be challenged, since the main idea of Parallel Prompt Decoding practically overlaps with one in BiTA paper (https://arxiv.org/html/2401.12522v2), published early this year. It uses the same core idea of learnable tokens fed to transformer model to generate speculative continuation. Below is a quote from BiTA paper method description: "Thanks to the transformer architectures of LLMs, we leverage multiple learnable placeholder tokens known as mask tokens, empoweri
1. PPD effectively uses pre-trained prompt tokens as placeholders, allowing LLM to generate multiple tokens in parallel and significantly boosting throughput. 2. By using prompt tokens instead of changing the model, PPD achieves a state-of-the-art acceptance rate with minimal training overheads. 3. By merging the "Verify" and "Next Guess" steps into a single model execution, PPD improves processing speed and maximizes inference efficiency. 4. PPD incorporates both online and offline decoding tre
1. ***Modest Improvement over SOTA***: While PPD achieves a notable 2.49x speedup over autoregressive models, its 1.07x advantage over Medusa is less impressive. Given that training Medusa is a one-time effort, the slight difference in training time (1.24 vs. 0.52 hours) may not justify the switch for all use cases. Consider the guessing of Medusa only happens on the last stage, **LMHeads**, PDD introduces more decoding workload as each guess is generated through the entire model. This approach
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Data Storage Technologies · Cryptography and Residue Arithmetic · Parallel Computing and Optimization Techniques
MethodsSPEED: Separable Pyramidal Pooling EncodEr-Decoder for Real-Time Monocular Depth Estimation on Low-Resource Settings
