From Interpretability to Performance: Optimizing Retrieval Heads for Long-Context Language Models
Youmi Ma, Naoaki Okazaki

TL;DR
This paper introduces RetMask, a training method that leverages retrieval heads in long-context language models to significantly improve their retrieval and generation capabilities without sacrificing general performance.
Contribution
The work demonstrates that mechanistic insights into retrieval heads can be used to enhance long-context performance through a novel contrastive training approach.
Findings
RetMask improves Llama-3.1's long-context retrieval by +2.28 points on HELMET at 128K.
It achieves +70% gains on citation-aware generation and +32% on passage re-ranking.
Models with sparser retrieval score distributions benefit more from RetMask.
Abstract
Advances in mechanistic interpretability have identified special attention heads, known as retrieval heads, that are responsible for retrieving information from the context. However, the role of these retrieval heads in improving model performance remains unexplored. This work investigates whether retrieval heads can be leveraged to enhance the long-context capabilities of LLMs. Specifically, we propose RetMask, a method that generates training signals by contrasting normal model outputs with those from an ablated variant in which the retrieval heads are masked. This mechanism-based approach achieves substantial improvements: +2.28 points on HELMET at 128K for Llama-3.1, with +70% gains on generation with citation and +32% on passage re-ranking, while preserving performance on general tasks. Experiments across four models in three families demonstrate that RetMask consistently improves…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
