When can transformers reason with abstract symbols?
Enric Boix-Adsera, Omid Saremi, Emmanuel Abbe, Samy Bengio, and Etai Littwin, Joshua Susskind

TL;DR
This paper demonstrates that transformers can learn and generalize abstract relational reasoning tasks when trained with sufficient data, outperforming classical networks, and introduces minimal architectural modifications to enhance data efficiency.
Contribution
It proves transformers' ability to learn abstract relations and generalize in reasoning tasks, and proposes simple architectural modifications to improve data efficiency.
Findings
Transformers learn and generalize relational reasoning with enough training data.
Classical fully-connected networks fail at relational reasoning tasks.
Minimal architectural changes improve data efficiency in transformers.
Abstract
We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.
Peer Reviews
Decision·ICLR 2024 poster
This paper studies an important problem of generalization to unseen symbols in transformer architectures. Much of the contribution is formalizing a set of symbolic reasoning tasks which can be analyzed in different architectures. The authors give a theoretical analysis that suggests implementation tweaks (adding the identity matrices) which improve performance on real-valued template tasks and preserve good performance on natural language modeling. They also give an important separation result b
Though the paper provides some generic results about architectures for MLPs, the main result for transformers relies on operating the model in the kernel regime. It is unclear at the moment if any of the results would change (1) at finite width or (2) in the feature learning regime (see questions below). Further, the proposed theory bounds the generalization error in terms of the data diversity metric, which does not have an obvious scaling with samples n, making it hard to reason about the numb
As noted in the summary, I find the targeted problem valuable and exciting. The theoretical framework itself involves several assumptions that are arguably impractical (making this essentially a kernel regression take on LLMs paper), but nonetheless the tasks established earlier in the paper are very well described and the results are accordingly insightful.
1. I think the authors are over-claiming their results at several points. The simplified architecture that is studied in this work, i.e., a self-attention + MLP model, is not justified to be labeled a Transformer. Residual connections play a huge role in both optimization stability and, e.g., some of the seemingly emergent abilities of Transformers. In this sense, I would argue the paper is actually focused on understanding inductive biases and limitations of a feedforward, MLP model with self-a
The paper was a great read! The authors conduct a thorough theoretical analysis with practical insights suggesting changes to the transformer architecture. The work studies if transformers are capable of symbol binding, which is very important question and of broad interest to the community. The paper presents a diverse set of results, but makes them easy to understand to a broader audience too. The template tasks are elegant and easy to understand. It grounds the problem in a concrete scenario
The theory for the transformers seems to rely on training only the final layer of the transformer. Does that mean that there are other architectures that can also do symbol binding? It is unclear which aspect of transformers are important for symbol binding: is the attention, the MLP, the invariance properties or something else? It isn't entirely clear to me how A limitation of most theory that uses the PAC framework often has vacuous bounds. While the paper derives asymptotic trends, it is un
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNatural Language Processing Techniques · Advanced Text Analysis Techniques
