Oscillation-Reduced MXFP4 Training for Vision Transformers
Yuxiang Chen, Haocheng Xi, Jun Zhu, Jianfei Chen

TL;DR
This paper introduces TetraJet, a novel training method for FP4 precision in Vision Transformers, which reduces accuracy loss caused by weight oscillation and achieves performance close to full precision training.
Contribution
The paper proposes TetraJet with Q-EMA and Q-Ramping to address weight oscillation in MXFP4 training, significantly improving accuracy over existing methods.
Findings
TetraJet reduces accuracy degradation by over 50%.
Q-EMA and Q-Ramping effectively mitigate oscillation.
Achieves competitive performance with full precision training.
Abstract
Pre-training Transformers in FP4 precision is becoming a promising approach to gain substantial speedup, but it comes with a considerable loss of accuracy. Microscaling (MX) data format provides a fine-grained per-group quantization method to improve the representation ability of the FP4 format and is supported by the next-generation Blackwell GPU architecture. However, training with MXFP4 data format still results in significant degradation and there is a lack of systematic research on the reason. In this work, we propose a novel training method TetraJet for a more accurate FP4 training. We comprehensively evaluate all of the quantizers involved in the training, and identify the weight oscillation problem in the forward pass as the main source of the degradation in MXFP4 training. Therefore, we introduce two novel methods, EMA Quantizer (Q-EMA) and Adaptive Ramping Optimizer…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Ferroelectric and Negative Capacitance Devices · Parallel Computing and Optimization Techniques
