Optimal Gradient Checkpointing for Sparse and Recurrent Architectures using Off-Chip Memory
Wadjih Bencheikh, Jan Finkbeiner, Emre Neftci

TL;DR
This paper introduces memory-efficient gradient checkpointing strategies tailored for sparse RNNs and SNNs, enabling training on much longer sequences and larger networks with minimal overhead, especially on architectures like IPUs.
Contribution
It proposes the Double Checkpointing method, optimizing local memory use and reducing recomputation, thus significantly improving training scalability for sparse and recurrent neural networks.
Findings
Double Checkpointing outperforms other methods in efficiency
Enables training on sequences over 10 times longer
Supports larger networks with marginal time overhead
Abstract
Recurrent neural networks (RNNs) are valued for their computational efficiency and reduced memory requirements on tasks involving long sequence lengths but require high memory-processor bandwidth to train. Checkpointing techniques can reduce the memory requirements by only storing a subset of intermediate states, the checkpoints, but are still rarely used due to the computational overhead of the additional recomputation phase. This work addresses these challenges by introducing memory-efficient gradient checkpointing strategies tailored for the general class of sparse RNNs and Spiking Neural Networks (SNNs). SNNs are energy efficient alternatives to RNNs thanks to their local, event-driven operation and potential neuromorphic implementation. We use the Intelligence Processing Unit (IPU) as an exemplary platform for architectures with distributed local memory. We exploit its suitability…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsDistributed systems and fault tolerance · Parallel Computing and Optimization Techniques · Interconnection Networks and Systems
MethodsSpiking Neural Networks · Gradient Checkpointing
