Disentangling Representations through Multi-task Learning
Pantelis Vafidis, Aman Bhargava, Antonio Rangel

TL;DR
This paper demonstrates that multi-task learning in neural networks leads to the emergence of disentangled, interpretable representations that facilitate zero-shot generalization, supported by theoretical guarantees and extensive experiments.
Contribution
It provides a theoretical framework linking multi-task performance to disentangled representations and validates this with experiments across various neural architectures.
Findings
Disentangled representations emerge in multi-task trained RNNs as continuous attractors.
Transformers are particularly effective at disentangling representations.
The framework explains zero-shot out-of-distribution generalization capabilities.
Abstract
Intelligent perception and interaction with the world hinges on internal representations that capture its underlying structure (''disentangled'' or ''abstract'' representations). Disentangled representations serve as world models, isolating latent factors of variation in the world along approximately orthogonal directions, thus facilitating feature-based generalization. We provide experimental and theoretical results guaranteeing the emergence of disentangled representations in agents that optimally solve multi-task evidence accumulation classification tasks, canonical in the neuroscience literature. The key conceptual finding is that, by producing accurate multi-task classification estimates, a system implicitly represents a set of coordinates specifying a disentangled representation of the underlying latent state of the data it receives. The theory provides conditions for the…
Peer Reviews
Decision·ICLR 2025 Poster
The paper presents theoretical results that establish specific conditions—relating to the number of tasks, input dimensionality, input noise, and more—that lead to the emergence of abstract and disentangled representations in agents solving multi-task evidence aggregation classification tasks. The authors conduct thorough experiments across several architectures (RNNs, LSTMs, and transformers) to validate their theoretical results, showing that even architectures like GPT-2 can exhibit these pro
One notable limitation of this work is the assumption of factorization, as acknowledged by the authors. Additionally, the theoretical framework is tailored to a specific type of multi-task learning problem—evidence aggregation classification with linear decision boundaries—which may not capture the full diversity of tasks and decision boundaries that agents encounter in dynamic environments. It would be valuable to explore how these ideas generalize to other multi-task learning scenarios. Furthe
1. Originality: the study of the emergence of disentangled representation in temporal tasks and models is relatively new. 2. Quality: the empirical validation is comprehensive. 3. Clarity: the theory is well explained. The background section is also very nicely written and thorough. 4. Significance: The paper addresses the key problem of learning the world model in representation learning and also discusses in detail the biological relevance. The results offer meaningful insights for future r
1. The theory assumes that agents are optimal multi-task classifiers, which may not be achievable in realistic settings where the input dimension D is already large and the number of tasks N_task >>D. This raises questions about the practical relevance of the regime considered in the paper. Additionally, it’s difficult to imagine a large number of orthogonal tasks in real-world settings. For instance, in the dSprite dataset, as the authors noted, how could a meaningful, larger set of tasks be co
The paper introduces a theoretical framework connecting optimal multi-task classification to disentangled representations, building upon the empirical observations in [Johnston and Fusi (2023)](https://www.nature.com/articles/s41467-023-36583-0) . To support these theoretical claims, the paper provides a range of experiments exploring different architectures, task structures, and decision boundary geometries.
A significant weakness of this paper is its heavy reliance on dense supervised signals to achieve abstract representations. Similar to [Johnston and Fusi (2023)](https://www.nature.com/articles/s41467-023-36583-0), this work utilizes multiple binary classification tasks, framed as "multi-task" learning. This framing implies a diversity of tasks that is not truly reflected in this setup. The dependence on numerous, highly similar, supervised tasks raises questions about the biological plausibilit
Videos
Taxonomy
TopicsDigital Media Forensic Detection · Domain Adaptation and Few-Shot Learning · Adversarial Robustness in Machine Learning
MethodsSparse Evolutionary Training
