Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought
Jianhao Huang, Zixuan Wang, Jason D. Lee

TL;DR
This paper demonstrates that transformers trained with Chain of Thought prompting can learn to perform multi-step gradient descent, enabling better in-context learning and generalization in linear regression tasks.
Contribution
It provides a theoretical analysis showing how CoT prompting allows transformers to implement multi-step gradient descent, a capability absent in models without CoT.
Findings
Transformers with CoT can perform multi-step gradient descent.
Looped transformers outperform non-looped in in-context linear regression.
CoT prompting significantly improves empirical performance.
Abstract
Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
