Multi-Head Low-Rank Attention
Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo

TL;DR
This paper introduces Multi-Head Low-Rank Attention (MLRA), a novel method that improves large language model decoding efficiency by enabling partitionable latent states, reducing memory traffic, and achieving faster inference without sacrificing performance.
Contribution
MLRA provides a partitionable latent state mechanism for efficient distributed decoding, overcoming sharding bottlenecks of prior methods like MLA.
Findings
Achieves state-of-the-art perplexity and downstream task performance.
Delivers a 2.8× decoding speedup over MLA.
Reduces memory traffic during distributed decoding.
Abstract
Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves…
Peer Reviews
Decision·ICLR 2026 Poster
1. The paper does a good job of outlining how different mechanisms behave when sharded (Table 1), which shows MLRA achieving 1.5 dₕ per device with 4‑way TP, below GQA’s 2 dₕ target and lower than MLA/GLA under the same conditions. 2. The authors formalize translation equivariance for RoPE (§2.2–§2.3; Theorem 1), explaining why post‑RoPE projections typically break it and hence using partial RoPE so MLRA maintains a (semi) equivariance property (§3.1–§3.3). 3. Efficiency: Long‑context latency
1. Scale: The paper's primary motivation is to enable efficient inference for production-scale models but the experiments only extend to 2.9B parameters. This creates a significant gap since it's uncertain how architectural complexity introduced by the the dual-path design will scale. At 70B scale with 80 layers, it's unclear whether the base path still maintain quality advantages, or will the low-rank path dominate. The paper shows that naively sharding MLA hurts quality (Figure 2, 354M model
- The problem is well-motivated since MLA's latent vector cannot be sharded across TP devices, leading to sub-optimal speed when multiple GPUs are available. - MLRA is designed with the constraints of tensor parallelism paradigms. The integration with system-level optimizations is a major strength. - Extensive experiments demonstrate MLRA‘s superior performance, achieving higher efficiency compared with SOTA attention implementations like GQA and MLA.
- Long context tasks: The optimization of KV cache becomes more important under long context scenarios. Adding some long-text tasks (like tasks from LongBench/RULER) can better demonstrate the effectiveness of MLRA. - Model architecture: The experiments are based on LLaMA-3 architecture. MLRA should be validated on other architectures (e.g., Qwen) to prove generality. Considering the pre-training setting, it's understandable if the author can not complete this experiment during the rebuttal peri
- Paper is well-written and has good structure to explain their methods (detailed specification of dimension is quite helpful). - Motivation is highly clear, and proposed algorithm is quite well-aligned. - They compared various attention mechanisms in large scaled settings.
- What and why do you think MLRA improves the model quality fundamentally? - Could you elaborate more about throughput results of GQA? Why does it have lower and lower throughput with much longer sequences (due to custom Triton kernel or only using TP?) - Some studies regarding the effects of different DP and TP degree would be valuable to compare these attention mechanisms. - In Figure 2 experiment, why does Shard approach have higher training loss? - It'd be good to add explanation for GLA t
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Parallel Computing and Optimization Techniques · Big Data and Digital Economy
