GraB-sampler: Optimal Permutation-based SGD Data Sampler for PyTorch
Guanghao Wei

TL;DR
This paper introduces GraB-sampler, an efficient Python library implementing the theoretically optimal Gradient Balancing (GraB) algorithm for permutation-based data sampling in SGD, improving training efficiency with minimal overhead.
Contribution
It provides an accessible implementation of GraB with five variants, enabling practical use and demonstrating near-optimal training performance with low computational overhead.
Findings
Achieves training loss and accuracy comparable to original GraB
Reproduces results with only 8.7% additional training time
Uses 0.85% more GPU memory at peak
Abstract
The online Gradient Balancing (GraB) algorithm greedily choosing the examples ordering by solving the herding problem using per-sample gradients is proved to be the theoretically optimal solution that guarantees to outperform Random Reshuffling. However, there is currently no efficient implementation of GraB for the community to easily use it. This work presents an efficient Python library, , that allows the community to easily use GraB algorithms and proposes 5 variants of the GraB algorithm. The best performance result of the GraB-sampler reproduces the training loss and test accuracy results while only in the cost of 8.7% training time overhead and 0.85% peak GPU memory usage overhead.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsMachine Learning and Data Classification · Advanced Image and Video Retrieval Techniques · Advanced Neural Network Applications
