Training Chain-of-Thought via Latent-Variable Inference
Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le,, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous

TL;DR
This paper introduces a novel fine-tuning method for large language models that maximizes the marginal likelihood of correct answers using chain-of-thought prompting, effectively averaging over possible rationales without requiring expensive rationale annotations.
Contribution
It proposes an MCMC-EM fine-tuning approach that improves LLM performance on reasoning tasks by marginalizing over rationales, avoiding the need for detailed rationale supervision.
Findings
Improves accuracy on GSM8K and BIG-Bench Hard tasks
Outperforms STaR and prompt-tuning methods
Uses a variance-reduction control variate technique
Abstract
Large language models (LLMs) solve problems more accurately and interpretably when instructed to work out the answer step by step using a ``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a specific task by supervised fine-tuning, i.e., by using gradient ascent on some tunable parameters to maximize the average log-likelihood of correct answers from a labeled training set. Naively combining CoT with supervised tuning requires supervision not just of the correct answers, but also of detailed rationales that lead to those answers; these rationales are expensive to produce by hand. Instead, we propose a fine-tuning strategy that tries to maximize the \emph{marginal} log-likelihood of generating a correct answer using CoT prompting, approximately averaging over all possible rationales. The core challenge is sampling from the posterior over rationales conditioned…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Explainable Artificial Intelligence (XAI)
