TL;DR
KAISA is a scalable second-order optimizer framework for deep neural networks that adapts memory, communication, and computation to improve convergence speed and scalability on large models and hardware.
Contribution
KAISA introduces an adaptable framework that optimizes memory and communication tradeoffs, enabling faster convergence and better scalability for large neural network models.
Findings
KAISA converges 18.1-36.3% faster than baseline optimizers.
KAISA achieves 32.5% and 41.6% faster convergence under fixed memory budgets for ResNet-50 and BERT-Large.
KAISA maintains or improves scaling efficiency compared to traditional optimizers.
Abstract
Kronecker-factored Approximate Curvature (K-FAC) has recently been shown to converge faster in deep neural network (DNN) training than stochastic gradient descent (SGD); however, K-FAC's larger memory footprint hinders its applicability to large models. We present KAISA, a K-FAC-enabled, Adaptable, Improved, and ScAlable second-order optimizer framework that adapts the memory footprint, communication, and computation given specific models and hardware to improve performance and increase scalability. We quantify the tradeoffs between memory and communication cost and evaluate KAISA on large models, including ResNet-50, Mask R-CNN, U-Net, and BERT, on up to 128 NVIDIA A100 GPUs. Compared to the original optimizers, KAISA converges 18.1-36.3% faster across applications with the same global batch size. Under a fixed memory budget, KAISA converges 32.5% and 41.6% faster in ResNet-50 and…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
MethodsMulti-Head Attention · Attention Is All You Need · Linear Layer · Region Proposal Network · Concatenated Skip Connection · Adam · Layer Normalization · Weight Decay · Dropout · Max Pooling
