Tree Cross Attention
Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Yoshua Bengio,, Mohamed Osama Ahmed

TL;DR
Tree Cross Attention (TCA) introduces a logarithmic token retrieval method using a tree structure, enabling efficient inference while maintaining performance across classification and regression tasks.
Contribution
The paper proposes Tree Cross Attention, a novel module that reduces token retrieval complexity from linear to logarithmic, improving efficiency in token-based inference.
Findings
TCA achieves comparable accuracy to standard Cross Attention.
ReTreever outperforms Perceiver IO with the same token count.
TCA significantly reduces inference complexity.
Abstract
Cross Attention is a popular method for retrieving information from a set of context tokens for making predictions. At inference time, for each prediction, Cross Attention scans the full set of tokens. In practice, however, often only a small subset of tokens are required for good performance. Methods such as Perceiver IO are cheap at inference as they distill the information to a smaller-sized set of latent tokens on which cross attention is then applied, resulting in only complexity. However, in practice, as the number of input tokens and the amount of information to distill increases, the number of latent tokens needed also increases significantly. In this work, we propose Tree Cross Attention (TCA) - a module based on Cross Attention that only retrieves information from a logarithmic number of tokens for performing…
Peer Reviews
Decision·ICLR 2024 poster
- The paper is well written and builds the theory coherently. - The proposed cross-attention architecture, TCA, along with the general purpose retrieval model, ReTreever is novel. - Because ReTreever uses reinforcement learning to learn the internal node representations, the reward used for optimization can be non-differentiable like accuracy, which improves performance over a reward based on cross entropy because the reward model is simpler in case of accuracy. - The reasoning behind each of th
- It would be good if a similar row (as given in Table 2) can be added to Table 1 for Perceiver IO with increased latent tokens that matches the performance of TCA on the copy task. - Theoretical complexity is fine, but the paper should also report wall-clock time for ReTreever and compare it with the full Transformer+Cross Attention and Perceiver IO models. I am guessing the tree approach is not parallelizable on accelerated devices like GPUs, but it would be good to see if there's considerable
I enjoyed reading the paper. The paper is well-written and easy to follow. The idea is simple and clever.
Overall, I believe the paper is pretty complete. I am mostly curious about how to make this method work for self-attention and (masked) autoregressive modeling. Larger-scale experiments would be appreciated, as the current experiments are quite small-scale. Presumedly the challenges of training the tree expansion policy would increase with harder datasets. One suggestion for a larger-scale experiment would be training a translation or summarization model and replacing the encoder attention wit
The idea to construct context as a tree is interesting and could have broad implications in constructing context for language models use cases including agent trajectories, in-context learning examples and retrieved documents and more.
- **More context on the baseline IO Perceiver**: The authors need a background section for IO perceiver so the work is self-contained. With the current version, IO perceiver, though being a famous and well-cited paper, is not clearly stated. - **Speedup by attending to fewer context tokens**: One claimed benefit of the method is that it attends to fewer tokens to context when performing the task, which I assume would result in an inference speedup. But the work does not explicitly measure if TC
Code & Models
Videos
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Machine Learning and Data Classification · Topic Modeling
MethodsPerceiver IO
