Modeling Latent Attention Within Neural Networks
Christopher Grimm, Dilip Arumugam, Siddharth Karamcheti, David Abel,, Lawson L.S. Wong, Michael L. Littman

TL;DR
This paper introduces a visualization method for neural networks that highlights which input features influence decisions, enhancing interpretability across various domains like vision, language, and reinforcement learning.
Contribution
It presents a general, dataset-centric approach to visualize and interpret the internal attention mechanisms of neural networks across multiple modalities.
Findings
Effective visualization of attention masks in diverse neural architectures
Improved understanding of input attribute importance in decision-making
Framework applicable to vision, NLP, and reinforcement learning models
Abstract
Deep neural networks are able to solve tasks across a variety of domains and modalities of data. Despite many empirical successes, we lack the ability to clearly understand and interpret the learned internal mechanisms that contribute to such effective behaviors or, more critically, failure modes. In this work, we present a general method for visualizing an arbitrary neural network's inner mechanisms and their power and limitations. Our dataset-centric method produces visualizations of how a trained network attends to components of its inputs. The computed "attention masks" support improved interpretability by highlighting which input attributes are critical in determining output. We demonstrate the effectiveness of our framework on a variety of deep neural network architectures in domains from computer vision, natural language processing, and reinforcement learning. The primary…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Adversarial Robustness in Machine Learning · Machine Learning and Data Classification
