The Belief State Transformer
Edward S. Hu, Kwangjun Ahn, Qinghua Liu, Haoran Xu, Manan Tomar, Ada Langford, Jayden Teoh, Bryon Xu, David Yan, Dinesh Jayaraman, Alex Lamb, John Langford

TL;DR
The paper introduces the Belief State Transformer, a novel model that predicts tokens using both prefix and suffix inputs, improving performance in goal-conditioned text tasks and overcoming limitations of traditional transformers.
Contribution
It presents a new transformer architecture that learns a belief state for better handling of prefix-suffix prediction tasks across domains.
Findings
Outperforms Fill-in-the-Middle in story writing tasks
Achieves better goal-conditioned decoding and inference
Provides high-quality representations on small-scale problems
Abstract
We introduce the "Belief State Transformer", a next-token predictor that takes both a prefix and suffix as inputs, with a novel objective of predicting both the next token for the prefix and the previous token for the suffix. The Belief State Transformer effectively learns to solve challenging problems that conventional forward-only transformers struggle with, in a domain-independent fashion. Key to this success is learning a compact belief state that captures all relevant information necessary for accurate predictions. Empirical ablations show that each component of the model is essential in difficult scenarios where standard Transformers fall short. For the task of story writing with known prefixes and suffixes, our approach outperforms the Fill-in-the-Middle method for reaching known goals and demonstrates improved performance even when the goals are unknown. Altogether, the Belief…
Peer Reviews
Decision·ICLR 2025 Poster
The paper presents a new form of training that results in better belief state representation learning for LMs. This is shown with the performance on the star graph task where the belief state transformer achieves 100% performance. Similarly, the results on TinyStories also shows that this training regime achieves better LLM-as-Judge scores for completion of stories with prefix and suffix provided.
The task StarGraph seems to be very specific to the type of training objective proposed in the paper. The training involves learning the tokens left-to-right and right-to-left which seems aligned to the problem. Specifically, such an solution was mentioned in Bachmann & Nagarajan (2024), where the LM can learn to predict from "right-to-left" starting at the goal and ending in the start state. In order to claim that the proposed training objective results in a better belief state representation i
- The paper addresses the important challenge that modern architectures lack planning ahead mechanisms to rely on before generating solutions. - By predicting both the next token for the left context and the previous token for the right context, the model uses a more complex training scenario with more training signal (O(n^2)) compared to FIM and regular next-token language modeling. - Despite the more complex training, the model retains the same autoregressive inference mechanism as regular lan
- Comparison with FIM on Star-Shaped Graphs. The paper does not provide results for the Fill-in-the-Middle (FIM) approach on the star-shaped graph task. Including FIM as a baseline in this task would offer a more comprehensive comparison and help determine whether the improvements are due to the model architecture or the specific training objectives. - The paper does not discuss whether a backward language model could solve the star-shaped graph task more effectively than a forward LM. Exploring
- This paper addresses an important problem in improving the non-myopic prediction capabilities of current auto-regressive language models. Although this study is limited to small-scale settings, its design contributes meaningfully to expanding the architectural space of flexible language models. - The proofs in Section 4 further strengthen the findings and nicely highlight the biases present in modeling non-myopic dependencies across various objectives (namely BSTs, next token prediction, and t
The study lacks experiments on more realistic, large-scale settings or with larger model sizes. For instance, evaluating bidirectional representations in tasks like code infilling could provide more robust insights.
Videos
Belief state transformers | Microsoft Research Forum· youtube
Taxonomy
TopicsBayesian Modeling and Causal Inference · Intelligent Tutoring Systems and Adaptive Learning · AI-based Problem Solving and Planning
MethodsAttention Is All You Need · Linear Layer · Dense Connections · Label Smoothing · Layer Normalization · Residual Connection · Position-Wise Feed-Forward Layer · Adam · Multi-Head Attention · Softmax
