On the Generalization Ability of Retrieval-Enhanced Transformers
Tobias Norlund, Ehsan Doostmohammadi, Richard Johansson, Marco, Kuhlmann

TL;DR
This paper investigates the generalization capabilities of Retrieval-Enhanced Transformers (RETRO), revealing that their performance gains mainly stem from token overlap with the database rather than true generalization, highlighting evaluation challenges.
Contribution
The study clarifies the relative impact of retrieval and model weights in RETRO, showing that token overlap largely explains performance improvements, challenging previous assumptions about non-trivial generalization.
Findings
Performance gains mainly due to token overlap
Limited evidence for non-trivial generalization
Highlights challenges in evaluating retrieval-augmented models
Abstract
Recent work on the Retrieval-Enhanced Transformer (RETRO) model has shown that off-loading memory from trainable weights to a retrieval database can significantly improve language modeling and match the performance of non-retrieval models that are an order of magnitude larger in size. It has been suggested that at least some of this performance gain is due to non-trivial generalization based on both model weights and retrieval. In this paper, we try to better understand the relative contributions of these two components. We find that the performance gains from retrieval largely originate from overlapping tokens between the database and the test data, suggesting less non-trivial generalization than previously assumed. More generally, our results point to the challenges of evaluating the generalization of retrieval-augmented language models such as RETRO, as even limited token overlap may…
| Param | ||
|---|---|---|
| Encoder | Num layers | 2 |
| Num heads | 14 | |
| Hidden size | 896 | |
| FFN | 3584 | |
| CA layers | [2] | |
| Decoder | Num layers | 12 |
| Num heads | 12 | |
| Hidden size | 1536 | |
| FFN | 6144 | |
| CCA layers | [6,9,12] |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Multimodal Machine Learning Applications
MethodsAttention Is All You Need · Test · Linear Layer · Adam · Multi-Head Attention · Residual Connection · Layer Normalization · Softmax · Label Smoothing · Position-Wise Feed-Forward Layer
On the Generalization Ability of Retrieval-Enhanced Transformers
Tobias Norlund1,4 Ehsan Doostmohammadi2 Richard Johansson1,3 Marco Kuhlmann2
1 Chalmers University of Technology 2 Linköping University
3 University of Gothenburg 4 Recorded Future Corresponding author, [email protected]
Abstract
Recent work on the Retrieval-Enhanced Transformer (Retro) model has shown that off-loading memory from trainable weights to a retrieval database can significantly improve language modeling and match the performance of non-retrieval models that are an order of magnitude larger in size. It has been suggested that at least some of this performance gain is due to non-trivial generalization based on both model weights and retrieval. In this paper, we try to better understand the relative contributions of these two components. We find that the performance gains from retrieval largely originate from overlapping tokens between the database and the test data, suggesting less non-trivial generalization than previously assumed. More generally, our results point to the challenges of evaluating the generalization of retrieval-augmented language models such as Retro, as even limited token overlap may significantly decrease test-time loss. We release our code and model at https://github.com/TobiasNorlund/retro
1 Introduction
Large-scale generative language models have shown promising results toward creating a general-purpose foundation for many natural language applications. While sheer scale-up has resulted in better language modeling performance, the immense costs are an inhibiting factor towards further improvements (Sharir et al., 2020).
Recent work on retrieval-augmented language models, such as the Retrieval-Enhanced Transformer (Retro; Borgeaud et al., 2022), suggests that memory can be effectively off-loaded from the model parameters to an external database. In Retro, the information retrieved from the database is used to augment the context from which the model predicts new tokens, reducing the need to memorize this information in the model parameters. This opens up for smaller language models with retained performance. Specifically, Borgeaud et al. (2022) report that, with a large enough retrieval database, Retro can achieve a performance comparable to GPT-3 (Brown et al., 2020) and Jurassic-1 (Lieber et al., 2021) on the Pile (Gao et al., 2020), at only 4% of the parameters. Similarly, Retro achieves significantly lower bits-per-byte performance compared to a baseline of the same size without retrieval.
Borgeaud et al. (2022) conclude that Retro has the capacity for non-trivial generalization based on both the model parameters and the retrieval database, even though they find that part of the performance gains can be attributed to lexical overlap between retrieval and test data. In this work, we want to better understand the nature and magnitude of this effect. Our findings indicate that performance gains111Results on Retro were originally reported in bits-per-byte, while we report results in loss. originate almost exclusively from Retro’s ability to copy tokens verbatim from retrieved data, effectively exploiting any (small or large) overlap between training and test data. This suggests that the ability of Retro to fuse retrieved and in-parameter information may be more limited than previously assumed.
2 Method
To investigate gains from retrieval, we re-implement the Retro model described by Borgeaud et al. (2022) (with a few deviations; see below). We present the model here in brevity.
2.1 The Retro Model
Retro is an autoregressive language model trained with the next-token prediction objective, where the prediction probability is conditioned on additional context retrieved from a database.
Retrieval
Retrieval occurs at the granularity of contiguous token chunks with a fixed size . More specifically, assume that Retro has already generated a sequence of tokens . Each token belongs to a chunk , where . The probability of the next token depends on the previously generated tokens and the context retrieved from the previously seen chunks:
[TABLE]
Database
Retro’s database takes the form of a key–value storage , where is a chunk from one of the indexed documents, is the immediately following chunk, and the key is the embedding of according to some embedding model . This database is used to retrieve the nearest neighbors of a chunk , based on the embedding :
[TABLE]
Architecture
Retro is based on the original Transformer architecture (Vaswani et al., 2017). Chunk neighbors are encoded by the encoder and attended to by the decoder. Due to the quadratic complexity in self-attention, each neighbor is encoded separately; all representations are then concatenated and made available to the decoder (Izacard and Grave, 2021). The original decoder is modified such that for the prediction of token , cross-attention (CA) can only attend to the neighbor representations retrieved based on the previous chunk . This is called chunked cross-attention (CCA). Furthermore, the encoder is modified to include a restricted form of cross-attention to the decoder. Specifically, the encoder CA attends to the decoder hidden states immediately before the first CCA. We refer to Borgeaud et al. (2022) for more details.
Implementation Details
For tokenizing documents, we use the pre-trained T5 tokenizer. The retrieval was performed using approximate nearest neighbor search with the high-performant faiss library (Johnson et al., 2019). We implement Retro in PyTorch (Paszke et al., 2019) and use PyTorch Lightning for distributing the training and validation data across GPUs and compute nodes. Our implementation deviates from that of Borgeaud et al. (2022) only in that we
- •
use learnable relative positional biases as in T5 Raffel et al. (2020), with a bucket for each unique relative position; and
- •
instantiate the chunk embedding model by a pretrained Sentence-BERT (SB) model (Reimers and Gurevych, 2019) instead of Bert. We deemed SB to be preferable over Bert as it is smaller (i.e. cheaper to compute) and produces embeddings of lower dimensionality (i.e. saves disk space).
2.2 Dataset
Borgeaud et al. (2022) used a multi-lingual version of MassiveText Rae et al. (2021) for both training and retrieval data. To replicate the English portion of this data, we sought open-source alternatives. MassiveText comprises text from the categories web text, news, code, books, and Wikipedia. By pooling matching categories from Pile (Gao et al., 2020) and adding the RealNews dataset (Zellers et al., 2019), we obtain a large dataset composed of all five categories, consisting of 36M documents and 52B tokens. We keep the training/validation splits from the Pile categories. For RealNews, we use the provided training set and a subsample of 16,400 documents from the validation set. The full description of our dataset is shown in Table 1.
2.3 Model Training
For our experiments, we train a Retro model that resembles the 425M model222The 425M parameters exclude embeddings. in Borgeaud et al. (2022), as shown in Table 2. We train and test on our open-source version of MassiveText as described in Section 2.2. During training, we retrieve neighbors from the training set, while at validation time, we retrieve from the union of training and validation sets. We filter out neighbors that originate from the same source document as the query chunk. Each model is trained on sequences of no more than 1,024 tokens; longer sequences are truncated. We use a chunk size of 64 and retrieve two neighbors during both training and validation. We train the model for 140k training steps with a batch size of 16. This means that only 6% of the training documents are actually used during training, excluding retrieved neighbors. We use the Adam optimizer with a fixed learning rate of .
3 Experiments
Borgeaud et al. (2022) observed that retrieval increases language modeling performance. To validate this observation, we compare two configurations of our model: Retro[on], where we enable retrieval, and Retro[off], where we remove the CCA layers, thereby reducing Retro to a standard decoder-only language model. As we can see in Figure 1, retrieval reduces the loss across all data categories, and with 11% across the full validation set. GitHub data has the lowest validation loss among all categories and is also where we see the largest reduction in loss, at 42%. Wikipedia sees the smallest reduction in loss, at only 3%. A closer comparison to the results from Borgeaud et al. (2022) is available in Appendix D.
3.1 Loss per Degree of Overlap
As Borgeaud et al. (2022) note, retrieval-based models such as Retro may more easily exploit evaluation dataset leakage. To quantify how much of the positive effect of retrieval on language modeling performance can be attributed to such leakage, the authors computed bits-per-byte (bpb) for evaluation chunks with different amounts of consecutive token overlap relative to their retrieved neighbors. This analysis showed that, while the positive effect of retrieval decreased with smaller overlaps, it was still significant at overlap levels of at most 8 contiguous tokens, which the authors considered small enough to conclude that while Retro actually learns to generalize from retrieval data, not merely copy-and-paste it. Here we investigate the hypothesis that the bpb reductions observed by Borgeaud et al. (2022) are localized exclusively in the overlapping tokens. If this was true, it would challenge the conclusion that Retro learns non-trivial generalizations based on retrieval data.
To test our hypothesis, we sort the validation set tokens into buckets based on their leftward overlap. Specifically, we put a token into a bucket , where is the largest number such that and the tokens preceding it consecutively overlap with some neighboring chunk in . For example, the bucket contains all tokens for which the unigram appears in some neighbor, but not the bigram ; the bucket contains all for which overlaps but not , and so on. As a special case, the bucket contains all tokens that do not overlap with any of its neighbors. This includes all tokens that occur in a first chunk , which lacks neighbors.
In Figure 2 we plot the average loss per bucket,
[TABLE]
as a function of . Here, is the loss when predicting token using Retro[on]333The sizes of each bucket (accumulated over the validation data) are shown in the appendix, Figure 4.. We see that the loss drastically decreases as the consecutive overlap increases. For example, at an overlap of tokens, the loss is only 6% of the loss for non-overlapping tokens. This suggests that Retro enters “copy mode” when the previous tokens overlap with those from a neighbor.
3.2 Loss Reductions per Degree of Overlap
For a more detailed analysis of the effect of overlap on predictive performance, we look at the token-specific loss differences between the two configurations Retro[off] and Retro[on]:
[TABLE]
Note that a loss difference is positive if the access to the retrieved context reduces the token-specific loss for . The overall reduction in loss visible in Figure 1 is the average of the loss differences across all tokens in the validation data. By aggregating loss differences per bucket , we get a fine-grained picture of how the reductions are distributed with respect to different degrees of consecutive overlap. This is illustrated in Figure 3.
For non-overlapping tokens (), we can see that there are both positive and negative differences, with a small negative net. For all overlapping tokens (), the net differences are positive, and for buckets with 3 or more overlapping tokens, there are almost no negative differences at all.444We note a sudden increase in accumulated loss difference for which is expected considering the way in which we construct the buckets; see Appendix C for more details. This shows that the largest share of all loss reductions originates from tokens that are consecutively overlapping in neighbors. Interestingly, the net differences are positive even for very small degrees of overlap. Borgeaud et al. (2022) considered reductions in bits-per-byte from chunks with up to 8 consecutively overlapping tokens as evidence of a non-trivial generalization capacity. However, our results suggest that even a small number of overlapping tokens may cause a large reduction in loss, which we take as an argument against this conclusion.
4 Related Work
Equipping language models with a retrievable external memory has been extensively studied (Guu et al., 2020; Karpukhin et al., 2020; Lewis et al., 2020; Izacard and Grave, 2021; Li et al., 2022). Explicitly leveraging the training data through retrieval to reduce perplexity is proposed in kNN-LM (Khandelwal et al., 2020). kNN-LM matches the leftward context with the leftward context of all training data tokens, and explicitly interpolates between generating and copying the next token. A recent study analyzes kNN-LM to better understand the causes of performance gains (Xu et al., 2023). Similar to our findings in Retro, lexical overlap has also been found to play a significant role in explaining retrieval performance gains in kNN-LM as well (Drozdov et al., 2022). The idea of kNN-LM is extended in Spalm (Yogatama et al., 2021) to instead learn a gating function that facilitates more dynamic interpolation.
In both kNN-LM and Spalm, retrieval is incorporated at the top of the network. This might induce a bias towards surface-level rather than semantic augmentation. In contrast, retrieval in Retro is incorporated in lower layers of the network, which opens up for more sophisticated integration of the retrieved information. Our results suggest, however, that retrieval in Retro also contributes at the surface rather than at the semantic level, similar to the previous works.
5 Conclusions and Future Work
The capacity of language models for generalization is often measured intrinsically using perplexity, loss or bits-per-byte on held-out validation data. Low perplexity language models perform well as few-shot learners on many downstream tasks due to their capacity to both memorize and non-trivially combine textual information from many sources (Brown et al., 2020; Rae et al., 2021; Lieber et al., 2021; Chowdhery et al., 2022). The hope is that we can externalize memory to reduce the footprints of language models without reducing generalization and downstream task performance.
Our results show that the low loss in Retro almost exclusively originates from tokens overlapping between retrieval and validation data, rather than from more sophisticated generalization. To better understand this effect, it would be interesting to modify the retrieval component and deliver semantically similar but lexically different context during training. If the retrieved context is uninformative, the model will learn to ignore it, but if the context is too specific (e.g. literal overlap) the model will learn to copy. By better balancing between these two modes, models may become better at utilizing retrieved information at a deeper and more generalizable level.
Limitations
We have made our best effort in trying to reproduce the model and results of Borgeaud et al. (2022). Nonetheless, our experiments were performed on one of the smaller model sizes and with a dataset that is only 2.5% of their size (52 billion vs. 2 trillion tokens). This was due to computational constraints and lack of larger open datasets. However, as was also shown by Borgeaud et al. (2022), the performance gain of retrieval is constant with respect to model size. We speculate that larger Retro models mostly improve with respect to loss on tokens that are not overlapping, which would not change our conclusions here.
One noteworthy limitation of our work is the fact that we compare to a non-retrieval baseline (Retro[off]) that was trained with access to retrieved context. We were not able to train a separate non-retrieval baseline due to computational constraints, but note that the bits-per-byte results of Retro[off] and the baseline in Borgeaud et al. (2022) were close to identical.
Acknowledgements
This work was partially supported by the Wallenberg AI, Autonomous Systems and Software Program (WASP) funded by the Knut and Alice Wallenberg Foundation. The computations were enabled by resources provided by the National Academic Infrastructure for Supercomputing in Sweden (NAISS) at Alvis partially funded by the Swedish Research Council through grant agreement no. 2022-06725, and by the Berzelius resources provided by the Knut and Alice Wallenberg Foundation at the National Supercomputer Centre.
Appendix A MassiveOpenText statistics
Statistics on the number of documents, chunks and tokens for each split and text category are shown in Table 1.
Appendix B Retro model details
We show hyperparameters of our Retro model in Table 2.
Appendix C Consecutively overlapping tokens
As explained in Section 3.1, we sort validation set tokens into buckets denoted depending on the longest overlapping leftward context.
In Figure 4 we show the number of tokens in each bucket. We note a big “jump” from to , which can be explained by the following rationale. A neighbor to a chunk is retrieved based on the similarity between and . In the case where both and , tokens in will be put into with . The jump in Figure 4 indicates such duplicates are common in our data.
Appendix D Model validation
As we aim to reproduce the 425M model trained in Borgeaud et al. (2022), it is important to validate that the implementations are equivalent and that their evaluation results are comparable. However, evaluations of the 425M model in Borgeaud et al. (2022) on the Pile are not available, making it hard to make direct comparisons. Borgeaud et al. (2022) report evaluation results on the C4 (Raffel et al., 2022) dataset, with various sizes of retrieval datasets. For their setup with 36B retrieval tokens, which is the most similar to our own retrieval size, they report that bits-per-byte is reduced by 2% (from 0.92 to 0.90) when using retrieval. That could be compared to our results on Pile-CC, as both datasets originate from Common Crawl. In our experiments, loss is reduced by 7% (from 3.05 to 2.83) on Pile-CC.
Evaluations on the Pile in Borgeaud et al. (2022) are only reported for their largest model (7B params) and largest retrieval set (2T tokens). For example, on Pile–GitHub their reduction is 53% whereas our reduction is 42%.
While these numbers are not directly comparable, we believe they indicate that our reimplementation of the Retro model is working as expected.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1Borgeaud et al. (2022) Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George Bm Van Den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, Diego De Las Casas, Aurelia Guy, Jacob Menick, Roman Ring, Tom Hennigan, Saffron Huang, Loren Maggiore, Chris Jones, Albin Cassirer, Andy Brock, Michela Paganini, Geoffrey Irving, Oriol Vinyals, Simon Osindero, Karen Simonyan, Jack Rae, Erich Elsen, and Laurent Sifre. 2022. Improving language
- 2Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam Mc Candlish, Alec Radford, Ilya Sutskever, and Dario Amodei. 2020
- 3Chowdhery et al. (2022) Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa De · doi ↗
- 4Drozdov et al. (2022) Andrew Drozdov, Shufan Wang, Razieh Rahimi, Andrew Mc Callum, Hamed Zamani, and Mohit Iyyer. 2022. You can’t pick your neighbors, or can you? when and how to rely on retrieval in the k 𝑘 k nn-lm. ar Xiv preprint ar Xiv:2210.15859 .
- 5Gao et al. (2020) Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, Shawn Presser, and Connor Leahy. 2020. The Pile: An 800gb dataset of diverse text for language modeling . ar Xiv preprint ar Xiv:2101.00027 .
- 6Guu et al. (2020) Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, and Mingwei Chang. 2020. Realm: Retrieval-augmented language model pre-training . In International Conference on Machine Learning , pages 3929–3938. PMLR.
- 7Izacard and Grave (2021) Gautier Izacard and Edouard Grave. 2021. Leveraging passage retrieval with generative models for open domain question answering . In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume , pages 874–880, Online. Association for Computational Linguistics. · doi ↗
- 8Johnson et al. (2019) Jeff Johnson, Matthijs Douze, and Hervé Jégou. 2019. Billion-scale similarity search with GP Us . IEEE Transactions on Big Data , 7(3):535–547. · doi ↗
