From Projection to Prediction: Beyond Logits for Scalable Language Models
Jianbing Dong, Jianbin Chang

TL;DR
This paper proposes a unified approach to LLM output projection and loss computation that reduces memory and bandwidth usage, enabling faster training with larger batches and longer sequences.
Contribution
It introduces a novel method that combines projection and loss calculation into a single operation, bypassing explicit logits materialization for improved efficiency.
Findings
Significant memory savings during training.
Measurable speedups over standard methods.
Supports larger batch sizes and longer sequences.
Abstract
Training Large Language Models (LLMs) typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head), followed by cross-entropy loss computation against target tokens. While conceptually simple, this design incurs substantial overhead. The intermediate logits tensor, with dimensions proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth comsumption, limiting scalability and slowing training throughput. In this work, we introduce a novel approach to integrates the output projection and loss prediction into a single operation. By directly computing the loss from hidden states and target tokens, our approach bypasses explicit…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNatural Language Processing Techniques · Topic Modeling · Big Data and Digital Economy
