Implicit Chain of Thought Reasoning via Knowledge Distillation
Yuntian Deng, Kiran Prasad, Roland Fernandez, Paul Smolensky, Vishrav, Chaudhary, Stuart Shieber

TL;DR
This paper introduces an implicit reasoning method for language models that leverages hidden states distilled from explicit chain-of-thought models, enabling efficient multi-step reasoning without generating intermediate natural language steps.
Contribution
It proposes a novel implicit reasoning approach using knowledge distillation from explicit chain-of-thought models, improving reasoning efficiency and capability.
Findings
Enables solving complex math problems previously unsolvable without explicit reasoning.
Achieves reasoning speed comparable to models without chain-of-thought.
Demonstrates effectiveness on multi-digit multiplication and grade school math datasets.
Abstract
To augment language models with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the language model's internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning "horizontally" by producing intermediate words one-by-one, we distill it such that the reasoning happens "vertically" among the hidden states in different layers. We conduct…
Peer Reviews
Decision·Submitted to ICLR 2024
The idea is straightforward and the motivation is clear.
1. **The writing is not clear and some paragraphs are not rigorous enough.** The proposed pipeline includes 3 different modules, and the explanations of how they work are hard to follow. In the `information Extraction` paragraph on page 5, the representation of the CoTs is extracted from the diagonal elements of the matrix $z$. However, matrix $z$ is often not a square matrix. In that case, what does it mean to extract the elements from $z_{11}$ to $z_{LL}$? Does it mean that the hidden states o
1. The proposed method distills the CoT reasoning capability from a large model to a small one. The small LM does not have to consume its context window to conduct CoT reasoning. 2. The paper is well-organized and easy to follow.
1. As indicated by the authors as well, such implicit CoT reasoning is not interpretable and it is hard to tell whether indeed the proposed system is conducting CoT reasoning or is simply learning some reasoning shortcuts. 2. The proposed method may not generalize compositionally to questions requiring more reasoning steps or just out-of-distribution data. It forces the model to conduct CoT with limited computation.
The proposed method is original and interesting. It tries to address a challenging task for language models, ie: multi-step reasoning The paper is well written and clear to understand. The experimental section contains interesting ablation studies that shows the importance of various components. It is good to see the effect of selecting different hidden states from the teacher network, the importance of mixture on GSM8k, and the importance of optimizing both the emulator and student network we
1. Given that the method requires a teacher model that does explicit CoT reasoning to distill into the emulator, it should be evaluated against these models. Unfortunately, the proposed method is weaker than explicit CoT models, although faster. Overall, this method trades interpretability and performance (explicit CoT) for speed, and it doesn't seem like a good trade-off given the extensive literature on making faster inference Transformers. 2. Another weakness of the proposed approach is that
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Explainable Artificial Intelligence (XAI) · Online Learning and Analytics
MethodsSPEED: Separable Pyramidal Pooling EncodEr-Decoder for Real-Time Monocular Depth Estimation on Low-Resource Settings
