A Theoretical Understanding of Shallow Vision Transformers: Learning, Generalization, and Sample Complexity
Hongkang Li, Meng Wang, Sijia Liu, Pin-yu Chen

TL;DR
This paper provides the first theoretical analysis of shallow Vision Transformers, characterizing their sample complexity, attention sparsity, and how token sparsification can improve performance, supported by empirical validation.
Contribution
It offers a novel theoretical framework for understanding shallow ViTs, including sample complexity bounds and the impact of token sparsification on generalization.
Findings
Sample complexity depends on label-relevant token fraction and noise level.
SGD training leads to sparse attention maps, verifying intuition.
Token sparsification improves test performance by removing irrelevant tokens.
Abstract
Vision Transformers (ViTs) with self-attention modules have recently achieved great empirical success in many vision tasks. Due to non-convex interactions across layers, however, theoretical learning and generalization analysis is mostly elusive. Based on a data model characterizing both label-relevant and label-irrelevant tokens, this paper provides the first theoretical analysis of training a shallow ViT, i.e., one self-attention layer followed by a two-layer perceptron, for a classification task. We characterize the sample complexity to achieve a zero generalization error. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the initial model error. We also prove that a training process using stochastic gradient descent (SGD) leads to a sparse attention map, which is a formal verification of the…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
Taxonomy
TopicsAdvanced Neural Network Applications · Machine Learning and ELM · Advanced Memory and Neural Computing
MethodsTest
