Sharpness-Aware Pretraining Mitigates Catastrophic Forgetting
Ishaan Watts, Catherine Li, Sachin Goyal, Jacob Mitchell Springer, Aditi Raghunathan

TL;DR
This paper demonstrates that pretraining with sharpness-aware optimization techniques like SAM reduces catastrophic forgetting in models across various sizes and post-training scenarios.
Contribution
It introduces the use of sharpness-aware pretraining methods to mitigate forgetting, showing consistent improvements across multiple models and tasks.
Findings
Sharpness-aware minimization reduces forgetting by up to 80%.
Applying SAM mid-training decreases post-training forgetting by 31-40%.
Methods are effective across model sizes from 20M to 150M parameters.
Abstract
Pretraining optimizers are tuned to produce the strongest possible base model, on the assumption that a stronger starting point yields a stronger model after subsequent changes like post-training and quantization. This overlooks the geometry of the base model which controls how much of the base model's capabilities survive subsequent parameter updates. We study three pretraining optimization approaches that bias optimization toward flatter minima: Sharpness-Aware Minimization (SAM), large learning rates, and shortened learning rate annealing periods. Across model sizes ranging from 20M to 150M parameters, we find that these interventions consistently improve downstream performance after post-training on five common datasets with up to 80% less forgetting. These principles hold at scale: a short SAM mid-training phase applied to an existing OLMo-2-1B checkpoint reduces forgetting by 31%…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
