Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation

About

Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.

Hongyu Cao, Yuxuan Wu, Yucheng Cai, Xianyu Zhao, Zhijian Ou• 2025

Related benchmarks

TaskDatasetResultRank
Open-domain Question AnsweringNQ--
74
Knowledge-grounded dialogDoQA
R@168.09
4
Open-domain Question AnsweringNQ
R@129.23
4
Open-domain Question AnsweringTQA
R@137.39
4
Open-domain Question AnsweringMS Marco
R@124.75
4
Knowledge-grounded dialogDoQA
BLEU-417.11
3
Knowledge-grounded dialogOR-QUAC
BLEU-47.76
3
Open-domain Question AnsweringTQA
Exact Match (EM)75.23
3
Open-domain Question AnsweringMS Marco
BLEU-135.28
3
Showing 9 of 9 rows

Other info

Follow for update