A Hessian-Aware Stochastic Differential Equation for Modelling SGD
Xiang Li, Zebang Shen, Liang Zhang, Niao He

TL;DR
This paper introduces the Hessian-Aware Stochastic Modified Equation (HA-SME), a new SDE model that better captures SGD dynamics by incorporating Hessian information, improving approximation accuracy and understanding of escaping behaviors.
Contribution
The paper develops HA-SME, the first SDE model that exactly recovers SGD dynamics for quadratic objectives and offers superior approximation guarantees for general functions.
Findings
HA-SME achieves order-best approximation error among existing models.
Empirical validation shows improved modeling of neural network loss functions.
HA-SME enables analytical escape time analysis for SGD.
Abstract
Continuous-time approximation of Stochastic Gradient Descent (SGD) is a crucial tool to study its escaping behaviors from stationary points. However, existing stochastic differential equation (SDE) models fail to fully capture these behaviors, even for simple quadratic objectives. Built on a novel stochastic backward error analysis framework, we derive the Hessian-Aware Stochastic Modified Equation (HA-SME), an SDE that incorporates Hessian information of the objective function into both its drift and diffusion terms. Our analysis shows that HA-SME achieves the order-best approximation error guarantee among existing SDE models in the literature, while significantly reducing the dependence on the smoothness parameter of the objective. Empirical experiments on neural network-based loss functions further validate this improvement. Further, for quadratic objectives, under mild conditions,…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsStochastic processes and financial applications
MethodsDiffusion · Stochastic Gradient Descent
