Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Optimal low-rank stochastic gradient estimation for LLM training

About

Large language model (LLM) training is often bottlenecked by memory constraints and stochastic gradient noise in extremely high-dimensional parameter spaces. Motivated by empirical evidence that many LLM gradient matrices are effectively low-rank during training, we present an unbiased, memory-efficient, low-rank matrix estimator with the lowest variance that is applicable across common stochastic gradient estimation paradigms. The core idea is to project a high-dimensional stochastic gradient estimator onto a random low-dimensional subspace and lift it back, reducing memory while keeping the estimator unbiased and controlling mean-squared error via an optimally designed projection distribution, including Haar--Stiefel projections. The projection distribution is derived by solving a constrained functional optimization problem, yielding an optimal random projector that guides algorithm design. Empirically, the resulting low-rank gradient estimators deliver both practical memory savings and improved training behavior. In RoBERTa-large fine-tuning, our method attains the lowest peak GPU memory among compared methods (e.g., 3.83GB versus 16.7GB for full BP) while remaining competitive in accuracy; in autoregressive LLM pretraining (LLaMA-20M/60M/100M), our method outperforms the traditional methods, supporting the benefit of the proposed optimal projection strategy.

Zehao Li, Tao Ren, Zishi Zhang, Xi Chen, Yijie Peng• 2026

Related benchmarks

TaskDatasetResultRank
Natural Language InferenceRTE
Accuracy63.7
448
Question ClassificationTREC
Accuracy80.9
259
Sentiment ClassificationSST-2
Accuracy91.3
184
Natural Language InferenceSNLI
Accuracy74.3
180
Natural Language InferenceMNLI--
80
Sentiment ClassificationSST-5
Accuracy43.2
46
Showing 6 of 6 rows

Other info

Follow for update