Toward generalizable learning of all (linear) first-order methods via memory augmented Transformers
Sanchayan Dutta (UC Davis), Suvrit Sra (TU Munich)

TL;DR
This paper demonstrates that memory-augmented Transformers can learn and implement the entire class of linear first-order optimization methods, including advanced algorithms like conjugate gradient descent, and explores their adaptability and learnability.
Contribution
It introduces a framework where Transformers can learn all linear first-order methods, extends this to test-time adaptation, and treats LFOMs as learnable algorithms.
Findings
Transformers can implement all LFOMs, including GD, CGD, and momentum methods.
Memory-augmented Transformers outperform traditional methods in certain tasks.
LFOMs can be learned from data to improve performance.
Abstract
We show that memory-augmented Transformers can implement the entire class of linear first-order methods (LFOMs), a class that contains gradient descent (GD) and more advanced methods such as conjugate gradient descent (CGD), momentum methods and all other variants that linearly combine past gradients. Building on prior work that studies how Transformers simulate GD, we provide theoretical and empirical evidence that memory-augmented Transformers can learn more advanced algorithms. We then take a first step toward turning the learned algorithms into actually usable methods by developing a mixture-of-experts (MoE) approach for test-time adaptation to out-of-distribution (OOD) samples. Lastly, we show that LFOMs can themselves be treated as learnable algorithms, whose parameters can be learned from data to attain strong performance.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNeural Networks and Applications
MethodsLinear Regression
