A comprehensive evaluation of self-attention for detecting regulatory feature interactions
Saira Jabeen, Asa Ben-Hur

TL;DR
This paper evaluates how self-attention models can detect interactions between transcription factors in gene regulation, improving model interpretability with an entropy term.
Contribution
The addition of an entropy term to self-attention models improves precision and interpretability of regulatory feature interactions.
Findings
Entropy-enhanced attention models produce high-precision sparse attention maps.
Different attention-based methods vary in performance for transcription factor cooperativity discovery.
The study provides insights for effectively using attention models in biological discovery.
Abstract
The successful use of deep learning in computational biology depends on the ability to extract meaningful biological information from the trained models. Recent work has demonstrated that the attention maps generated by self-attention layers can be interpreted to predict cooperativity between binding of transcription factors, a key feature of gene regulatory networks. We extend this earlier work and demonstrate that the addition of an entropy term yields high-precision sparse attention maps that are easy to interpret. Furthermore, we performed a comprehensive evaluation of the relative performance of different flavors of attention-based transcription factor cooperativity discovery. Our findings demonstrate the benefit of the entropy-enhanced attention models and provide additional insights that would enable practitioners to make effective use of this valuable tool for biological…
Genes, proteins, chemicals, diseases, species, mutations and cell lines named across the full text — each resolved to its canonical identifier and authoritative record.
Click any figure to enlarge with its caption.
Figure 1
Figure 2
Figure 3
Figure 4
Figure 5
Figure 6| Dataset name | No of unique TFs | No of pairs |
|---|---|---|
| Data-40 | 20 | 40 |
| Data-60 | 25 | 60 |
| Data-80 | 30 | 80 |
- —U.S. Department of Energy Office of Science10.13039/100006132
- —Biological and Environmental Research10.13039/100006206
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsGenomics and Chromatin Dynamics · Machine Learning in Bioinformatics · Bioinformatics and Genomic Networks
Introduction
Deep learning has made a major impact in computational biology, addressing complex prediction problems, including the prediction of epigenomic features [1, 2], gene expression [3–5], deciphering protein–protein interactions [6], and predicting protein structures [7]. A key advantage of deep networks is their ability to automatically learn task-relevant features directly from raw input data, eliminating the need for manual feature engineering [8]. This capability is particularly important for genomic sequences, where the underlying “language” of DNA is only partially understood. Consequently, deriving biological insights from trained deep networks is critical for understanding regulatory mechanisms and guiding downstream analyses.
Despite their strength in accurately modeling complex high-dimensional prediction problems, interpreting deep networks remains challenging. Interpretability has been a part of deep learning in computational biology from the very beginning, where the authors of DeepBind have introduced a method to interpret the learned weights of the first convolutional layer as a motif model [9]. Another major effort has been the development of attribution methods, which aim at discovering the parts of a sequence that most affect the classification results. These include in silico mutagenesis [2] and gradient-based methods [10, 11].
The process of gene expression is regulated by transcription factors (TFs). The regulation of gene expression is orchestrated by the cooperative binding of TFs and other regulatory proteins in close proximity [12, 13]. Identifying these cooperative binding events can reveal insights that are important for understanding specific biological processes [14, 15]. Several methods have been proposed in order to capture regulatory interactions between TFs. The authors of Deep Feature Interaction Maps (DFIM) [16] have proposed a method based on the DeepLIFT attribution method [17]. The more recent SATORI approach leverages the ability of self-attention layers to capture the relationship between different positions in a sequence [18]. We build on this earlier work in several ways:
We improve model interpretability with the addition of an entropy term that promotes the sparsity of the attention weights, improving both interpretability and precision of the detected regulatory interactions.We extend the SATORI framework by scoring regulatory interactions using either raw attention values or attention attribution.We perform a comprehensive evaluation of the approach and its variants on simulated and real datasets in the context of deep-learning models of varying levels of complexity, demonstrating the effectiveness of attention for this task.
Related work
While there has been extensive research on deep learning architectures for genomic data, efforts related to inferring cooperativity between regulatory elements have been somewhat limited. As mentioned earlier, Greenside et al. [16] introduced the DFIM method to identify interactions at both nucleotide and motif-resolution levels. Their approach utilizes DeepLIFT by computing the disparity in the attribution scores between a sequence with and without mutation to quantify the relationship between a pair of features. A drawback of DFIM is its significant computational expense and the necessity of a post-processing step. The capability of attention layers to capture relationships between features regardless of their spatial distance is the basis for SATORI [18]. The authors demonstrated comparable results to DFIM with much lower computational overhead. However, both methods have not been evaluated on realistic simulated datasets that can provide a clear indication of the accuracy of their predictions.
The idea of using attention weights to interpret a trained model has been explored in several other papers. The authors of EPCOT [19] proposed a self-attention-based encoder-decoder architecture to establish a framework for predicting epigenetic state that is able to generalize to conditions that were not a part of their training set. They demonstrated that the learned epigenomic feature embeddings of co-occurring motifs are nearest neighbors in the embedding space and validated selected interactions using the STRING database. However, the method has not been well tested for the purpose of discovering regulatory interactions. While SATORI uses raw attention values to score regulatory interactions, the ISANREG method performs attribution using integrated gradients over the attention scores to infer TF cooperativity [20]. However, their model is limited to detecting TF cooperativity with a single TF of interest and is designed specifically for the analysis of TF ChIP-seq datasets. SATORI, on the other hand, is more general, supporting the detection of TF cooperativity for more complex datasets, such as those that probe gene regulation on a global scale. The TIANA model uses an approach similar to SATORI, with the distinction of employing pre-loaded convolutional layers initialized with known motif position-specific score matrices and using attribution scores relative to the attention layers [21]. Their experimental results demonstrate the benefit of this approach, which addresses the challenge in motif matching and filter-size choice. However, their model was tested on a limited set of data, and a clearer understanding of the factors that contribute to its performance is still required. The ECHO model takes a very different approach and uses attribution over Micro-C contacts to identify cooperative binding [22]. This method identifies cooperative binding events without distance limitations, an advantage over sequence-based methods. However, it is limited by the scarcity of Micro-C data.
In conclusion, while several methods have been proposed to infer regulatory interactions and TF cooperativity, existing approaches have limitations. These include computational complexity (DFIM), reliance on scarce data types (ECHO), and the ability to identify interactions only with specific ChIP-seq targeted proteins (ISANREG). Moreover, there is a notable lack of comprehensive evaluation for these methods. This gap in thorough benchmarking limits our understanding of how well these methods perform in realistic scenarios. In this work, we integrate the developments described earlier into the SATORI framework and provide a detailed benchmark over realistic simulated data to understand the relative performance of different methods of using attention to detect regulatory cooperativity. Furthermore, we demonstrate that adding an entropy-penalty term to the attention scores leads to increased network interpretability, especially in noisy, real-world datasets.
Materials and methods
Simulated datasets
To obtain reliable estimates of interpretability accuracy, we created simulated datasets with implanted interactions. This addresses the lack of ground-truth interactions in real-world datasets. In generating our simulated datasets, we aimed to replicate the characteristics of real-world datasets. These datasets are designed to be sufficiently complex to challenge the ability of the network to discover interactions. We have generated binary-labeled datasets with varying levels of complexity. In each dataset, the positive class is distinguished from the negative class by the presence of interacting pairs of motifs. All sequences are randomly generated based on the nucleotide frequency of the human genome. For each dataset, we randomly selected motifs from the JASPAR database [23] and implanted motif occurrences according to their position weight matrices (PWMs). In the positive class, for each sequence, three interacting pairs are randomly chosen from the predefined set of interacting pairs. These pairs are embedded into the sequence using its PWM, ensuring PWM scores exceeding half of the maximum PWM score. Motifs from an interacting pair are placed in close proximity, within 8–15 base pairs. In the negative class, one to three motifs are randomly selected from the same pool of motifs. These motifs are uniformly embedded using their PWM, with each motif positioned at least 40 base pairs away from every other motif within the sequence. The generated datasets are balanced, consist of a total of 60 000 instances per dataset, and each sequence is 300 bp long. Table 1 describes the number of TFs embedded in the three simulated datasets, along with the number of interactions.
For the second set of simulated data used to evaluate pre-loaded CNN filters, we used sequences of length 1500 bp, and motif pairs were drawn from a new set, where Jaspar motifs were clustered, and 40 motifs from distinct clusters were selected, and 80 interacting motif pairs were selected. Positive sequences had 3–5 interacting pairs embedded within a window of 8–80 bp, while negatives included 1–3 randomly chosen motifs from the pool of unique motifs.
Interpreting self-attention layers with SATORI
SATORI uses a network with a structure that is composed of two or more blocks: an initial block composed of a 1D convolutional feature extraction layer followed by a multi-head self-attention layer; the initial block is followed by additional (optional) deep learning layers and a fully connected output block (see Fig. 1). The attention layer in the primary block serves as a pivotal element for interpretation. Any additional layers serve to provide improved model accuracy, and we demonstrate that this framework can handle models of varying levels of complexity.
The SATORI framework. The deep learning architecture receives as input a one-hot encoded DNA sequence, which is then processed through a 1D-CNN layer, a self-attention layer, optional supplementary layers, and a classification layer. SATORI combines the attention heads and compares the attention values between pairs of filters in positive examples to a background set. It then reports pairs of filters that demonstrate statistical significance.
The input to the model is a DNA sequence that is represented by its one-hot encoding. This representation enables the first convolutional layer to learn motif-like features directly from raw sequence, which the subsequent attention layer leverages to capture dependencies between motifs.
The output of convolution is passed through an activation function. In this work we use the rectified linear unit function (ReLU) defined as
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} \text{ReLU}(x) = \left\lbrace \begin{array}{@{}l@{\quad }l@{}}x, & \text{if } x > 0 \\0, & \text{otherwise}. \end{array}\right. \end{eqnarray*}\end{document}After applying the ReLU function, a max-pooling operation is employed to downsample the spatial dimensions of the feature maps. This max-pooling step is optional and can help in reducing the computation time and memory usage by reducing the length of the sequences that need to be processed by the attention layer, which is the most time- and memory-consuming aspect of the model. In our experiments, we used a very small value for the pooling parameter (four or less) to ensure that the attention layer receives accurate spatial information about the sequence. The output of the max-pooling operation is fed into the multi-head self-attention layer, which captures relationships between elements within the sequence. This allows the model to weigh the importance of different elements, enabling it to focus on relevant information. Self-attention is computed using three matrices known as query, key, and value defined as:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} Q = XW_q, \quad K = XW_k, \quad V = XW_v, \end{eqnarray*}\end{document}where \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} X\end{document} is the max-pooled output of first CNN layer, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} W_q\end{document} , \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} W_k\end{document} and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} W_v\end{document} are learnable weight matrices used to transform the input vector \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} X\end{document} into the query ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} Q\end{document} ), key ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} K\end{document} ), and value ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} V\end{document} ) matrices. By computing the dot product between the query and key vectors of specific positions in the sequence, we can assess the relationship between the selected positions. We then compute dot product attention as follows:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} A = \text{softmax}\left(\frac{QK^\top }{\sqrt{d_k}}\right) \end{eqnarray*}\end{document}where, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} d_k\end{document} is the embedding dimension of \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} K\end{document} , \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} Q\end{document} , and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} V\end{document} . The softmax function is used to transform the dot products into probabilities arranged in a row-wise manner. This allows us to interpret the output of the attention layer as the likelihood of a relation between different parts of the input.
The output from the softmax function is employed as weights to compute the value matrix \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} V\end{document} to generate the output of a given attention head:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} Z = A \cdot V \end{eqnarray*}\end{document}Multi-head attention enables the model to simultaneously attend to different aspects of the input from diverse subspaces across various positions. Subsequently, the outputs of the attention heads are concatenated to produce the final output of the attention layer.
We use a similar process as described by Ullah et al. [18] to discover filter–filter interactions and infer motif interactions from those filter interactions. To infer interactions, we first combine the values of the attention heads by computing the maximum value across heads. Instead of employing a fixed threshold value, as was used in the original SATORI, we utilized the mean value of attention scores as a cutoff threshold for selecting interacting attention values. Following this, the statistical significance of the attention values is evaluated using the non-parametric Mann–Whitney U Test [24] relative to the background attention scores computed over the negative examples for binary classification or shuffled sequences in the context of multi-label classification. This process is illustrated in Fig. 1.
We leveraged the weights of the initial CNN layer and attention layer to map filters to putative TFs using an established protocol [2]. Each substring yielding an activation score exceeding 0.65 of the maximum score of the filter is used the final computation of the PWM for that filter. The resulting PWMs are then compared against the known motif databases (JASPAR [23] (simulated datasets), Human CISBP [25] (for the human promoters dataset) and Arabidopsis DAP [26]) using the TomTom tool [27], and the top matches are assigned to each filter.
Architectures
Our baseline architecture (SATORI-basic) directly feeds the output of the attention layer to a fully connected layer. In order to demonstrate the effectiveness of SATORI with more complex architectures, we explored the addition of an additional block of convolutional layers whose structure was inspired by the SEI model [28]. The basic component of this block is a residual dual linear and non-linear CNN sub-block, shown in Fig. 2. The linear pathway enables rapid training, while the nonlinear pathway is intended to enhance its representational power. Furthermore, a residual connection is established between the stacked linear and nonlinear blocks, thereby permitting computation to traverse either the linear or nonlinear pathway. This also provides gradient stability during training and counteracts the vanishing gradient problem associated with deeper models. The last sub-block uses dilated CNN layers, augmented with residual connections, to capture longer-range relationships across the sequence. Subsequently, the resulting feature maps are averaged along the sequence dimension prior to processing through a fully connected layer and final output layer.
Additional layers included in the SATORI-deep architecture after the self-attention layer.
Improving network interpretability with an entropy loss term
In order to improve the interpretability of the attention values, we incorporated an entropy term in our loss function, in addition to the classification loss. The entropy loss is calculated from the attention scores to promote sparsity of the attention matrix. The resulting loss function is defined as follows:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} \mathrm{ Loss} = \mathrm{ Loss}_{c} + \lambda \mathrm{ Loss}_{e}. \end{eqnarray*}\end{document}Here, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \mathrm{ Loss}_{c}\end{document} is a classification loss (e.g. binary cross entropy) and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \mathrm{ Loss}_e\end{document} is the entropy term defined as:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} \mathrm{ Loss}_{e} &= -\sum _{k=1}^{h}\left( \frac{1}{n}\sum _{j=1}^{n} \sum _{i=1}^{n} A_{ij}^{(k)} \log A_{ij}^{(k)}\right), \end{eqnarray*}\end{document}where \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} n\end{document} is number of rows or columns in the attention head \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} A^{(k)}\end{document} , and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} h\end{document} is number of heads. The parameter \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \lambda\end{document} in Equation (5) is the regularization parameter for the entropy loss and is selected as part of our hyperparameter search.
Attention attribution scores
To compare the interpretation performance of raw attention scores with attributed attention scores, as used by [20], we incorporated their approach into the SATORI framework. Attention attribution scores are computed by calculating the integrated gradients of the target output with respect to the self-attention scores. This method was first proposed by [29] to interpret the information learned by a transformer model. Following [30], the attribution scores can be approximated from a zero-baseline using the Riemann approximation of the integration as follows:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \begin{eqnarray*} \text{Attr}(A^{(k)}) = \frac{A^{(k)}}{m} \odot \sum _{i=1}^{m} \frac{\partial F\left(\frac{i}{m} A\right)}{\partial A^{(k)}} , \end{eqnarray*}\end{document}where \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \odot\end{document} is element-wise multiplication, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} A^{(k)}\end{document} describes a single attention head among the set of \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} h\end{document} attention heads, denoted by \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} A\end{document} . \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \frac{\partial F\left(A\right)}{\partial A^{(k)}}\end{document} is the gradient of the model \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} F(\cdot )\end{document} with respect to \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} A^{(k)}\end{document} ; \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} m\end{document} is the number of approximation steps, which is set to 20. To identify significant interactions, we process the attention attribution scores in a similar manner to the way SATORI processes raw attention scores. When using the entropy loss with the deep architecture, we incorporated residual connections from the output of the feature extraction CNN layer in the initial block to the output of the multihead attention layer before applying additional convolution layers. This is done to maintain the accuracy of the model, which can be impacted by the increased sparsity induced by the entropy loss term.
Model selection and training
We used two primary architectures in our experiments: a basic model, which is a two-layered architecture comprising a CNN layer, self-attention layer, and the more complex model, which includes additional CNN layers. For model selection for the basic architecture, we employed randomized grid search using the search space shown in Supplementary Table S1 in the supplementary material. This tunes multiple hyperparameters, including the number of CNN filters, CNN filter size, size of the attention head, number of attention heads, and the size of the linear layer after the multi-head attention layer. In addition to these, the optimizer, learning rate, and weight decay are also determined through hyperparameter search. An instance of the simulated dataset is used for hyperparameter selection based on its AUC on the validation set. Details of selected models for training are provided in Supplementary Table S2 of the supplementary material. To obtain an optimal value for the regularization parameter \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} \lambda\end{document} of the entropy loss, we used values on a logarithmic scale ranging from \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} 10^{-5}\end{document} to \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{upgreek} \usepackage{mathrsfs} \setlength{\oddsidemargin}{-69pt} \begin{document} 10^{-2}\end{document} . We selected the entropy value corresponding to the best interaction performance (F1-score) on the dataset used for model selection for further experimentation. Datasets that are used for model and hyperparameter selection are later discarded, and new instances are utilized to evaluate the performance of the final selected model. For the real datasets, we used architecture settings similar to our previous work [18]; the deep architecture used the same hyperparameters for the initial block; details of the SEI-inspired additional block are given in Supplementary Table S3 in supplementary material. The simulated datasets were divided into train, validation, and test sets with 80%, 10%, and 10% ratios, respectively. For each dataset size, we generated three dataset instances. Subsequently, we trained a model three times using different seed values on each instance of the dataset, and the best-performing model, determined by validation AUC, was selected. The final results for each dataset type are reported by averaging over the three dataset instances. Model accuracy was measured using area under the ROC curve (AUC), and interpretation accuracy of the predicted interactions was measured using precision, recall, and F1-score. All methods were implemented using PyTorch [31], and all the experiments were conducted on a Linux workstation with a 12 GB TITAN V GPU. For DFIM implementation, we have used the implementation provided in the original SATORI code base [18].
Results
SATORI accurately infers TF interactions embedded in simulated datasets
In this section, we compare the performance of various flavors of SATORI, as described earlier, using simulated datasets of varying levels of complexity, and evaluate the accuracy of the approach.
Impact of the entropy loss
To comprehensively evaluate the impact of adding an entropy loss term on the quality of predicted interactions, we trained deep learning models with two architectures, with varying depth, with and without the entropy loss. Our analysis reveals that the entropy loss maintains the classification accuracy of the models while enhancing precision by promoting sparsity in the attention profile (see Fig. 3). This reduction in false positives aligns with our expectations. We also observe gains in overall accuracy, and in most cases, it also contributed positively to recall. The marginal decrease in classification accuracy observed in some cases when using the entropy loss indicates a potential loss of information induced by the increased sparsity (see Fig. 3).
Comparative performance of SATORI on simulated datasets of varying levels of complexity using the Basic and Deep architectures. We show: classification AUC (a) and interpretation accuracy: F1-score (b), precision (c), and recall (d). Results are shown for four models: Basic and Deep trained with and without the entropy loss; results are reported for datasets of increasing levels of complexity, comprising 40, 60, or 80 implanted TF motifs; performance numbers are averages over three instances of each dataset. Results are reported for the version of SATORI that uses raw attention values (Attn-Raw).
We also observe that, as the complexity of the simulated dataset increases, both interpretation and classification accuracy decrease, as would be expected. It is also interesting to note that moving from Basic to Deep architecture increases classification AUC, as shown in Fig. 3a. However, interpretation performance demonstrated a slight decrease. This behavior has been consistent across all three simulated datasets.
To obtain a better understanding of the contribution of the entropy term, we looked at the distribution of the attention scores with and without entropy (see Fig. 4a). As expected, we observe that the inclusion of entropy results in a distribution of attention scores that is clearly bi-modal. SATORI scores interactions based on the top attention values, requiring a threshold for choosing attention values that contribute to interactions. The addition of the entropy term has the added benefit of making it trivial to establish such a cutoff. The sparsity resulting from the use of entropy can also be observed in the heatmap representing the attention matrix (see Fig. 4b). The vertical lines observed in the heatmap correspond to positions of embedded motifs. Other positions in the sequence attend to these motif positions, suggesting that the model has successfully learned to attend to motif occurrences. In the figure, we also observe that the signal in the attention map is more clear in a model trained with an entropy term. Note that the entropy term is multiplied by a regularization parameter, whose value needs to be determined. In our experiments, values of the parameter between 0.001 and 0.01 have yielded similar levels of sparsity.
The effect of entropy on attention weights for a basic model trained without (left) and with (right) the entropy loss, using a regularization parameter of 0.01 for Data-80. (a) The distribution of attention weights across accurately predicted test sequences. (b) Heatmaps of attention weights maximized over all attention heads for a true positive sequence. The X-axis and Y-axis in the heatmaps are the positions in the simulated DNA sequence after max pooling; arrows indicate the locations of embedded motifs in the original sequence, adjusted for the pooling window.
Comparison of raw attention, attention attribution, and DFIM
Here, we compare the interpretation accuracy of SATORI using raw attention scores with that of attention attribution scores and DFIM. In our experiments, we observed that raw attention scores and attention attribution scores give comparable results on simulated datasets, as shown in Fig. 5. In fact, the attention attribution matrices exhibit a high similarity to the raw attention matrices (see Supplementary Figs S1 and S2 in the supplementary material). SATORI with entropy loss outperforms the corresponding model without entropy, both with raw attention scores and attention attribution scores. Both models also outperforms DFIM when used in conjunction with entropy.
Interpretation accuracy. We compare performance of SATORI using raw attention scores (Attn-Raw), attention attribution (Attn-Attr), and DFIM based on precision (a), recall (b), and F1-score (c). All methods use the same basic architecture; datasets comprise 40, 60, or 80 implanted TF motifs; performance numbers are averages over three instances of each dataset.
An advantage of SATORI with raw attention scores is computation time (see Supplementary Fig. S3 in the supplement). DFIM and SATORI with attention attribution have the added overhead of requiring the computation of gradients after the model has been trained. For the basic architecture, using raw attention scores has only a slight advantage in terms of running time; for the deep architecture, gradient computations dominate the running time, and using raw attention scores is around five times faster. The difference with respect to DFIM is even more dramatic, where SATORI with raw attention scores runs over 100 times faster for the deep architecture.
Using pre-loaded filters
The accuracy of inferred interactions depends on the ability of TomTom to accurately match filters with known motifs. The use of pre-loaded CNN filters based on known motifs offers the potential advantage of not requiring computing matches to the trained filters, avoiding a source of potential error [21]. To quantify this, we compared models using trainable CNN filters versus models where the CNN filters were pre-loaded with clustered motifs from the Jaspar database, following the setup proposed in [21]. For this purpose, we generated simulated datasets using a sequence length of 1500 bp, which will also enable us to demonstrate the ability of self-attention to capture dependencies in longer sequences. We trained both the basic and deep models using hyperparameters selected via random search. Each experiment was repeated on three independently generated dataset instances, and results were averaged to ensure robustness. As expected, using pre-loaded filters leads to improved interpretation accuracy for both the Basic and Deep architectures (Supplementary Figs S4 and S5). However, this comes at the expense of a drop in classification accuracy due to reduced flexibility in feature learning (Supplementary Fig. S5). We also evaluated the role of positional encoding [21]. While it provides minor improvements in some cases, overall it does not improve classification or interpretation accuracy. We also observe that, in some cases, attention attribution failed to match the interpretation accuracy of using raw attention scores and was somewhat inconsistent in performance. Finally, consistent with our previous results, the deeper model had somewhat lower interpretation accuracy.
Incorporating recurrent neural networks
In our original work, we used a bidirectional recurrent neural network (RNN) layer after the initial CNN layer [18]. Although this design introduces a less direct relationship between the output of the CNN filters and the self-attention layer, we have found that this did not impact interpretability accuracy: on the Data-60 simulated dataset, adding an RNN layer (a bidirectional LSTM) to the SATORI-Basic architecture produced an interpretation F1-score of 0.36 compared to 0.33 without the RNN (these represent an average over 10 instances). Adding the RNN also improved classification accuracy, measured by AUC, from 0.88 to 0.96. The use of an RNN led to a reduction in the reproducibility of the results: the interpretability F1-score with SATORI-Basic demonstrated a standard deviation of 0.0004 across 10 runs of the model with different seed values. The addition of the RNN increased this to 0.004, which we would still consider stable.
TF interactions in accessible chromatin in human promoters
To test the ability of SATORI to detect cooperation of TFs in real data, we explored the regulatory cooperation among TFs within human promoter regions, leveraging DNase I hypersensitivity data (DHSs) collected from 164 immortalized cell lines and tissues. This dataset was introduced in our earlier work [18] and is derived from the ENCODE [32] and ROADMAP [33] consortia, consistent with the data used in [2]. For our analysis, we selected DHSs overlapping promoter regions, defined as 1 kb upstream of transcription start sites based on Ensembl hg19/GRCh37 annotations. Overlapping DHSs were merged if they shared >200 bp, and each region was extended to 600 bp around its center, yielding a final dataset of 20 613 promoter-associated sequences. The prediction task was to determine the presence or absence of a DHS across each of the 164 cell lines, resulting in a multi-label classification problem. We partitioned the dataset into training, testing, and validation sets at an approximate ratio of 80%, 10%, and 10%, respectively, while ensuring entire chromosomes were allocated to each set. Specifically, the training set includes chromosomes 1–14 and 18–19, the validation set includes chromosomes 17, 20–21, and the test set contains the remaining chromosomes. To compare the performance of SATORI with trainable CNN filters against pre-loaded CNN weights from clustered PSSMs, as described elsewhere [21], we clustered 733 human TF motifs into 505 groups. The corresponding PSSMs for these clusters were loaded as weights in the first CNN layer. During interaction inference, we followed a similar approach to TIANA [21], where sequences with a filter score greater than half the maximum PSSM score for that filter were considered for further analysis. The mean classification AUC-ROC achieved for the basic and deep SATORI models without entropy were 0.82 and 0.84, respectively. The addition of the entropy term resulted in a slight decrease in classification accuracy, yielding an AUC of 0.79 for the basic model and 0.81 for the deep model. Models with pre-loaded CNN weights achieved the AUC-ROC of 0.81 with entropy for both basic and deep architectures. However, the entropy term was crucial for detecting interactions: the addition of the entropy term increased the number of interactions detected for both architectures (see Fig. 6). Its effect on the attention values is shown in Supplementary Fig. S6 in the supplementary material. The use of pre-loaded filters provided a further big increase in the number of detected interactions; in this dataset deep and shallow models had similar numbers of detected interactions, as did the attention-attribution and attention-raw models.
Number of interactions inferred by SATORI basic and deep architectures with and without entropy loss, compared against DFIM on the human promoter dataset (a), accessible chromatin in Arabidopsis (b), and Arabidopsis enhancers (c).
Since the co-occurrence of TF binding can be the result of physical interactions between the TFs, we searched these interactions against the HIPPIE protein–protein interaction database [34]. The interaction between SP7 and KDM2B discovered by all of the methods. Supplementary Table S4 provides the number of matches in the HIPPIE database for all experiments, and we observe that the pre-loaded CNN approach provided the largest number of matches. Supplementary Tables S5–S16 in the supplementary material provide the list of matches identified. For both real-data experiments, we were unable to run DFIM on deeper models due to time intensive gradient computations.
TF interactions in accessible chromatin in Arabidopsis
For the next experiment, we predicted chromatin accessibility across 36 Arabidopsis samples gathered from DNase I-Seq and ATAC-Seq studies used in our earlier work [18]. This dataset was also generated using the pipeline described in [2]. In this case, we utilized genome-wide accessible chromatin regions. This is also a multi-label classification problem with a total of 80 000 instances, each of length 600 bp. We partitioned the dataset into training, validation, and testing subsets in an 80%, 10%, and 10% ratio, respectively. Chromosome 2 was allocated to the validation set, chromosome 4 to the test set, and all other chromosomes were included in the training set. We trained two SATORI models, basic and deep, with and without employing the entropy loss. Similar to human promoter data experiments, we clustered 872 Arabidopsis TFs, resulting in 343 clusters, which are loaded as CNN weights. The mean classification AUC achieved by both models without entropy on this dataset was 0.82 and 0.83, respectively. For the models trained with entropy loss, basic and deep SATORI models achieved an AUC of 0.77 and 0.82, respectively. Models with pre-loaded CNN weights achieved an AUC of 0.79 on both basic and deep models when using entropy regularization.
In this dataset the addition of the entropy term was not as critical for detecting interactions for most of the experiments, and in some cases we obtained a higher number of interactions for the model without entropy. We believe this is a result of the much larger number of examples available in this dataset, making it easier for SATORI to detect significant interactions from the noisy attention weights, thereby requiring less help from the entropy term (Fig. 6b). Interestingly, the basic pre-loaded CNN model trained with entropy recovered the highest number of interactions for Arabidopsis (Fig. 6b). This can be attributed to the fact that the pre-loaded PSSMs effectively captured the relevant patterns in the data. Out of 343 filters, 342 were activated in >10 sequences within the test set, indicating that nearly all motif-based filters were involved during inference. In contrast, for models with trainable CNN filters, when TomTom is unable to match a filter to any known motif, that filter tends to remain inactive during interaction inference. Moreover, when multiple filters are mapped to the same motif, the resulting interactions can become redundant, further reducing the number of inferred interactions. In both Arabidopsis and the human promoters datasets, the attention-based models produced a much larger number of interactions than DFIM.
TF interactions in Arabidopsis enhancers
To further validate the impact of entropy loss on retrieving interactions, we used data on Arabidopsis active enhancers, since cooperative binding is a defining characteristic of these regulatory regions [13]. We obtained a set of 4327 enhancers from a previous study that used the STARR-seq assay [35]. We used this set as positive examples for training and evaluating a classifier of active enhancers. To construct negative examples, we extracted genomic features from the TAIR10 genome annotations, including exons, introns, promoters, UTRs, downstream regions, and distal intergenic regions. Each candidate region was extended to 600 bp centered on its midpoint, and regions overlapping with enhancers were excluded. To generate the negative examples, we sampled 43 270 sequences (10 negatives per enhancer) such that the distribution of negatives reflected that of enhancers, namely 53% exonic, 38% promoters, 6% distal, and 1% intronic. We randomly split data into 80-10-10 for training, testing, and validation sets, respectively. As in our other real-data experiments, we trained both basic and deep SATORI models, with and without the entropy loss. For pre-loaded CNN experiments, we used the same clustered PSSMs as in the whole-genome Arabidopsis experiment. Across all settings, we achieved AUPR values between 0.80 and 0.83 (Fig. 6c).
In our experiments, we observed that the addition of entropy loss provided a large improvement in the number of inferred TF–TF interactions for most architectures. DFIM achieved a similar number of inferred interactions as the deeper model with entropy loss but required substantially higher computational cost. Models with pre-loaded CNN filters performed relatively poorly in this setting compared to human promoter and whole-genome Arabidopsis experiments, despite the fact that 243 out of 343 filters were activated in >10 sequences. Overall, the entropy loss helped in identifying more interactions when using attention-based interpretation techniques compared to models trained without it, while also being computationally lighter than DFIM.
Conclusions
In this work, we presented SATORI 2.0, which adds the option of an entropy penalty term and incorporates the option of scoring feature interactions based on attribution scores over the attention matrix as an alternative to scoring based on the raw attention values. We performed a comprehensive evaluation of its performance over realistic simulated datasets and real-world datasets with architectures with varying levels of complexity. The results show that the self-attention layers can effectively and accurately discover regulatory interactions across different architectures, datasets, and sequence lengths. We found that the entropy loss helps remove false positives, and in most cases, also increases the method’s recall. These findings support entropy as an effective inductive bias for guiding attention towards improved precision when inferring interactions. We also evaluated SATORI models with pre-loaded CNNs, as suggested by [21]. Our findings confirm that pre-loaded CNNs perform well in terms of interaction inference due to reduced dependence on TomTom accuracy. Entropy also appears to improve the precision of models with pre-loaded CNNs. However, using pre-loaded CNNs comes with the challenge of selecting the right set of TFs relevant to the problem. In our experiments, we found that using attention attribution scores performed similarly to using raw attention values; the earlier DFIM method did not perform as well and required significantly higher computation times. In summary, the SATORI 2.0 framework provides a comprehensive tool for interpreting neural network layers and obtaining valuable predictions about regulatory relationships that can be further tested in the laboratory.
Supplementary Material
lqaf209_Supplemental_File
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1Zhou J, Troyanskaya OG. Predicting effects of noncoding variants with deep learning-based sequence model. Nat Methods. 2015;12:931–4. 10.1038/nmeth.3547.26301843 PMC 4768299 · doi ↗ · pubmed ↗
- 2Kelley DR, Snoek J, Rinn JL. Basset: learning the regulatory code of the accessible genome with deep convolutional neural networks. Genome Res. 2016;26:990–9. 10.1101/gr.200535.115.27197224 PMC 4937568 · doi ↗ · pubmed ↗
- 3Avsec Ž, Agarwal V, Visentin D et al. Effective gene expression prediction from sequence by integrating long-range interactions. Nat Methods. 2021;18:1196–203. 10.1038/s 41592-021-01252-x.34608324 PMC 8490152 · doi ↗ · pubmed ↗
- 4Singh R, Lanchantin J, Robins G et al. Deep Chrome: deep-learning for predicting gene expression from histone modifications. Bioinformatics. 2016;32:i 639–48. 10.1093/bioinformatics/btw 427.27587684 · doi ↗ · pubmed ↗
- 5Karbalayghareh A, Sahin M, Leslie CS. Chromatin interaction-aware gene regulatory modeling with graph attention networks. Genome Res. 2022;32:930–44.35396274 10.1101/gr.275870.121PMC 9104700 · doi ↗ · pubmed ↗
- 6Hashemifar S, Neyshabur B, Khan AA et al. Predicting protein–protein interactions through sequence-based deep learning. Bioinformatics. 2018;34:i 802–10. 10.1093/bioinformatics/bty 573.30423091 PMC 6129267 · doi ↗ · pubmed ↗
- 7Jumper J, Evans R, Pritzel A et al. Highly accurate protein structure prediction with Alpha Fold. Nature. 2021;596:583–9. 10.1038/s 41586-021-03819-2.34265844 PMC 8371605 · doi ↗ · pubmed ↗
- 8Bengio Y, Courville A, Vincent P. Representation learning: a review and new perspectives. IEEE Trans Pattern Anal Mach Intell. 2013;35:1798–828. 10.1109/TPAMI.2013.50.23787338 · doi ↗ · pubmed ↗
