Stiffness: A New Perspective on Generalization in Neural Networks
Stanislav Fort, Pawe{\l} Krzysztof Nowak, Stanislaw Jastrzebski, Srini, Narayanan

TL;DR
This paper introduces the concept of neural network stiffness, measuring how parameter updates on one example affect loss on others, linking stiffness to generalization and exploring its dependence on training dynamics across various models and datasets.
Contribution
The paper proposes the neural network stiffness as a new perspective on understanding generalization, analyzing its dependence on class, input distance, training iteration, and learning rate.
Findings
Higher stiffness correlates with better generalization.
Stiffness varies with class membership and input distance.
CIFAR-100 stiffness matrix reveals super-class awareness.
Abstract
In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the network's parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its dependence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsNeural Networks and Applications · Model Reduction and Neural Networks · Machine Learning and ELM
