Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL
Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan, Jiang, Tong Zhang

TL;DR
This paper introduces GVM-RAFT, a dynamic sampling strategy for CoT reasoning in LLMs that minimizes gradient variance, leading to faster convergence and improved accuracy in mathematical reasoning tasks.
Contribution
It proposes a prompt-specific dynamic sample allocation method that reduces gradient variance and accelerates training in CoT reasoning models.
Findings
GVM-RAFT achieves 2-4x speedup over vanilla RAFT.
The method improves reasoning accuracy on mathematical tasks.
Dynamic sampling enhances convergence in reinforcement learning algorithms.
Abstract
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be formalized as a latent variable problem, where the model needs to generate intermediate reasoning steps. While prior approaches such as iterative reward-ranked fine-tuning (RAFT) have relied on such formulations, they typically apply uniform inference budgets across prompts, which fails to account for variability in difficulty and convergence behavior. This work identifies the main bottleneck in CoT training as inefficient stochastic gradient estimation due to static sampling strategies. We propose GVM-RAFT, a prompt-specific Dynamic Sample Allocation Strategy designed to minimize stochastic gradient variance under a computational budget constraint. The method dynamically allocates computational resources by monitoring prompt acceptance rates and stochastic gradient norms, ensuring that the resulting gradient…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsData Stream Mining Techniques · Adversarial Robustness in Machine Learning · Anomaly Detection Techniques and Applications
