Trainability Preserving Neural Pruning
Huan Wang, Yun Fu

TL;DR
This paper introduces trainability preserving pruning (TPP), a scalable method that maintains neural network trainability during pruning by regularizing filters and batch normalization parameters, leading to improved performance and robustness.
Contribution
The paper proposes a novel trainability preserving pruning (TPP) method that decorrelates filters and regularizes batch normalization to enhance pruning effectiveness and robustness.
Findings
TPP performs comparably to oracle trainability recovery on linear networks.
On nonlinear ConvNets, TPP outperforms other pruning methods on CIFAR datasets.
On ImageNet, TPP consistently outperforms other structured pruning approaches.
Abstract
Many recent works have shown trainability plays a central role in neural network pruning -- unattended broken trainability can lead to severe under-performance and unintentionally amplify the effect of retraining learning rate, resulting in biased (or even misinterpreted) benchmark results. This paper introduces trainability preserving pruning (TPP), a scalable method to preserve network trainability against pruning, aiming for improved pruning performance and being more robust to retraining hyper-parameters (e.g., learning rate). Specifically, we propose to penalize the gram matrix of convolutional filters to decorrelate the pruned filters from the retained filters. In addition to the convolutional layers, per the spirit of preserving the trainability of the whole network, we also propose to regularize the batch normalization parameters (scale and bias). Empirical studies on linear MLP…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsAdvanced Neural Network Applications · Domain Adaptation and Few-Shot Learning · Seismic Imaging and Inversion Techniques
MethodsPruning · Batch Normalization
