Optimal low-rank stochastic gradient estimation for LLM training
Zehao Li, Tao Ren, Zishi Zhang, Xi Chen, Yijie Peng

TL;DR
This paper introduces an optimal low-rank stochastic gradient estimator that reduces memory usage and improves training efficiency for large language models by projecting gradients onto low-dimensional subspaces with minimal variance.
Contribution
It proposes a novel unbiased low-rank gradient estimator with optimal projection distribution, applicable across various stochastic gradient paradigms, enhancing memory efficiency and training performance.
Findings
Achieves significant GPU memory savings in RoBERTa-large fine-tuning.
Outperforms traditional gradient estimation methods in LLM pretraining.
Maintains competitive accuracy while reducing memory consumption.
Abstract
Large language model (LLM) training is often bottlenecked by memory constraints and stochastic gradient noise in extremely high-dimensional parameter spaces. Motivated by empirical evidence that many LLM gradient matrices are effectively low-rank during training, we present an unbiased, memory-efficient, low-rank matrix estimator with the lowest variance that is applicable across common stochastic gradient estimation paradigms. The core idea is to project a high-dimensional stochastic gradient estimator onto a random low-dimensional subspace and lift it back, reducing memory while keeping the estimator unbiased and controlling mean-squared error via an optimally designed projection distribution, including Haar--Stiefel projections. The projection distribution is derived by solving a constrained functional optimization problem, yielding an optimal random projector that guides algorithm…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsStochastic Gradient Optimization Techniques · Advanced Neural Network Applications · Tensor decomposition and applications
