An accurate flatness measure to estimate the generalization performance of CNN models
Rahman Taleghani, Maryam Mohammadi, and Francesco Marchetti

TL;DR
This paper introduces an exact, architecture-aware flatness measure for CNNs that correlates with generalization performance, addressing limitations of previous Hessian-based measures.
Contribution
The authors derive a closed-form Hessian trace for CNNs with global average pooling and propose a parameterization-aware flatness measure tailored to convolutional layers.
Findings
The proposed measure accurately correlates with CNN generalization performance.
It provides a robust tool for comparing and designing CNN architectures.
Empirical results validate the measure's effectiveness on standard benchmarks.
Abstract
Flatness measures based on the spectrum or the trace of the Hessian of the loss are widely used as proxies for the generalization ability of deep networks. However, most existing definitions are either tailored to fully connected architectures, relying on stochastic estimators of the Hessian trace, or ignore the specific geometric structure of modern Convolutional Neural Networks (CNNs). In this work, we develop a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs. We first derive a closed-form expression for the trace of the Hessian of the cross-entropy loss with respect to convolutional kernels in networks that use global average pooling followed by a linear classifier. Building on this result, we then specialize the notion of relative flatness to convolutional layers and obtain a parameterization-aware flatness measure…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Stochastic Gradient Optimization Techniques · Generative Adversarial Networks and Image Synthesis
