On a few pitfalls in KL divergence gradient estimation for RL
Yunhao Tang, R\'emi Munos

TL;DR
This paper identifies common pitfalls in estimating KL divergence gradients for reinforcement learning in language models, emphasizing correct implementation methods to ensure accurate gradient computation.
Contribution
It highlights specific implementation errors in KL gradient estimation and provides correct approaches, improving RL training for large language models.
Findings
Incorrect differentiation through KL estimates leads to wrong gradients.
Sequential nature of KL estimation affects gradient accuracy.
Proper implementation improves RL training stability.
Abstract
We point out a few pitfalls in implementing gradient estimation for KL divergence in RL training for LLM, as seen in a number of open source projects and papers. The first major pitfall is to differentiate through the KL estimate as loss functions to minimize KL divergence. We show that such implementations are generally incorrect and do not produce the desired KL gradient. Secondly, we show that some implementations do not account for the sequential nature of the estimation problem and produce a partial gradient at best. We demonstrate the impact of such issues with illustrative tabular and LLM experiments, and show the correct way to implement the KL gradient.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNatural Language Processing Techniques · Domain Adaptation and Few-Shot Learning · Imbalanced Data Classification Techniques
