Training Deep Nets with Sublinear Memory Cost
Tianqi Chen, Bing Xu, Chiyuan Zhang, Carlos Guestrin

TL;DR
This paper introduces a novel algorithm that significantly reduces memory usage during deep neural network training, enabling deeper models with minimal extra computation, thus advancing deep learning capabilities.
Contribution
The authors present a systematic method to train deep networks with sublinear memory cost, using in-place operations and memory sharing to optimize memory efficiency.
Findings
Reduced memory for a 1000-layer residual network from 48G to 7G
Achieved memory reduction with only 30% increase in training time
Enabled training of complex RNNs on long sequences
Abstract
We propose a systematic approach to reduce the memory consumption of deep neural network training. Specifically, we design an algorithm that costs O(sqrt(n)) memory to train a n layer network, with only the computational cost of an extra forward pass per mini-batch. As many of the state-of-the-art models hit the upper bound of the GPU memory, our algorithm allows deeper and more complex models to be explored, and helps advance the innovations in deep learning research. We focus on reducing the memory cost to store the intermediate feature maps and gradients during training. Computation graph analysis is used for automatic in-place operation and memory sharing optimizations. We show that it is possible to trade computation for memory - giving a more memory efficient training algorithm with a little extra computation cost. In the extreme case, our analysis also shows that the memory…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Machine Learning and Data Classification · Domain Adaptation and Few-Shot Learning
MethodsGradient Checkpointing
