Adjoint sharding for very long context training of state space models
Xingzi Xu, Amir Tavanaei, Kavosh Asadi, Karim Bouyarmane

TL;DR
This paper introduces adjoint sharding, a novel gradient computation technique that significantly reduces memory usage during training of large language models on very long contexts, enabling longer context training.
Contribution
The paper presents adjoint sharding, a new method based on the adjoint technique, to enable efficient training of large models with very long input contexts, which was previously limited by memory constraints.
Findings
Memory usage reduced by up to 3X with adjoint sharding.
Enabled training with context lengths above 100K tokens for a 1.27B parameter model.
Allowed increasing maximum context length from 35K to over 100K tokens.
Abstract
Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsFault Detection and Control Systems
MethodsSPEED: Separable Pyramidal Pooling EncodEr-Decoder for Real-Time Monocular Depth Estimation on Low-Resource Settings
