Transformers Provably Solve Parity Efficiently with Chain of Thought
Juno Kim, Taiji Suzuki

TL;DR
This paper provides a theoretical analysis of how transformers can efficiently learn to solve parity problems through chain-of-thought reasoning, highlighting the importance of intermediate supervision and data augmentation.
Contribution
It introduces the first theoretical framework showing transformers can learn parity efficiently with chain-of-thought, especially when using intermediate supervision and data augmentation.
Findings
Transformers require many iterations without intermediate supervision.
With teacher forcing, parity can be learned in one gradient step.
Data augmentation enables end-to-end learning without teacher forcing.
Abstract
This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental -parity problem, extending the work on RNNs by Wies et al. (2023). We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. (2) In contrast, when intermediate parities are incorporated into the loss function, our model can learn parity in one gradient update when aided by \emph{teacher forcing}, where ground-truth labels of the reasoning chain are provided at each generation step. (3) Even without teacher forcing, where the model must generate CoT chains end-to-end, parity can be…
Peer Reviews
Decision·ICLR 2025 Oral
I find the work to be of very high quality and very relevant to the current research in reasoning abilities of language model. The theoretical setup is in my opinion well chosen, very concise and easy to understand. Even though this is a theoretical paper, the authors are well aware of the current research on the applied side and the questions studied reflect it. I believe it is a strong accept, but I'm giving accept because I'm not sure whether theoretical results of this kind should have spot
I was not able to judge how limiting is the setup described in the paper overall, but for example the special masking seems quite artificial but I understand that it is required for the theoretical results. It would be nice if the authors explicitly describe the limitations they are aware of.
1. The paper innovatively uses the k-parity problem to theoretically analyze how transformers develop stepwise reasoning capabilities. 2. This paper introduces a novel hierarchical decomposition of the k-parity problem and designs a corresponding transformer architecture to handle this structure effectively. 3. This paper proposes mechanisms like data augmentation and self-consistency checks, enabling transformers to perform CoT reasoning even in the absence of explicit intermediate supervision.
1. The theoretical analysis is heavily focused on the k-parity problem, which, while illustrative, may not extend seamlessly to more complex or varied reasoning tasks. This limits the applicability of the findings to broader transformer applications. 2. The paper lacks empirical experiments to support the theoretical conclusions, which could leave readers questioning the practical effectiveness of the proposed methods in real-world scenarios or on diverse datasets.
S1. The study is very rigorous, sharing some lights on the importance of CoT on multi hop reasoning. S2. The introduction of the causal mask is indeed very simple to explain and very intuitive, but also very effective to limit the error compounding. I think this is the main contribution, given teaching forcing is not really useful on this problem. S3. This work extended previous related work on more realistic setting, offering a clearer picture of why CoT is indeed essential for the given pro
W1. Main weakness is a lack of empirical validation. I would love to see just a trained model on various settings (like with and without teacher forcing, parity problem sizes with different $k$ and $d$ values) to show that without teaching forcing the model is indeed able to learn the task. W2. The claim on the conclusion about finetuning transformer to improve multi step reasoning seems really too strong. In particular, it would be nice to show that it's easy to produce CoT data for all reaso
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsGraph Theory and Algorithms · Parallel Computing and Optimization Techniques · Neural Networks and Applications
