Trainable Transformer in Transformer
Abhishek Panigrahi, Sadhika Malladi, Mengzhou Xia, Sanjeev Arora

TL;DR
This paper introduces TinT, a novel transformer architecture that efficiently simulates and fine-tunes complex models internally during inference, enabling improved performance on language tasks with fewer parameters.
Contribution
The paper proposes the Transformer in Transformer (TinT), a new method allowing transformers to internally simulate and fine-tune complex models efficiently during inference.
Findings
TinT can simulate a 125M parameter transformer with less than 2B parameters.
TinT improves OPT-125M performance by 4-16% with limited inference steps.
The approach is compatible with various transformer variants.
Abstract
Recent works attribute the capability of in-context learning (ICL) in large pre-trained language models to implicitly simulating and fine-tuning an internal model (e.g., linear or 2-layer MLP) during inference. However, such constructions require large memory overhead, which makes simulation of more sophisticated internal models intractable. In this work, we propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of…
Peer Reviews
Decision·ICML 2024 Poster
1. New construction for simulating a backward pass within a model is provided, using fewer model parameters than prior constructions. 2. Experiments on more realistic language tasks are provided. 3. The idea of TinT might be useful in designing new language models for other applications.
1. Writtings could be improved in some places. For two examples, * In definition 2.1, what are the "relevant" auxiliary model weights? The current definition is a bit difficult for me to interpret. * In definition 2.3, are $p_t$'s referring to positional embedding? Could you explain why there aren't positional embeddings in definition 2.10. 2. Theorem 2.5 shows linear attention could be approximated by softmax attention. Can softmax attention also be approximated by linear attention? If not,
1. The paper details a novel transformer construction that enables forward and backward operations, as well as parameter updates, within a single inference pass, and without the need for weight adjustments in the TINT model. This method surpasses previous approaches in terms of parameter efficiency and the complexity of models it can simulate. 2. The TINT architecture's efficiency is corroborated through real-data evaluations.
1. The TINT model requires access to auxiliary model weights, which it uses in prefix embeddings. This dependency differs from some previous works where the transformer independently and implicitly learns the auxiliary model's weights (and architecture) from the provided in-context dataset. 2. Despite its parameter efficiency, the TINT model's reliance on prefix embeddings to access auxiliary weights may lead to longer input sequence, which could potentially reduce computational and memory effic
1. Using a transformer to perform in-context-learning to fine-tune a transformer sounds like a quite fancy idea. 2. The authors performed experiments and demonstrated the effectiveness of TinT.
The writing of this paper is not completely clear, which makes it hard to understand the exact architecture of TinT. More specifically: 1. In Figure 1, it is not clear what are the dimensions of V_k, e_i, \partial y_j. Are there multiple back-propagation and gradient update steps in the TinT forward pass? 2. What is the meaning of notation $\partial$? There are many of them in Sections 2.5 and 2.6, but I don't understand the equations that contain this symbol. 3. What are the parameters of th
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Speech Recognition and Synthesis
MethodsMulti-Head Attention · Attention Is All You Need · Layer Normalization · Absolute Position Encodings · Byte Pair Encoding · Linear Layer · Label Smoothing · Adam · Position-Wise Feed-Forward Layer · Residual Connection
