How Transformers Learn to Plan via Multi-Token Prediction
Jianhao Huang, Zhanpeng Zhou, Renqiu Xia, Baharan Mirzasoleiman, Weijie Su, Wei Huang

TL;DR
This paper investigates how multi-token prediction (MTP) improves reasoning in Transformers, demonstrating empirical performance gains and providing a theoretical analysis of its underlying reverse reasoning mechanism.
Contribution
It introduces a theoretical framework explaining how MTP induces a reverse reasoning process in Transformers, supported by empirical results on reasoning benchmarks.
Findings
MTP outperforms NTP on reasoning tasks like graph path-finding and SAT problems.
Theoretically, MTP causes a reverse reasoning process by attending to end nodes first.
MTP's gradient decoupling property leads to more interpretable reasoning circuits.
Abstract
While next-token prediction (NTP) has been the standard objective for training language models, it often struggles to capture global structure in reasoning tasks. Multi-token prediction (MTP) has recently emerged as a promising alternative, yet its underlying mechanisms remain poorly understood. In this paper, we study how MTP facilitates reasoning, with a focus on planning. Empirically, we show that MTP consistently outperforms NTP on both synthetic graph path-finding tasks and more realistic reasoning benchmarks, such as Countdown and boolean satisfiability problems. Theoretically, we analyze a simplified two-layer Transformer on a star graph task. We prove that MTP induces a two-stage reverse reasoning process: the model first attends to the end node and then reconstructs the path by tracing intermediate nodes backward. This behavior arises from a gradient decoupling property of MTP,…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
