Improving Generalization of Pre-trained Language Models via Stochastic Weight Averaging
Peng Lu, Ivan Kobyzev, Mehdi Rezagholizadeh, Ahmad Rashid, Ali Ghodsi,, Philippe Langlais

TL;DR
This paper adapts Stochastic Weight Averaging to fine-tune pre-trained language models, enhancing their generalization across NLP tasks without additional computational costs, and surpassing knowledge distillation methods.
Contribution
It introduces a novel application of SWA to improve PLM generalization, eliminating the need for separate teacher models and outperforming KD techniques.
Findings
SWA improves model generalization across NLP tasks.
SWA outperforms state-of-the-art KD methods.
No extra computational cost is required.
Abstract
Knowledge Distillation (KD) is a commonly used technique for improving the generalization of compact Pre-trained Language Models (PLMs) on downstream tasks. However, such methods impose the additional burden of training a separate teacher model for every new dataset. Alternatively, one may directly work on the improvement of the optimization procedure of the compact model toward better generalization. Recent works observe that the flatness of the local minimum correlates well with better generalization. In this work, we adapt Stochastic Weight Averaging (SWA), a method encouraging convergence to a flatter minimum, to fine-tuning PLMs. We conduct extensive experiments on various NLP tasks (text classification, question answering, and generation) and different model architectures and demonstrate that our adaptation improves the generalization without extra computation cost. Moreover, we…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Multimodal Machine Learning Applications
MethodsStochastic Weight Averaging
