TL;DR
Scalify introduces a formalized scale propagation method for low-precision training of large language models, enabling efficient float8 and float16 computations with minimal accuracy loss.
Contribution
It provides a unified, end-to-end framework for tensor scaling in low-precision LLM training, simplifying adoption and improving efficiency.
Findings
Supports out-of-the-box float8 matrix multiplication and gradients
Enables float16 optimizer state storage
Open-sourced JAX implementation available
Abstract
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
