How transformers learn structured data: insights from hierarchical filtering
Jerome Garnier-Brun, Marc M\'ezard, Emanuele Moscato, Luca Saglietti

TL;DR
This paper investigates how vanilla transformers learn hierarchical structured data, revealing that they approximate exact inference algorithms by progressively capturing correlations at increasing hierarchical levels during training.
Contribution
It introduces a hierarchical filtering procedure for data sequences on trees and demonstrates how transformers learn to reconstruct correlations across multiple length scales.
Findings
Transformers approximate exact inference algorithms in hierarchical data.
Correlations at larger distances are incorporated sequentially during training.
Attention maps reveal reconstruction of hierarchical correlations.
Abstract
Understanding the learning process and the embedded computation in transformers is becoming a central goal for the development of interpretable AI. In the present study, we introduce a hierarchical filtering procedure for data models of sequences on trees, allowing us to hand-tune the range of positional correlations in the data. Leveraging this controlled setting, we provide evidence that vanilla encoder-only transformers can approximate the exact inference algorithm when trained on root classification and masked language modeling tasks, and study how this computation is discovered and implemented. We find that correlations at larger distances, corresponding to increasing layers of the hierarchy, are sequentially included by the network during training. By comparing attention maps from models trained with varying degrees of filtering and by probing the different encoder levels, we find…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsEnergy Load and Power Forecasting
MethodsSoftmax · Attention Is All You Need
