M$^2$RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling
Mayank Mishra, Shawn Tan, Ion Stoica, Joseph Gonzalez, Tri Dao

TL;DR
This paper introduces M$^2$RNN, a matrix-valued non-linear RNN architecture that enhances language modeling by enabling better state tracking and scalability, outperforming hybrid models with fewer resources.
Contribution
The paper proposes M$^2$RNN with matrix-valued states and state size expansion, demonstrating improved performance and generalization in large-scale language modeling tasks.
Findings
M$^2$RNN achieves perfect state tracking at unseen sequence lengths.
Hybrid M$^2$RNN outperforms Gated DeltaNet by 0.4-0.5 perplexity points on 7B MoE.
Replacing a single recurrent layer with M$^2$RNN improves accuracy with minimal throughput impact.
Abstract
Transformers are highly parallel but are limited to computations in the TC complexity class, excluding tasks such as entity tracking and code execution that provably require greater expressive power. Motivated by this limitation, we revisit non-linear Recurrent Neural Networks (RNNs) for language modeling and introduce Matrix-to-Matrix RNN (MRNN): an architecture with matrix-valued hidden states and expressive non-linear state transitions. We demonstrate that the language modeling performance of non-linear RNNs is limited by their state size, and show how the state size expansion mechanism enables efficient use of tensor cores. Empirically, MRNN achieves perfect state tracking generalization at sequence lengths not seen during training. These benefits also translate to large-scale language modeling. In hybrid settings that interleave recurrent layers with attention, Hybrid…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
