Approximation of Log-Partition Function in Policy Mirror Descent Induces Implicit Regularization for LLM Post-Training
Zhenghao Xu, Qin Lu, Changlong Yu, Tuo Zhao

TL;DR
This paper introduces PMD-mean, an approximation method for policy mirror descent in reinforcement learning for large language models, which enhances stability and efficiency through implicit regularization.
Contribution
It proposes a novel approximation of the log-partition function in PMD, revealing its implicit regularization effect and demonstrating improved performance in math reasoning tasks.
Findings
PMD-mean achieves superior performance on math reasoning tasks.
It improves stability and time efficiency in RL for LLMs.
The method introduces an adaptive mixed KL–χ² regularizer.
Abstract
Policy mirror descent (PMD) provides a principled framework for reinforcement learning (RL) by iteratively solving KL-regularized policy improvement subproblems. While this approach has been adopted in training advanced LLMs such as Kimi K1.5/K2, the ideal closed-form PMD updates require reliable partition function estimation, a significant challenge when working with limited rollouts in the vast action spaces of LLMs. We investigate a practical algorithm, termed PMD-mean, that approximates the log-partition term with the mean reward under the sampling policy and performs regression in log-policy space. Specifically, we characterize the population solution of PMD-mean and demonstrate that it implicitly optimizes mirror descent subproblems with an adaptive mixed KL-- regularizer. This additional regularization constrains large probability changes, producing more…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsReinforcement Learning in Robotics · Domain Adaptation and Few-Shot Learning · Stochastic Gradient Optimization Techniques
