Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation
Hongyu Cao, Yuxuan Wu, Yucheng Cai, Xianyu Zhao, Zhijian Ou

TL;DR
This paper introduces JSA-RAG, a novel end-to-end training method for retrieval-augmented generation models that improves over traditional approaches by reducing bias and variance in gradient estimates.
Contribution
It develops a joint stochastic approximation algorithm for training RAG models, addressing bias and high variance issues in previous methods.
Findings
JSA-RAG significantly outperforms vanilla RAG and VRAG on five datasets.
The method improves generation quality, retrieval accuracy, and gradient stability.
Extensive experiments validate the effectiveness of JSA-RAG across tasks.
Abstract
Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
