TL;DR
This paper introduces the Permutohedral Attention Module (PAM), a memory- and computation-efficient attention mechanism designed to capture non-local information in 3D medical images, improving segmentation accuracy.
Contribution
The paper presents a novel, efficient attention module for neural networks that effectively captures non-local context in 3D medical imaging tasks.
Findings
Demonstrates improved vertebrae segmentation accuracy.
Shows scalability and efficiency of PAM in 3D medical imaging.
Provides GPU implementation suitable for real-world applications.
Abstract
Medical image processing tasks such as segmentation often require capturing non-local information. As organs, bones, and tissues share common characteristics such as intensity, shape, and texture, the contextual information plays a critical role in correctly labeling them. Segmentation and labeling is now typically done with convolutional neural networks (CNNs) but the context of the CNN is limited by the receptive field which itself is limited by memory requirements and other properties. In this paper, we propose a new attention module, that we call Permutohedral Attention Module (PAM), to efficiently capture non-local characteristics of the image. The proposed method is both memory and computationally efficient. We provide a GPU implementation of this module suitable for 3D medical imaging problems. We demonstrate the efficiency and scalability of our module with the challenging task…
| Network | FCN | FCN+PAM | Dil.FCN | Dil.FCN+PAM | U-Net | U-PAM-Net |
|---|---|---|---|---|---|---|
| Full | ||||||
| Cervical | ||||||
| Thoracic | ||||||
| Lumbar |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
11institutetext: School of Biomedical Engineering & Imaging Sciences, King’s College London
Permutohedral Attention Module
for Efficient Non-Local Neural Networks
Samuel Joutard
Reuben Dorent
Amanda Isaac
Sebastien Ourselin
Tom Vercauteren
Marc Modat
Abstract
Medical image processing tasks such as segmentation often require capturing non-local information. As organs, bones, and tissues share common characteristics such as intensity, shape, and texture, the contextual information plays a critical role in correctly labeling them. Segmentation and labeling is now typically done with convolutional neural networks (CNNs) but the context of the CNN is limited by the receptive field which itself is limited by memory requirements and other properties. In this paper, we propose a new attention module, that we call Permutohedral Attention Module (PAM), to efficiently capture non-local characteristics of the image. The proposed method is both memory and computationally efficient. We provide a GPU implementation of this module suitable for 3D medical imaging problems. We demonstrate the efficiency and scalability of our module with the challenging task of vertebrae segmentation and labeling where context plays a crucial role because of the very similar appearance of different vertebrae.
Keywords:
Non-local neural networks Attention module Permutohedral Lattice Vertebrae Segmentation
1 Introduction
Convolutional neural networks (CNNs) have become one of the most effective tools for many medical image processing tasks such as segmentation. However, working with medical images has its own idiosyncratic challenges. The organs, tissues or bones can have very similar characteristics, such as intensity, texture, or shape. As a consequence, the differentiating aspects of each individual structure come from the context and the position of the item of interest in the larger surroundings. However, naively extracting non-local characteristics of a region requires much more computation and memory than focusing on its local characteristics. This currently makes using non-local context highly non-trivial in medical imaging. Hence, an efficient approach to exploit non-local characteristics in deep learning could transform several medical imaging pipelines.
The notion of contextual information is intimately related to the concept of receptive field in deep learning. The receptive field of an output variable corresponds to the region in the input influencing its value. Recent studies on receptive field in CNNs [10] have proven that the receptive field size is sub-linear in the number of convolutional layers. In order to improve the receptive fields of a CNN, two main solutions have been adopted: down-sampling layers and dilated convolutions [14]. Use of down-sampling layers efficiently increases the receptive field size but decreases the resolution of the information. Hence, it is not suitable for very granular segmentation in which case dilated convolutions are often preferred [9]. Both of these solutions result in a fixed receptive field, which means that all contextual information in the receptive field will be taken into account whether it is relevant or not. Attention modules have been used to prune irrelevant information in medical imaging [12, 15]. Yet, these tools remain suboptimal as they do not allow to capture large scale context. However, the extended self-attention formulation of [13] offers a solution to dynamically adapt the individual receptive field of each output variable to only make use of relevant non-local information. Despite its attractive properties, this formulation of self-attention has not yet been applied to medical images partly because its computational requirements scale as ( is the number of voxels).
In this paper, we propose a new self-attention module called Permutohedral Attention Module (PAM), which makes use of the efficient approximation algorithm of the Permutohedral Lattice [1]. We adapted the algorithm of [1], originally designed to perform denoising, into a trainable self-attention module able to capture and process contextual information. The Permutohedral Lattice algorithm was previously used in a trainable framework in the more general context of sparse high dimensional convolutions for computer vision [7]. The self-attention approach is a suitable compromise for medical image processing in terms of memory and computation between [7] and standard convolutions while preserving most of the model representation capacity increase of [7] to process contextual information. Our module, similarly to the original non-local self-attention mechanism formulation, dynamically adapts the receptive field of each output variable in a learned way while being, in contrast to [13], applicable to medical images as it has low memory requirements, computationally scaling as . We evaluate our module on the challenging task of vertebrae segmentation. Vertebrae segmentation aims to label each individual vertebra and is used in practice as an initial step of various pipelines such as modality fusion, spine surgery planning and surgical guidance. As consecutive vertebrae have very similar local appearance, non-local information is compelling to identify them.
In Section 2, we first define self-attention and how it has been used, we then introduce the PAM. In section 3 we first highlight the capability of our module to capture and process contextual information without requiring a deep architecture. Then, we demonstrate its capability to improve state of the art segmentation architectures for vertebrae segmentation.
2 Methods
2.0.1 The self-attention mechanism
Self-attention used in deep learning frameworks can be defined as follows: consider a standard deep learning framework where the input is processed first by a section of the network we call descriptor network , and then by the rest of the network we call prediction network ( and are the respective parameter sets). The model predicts so that:
[TABLE]
We define the self-attention mechanism parameterized by which combines the non-local input descriptors in a learned way. For all input , is a self-attention matrix where the coefficient characterizes the attention of towards . Our framework including an attention mechanism predicts :
[TABLE]
where represents the matrix multiplication operator. This formulation has two principal strengths; it can increase the receptive field of each output variable up to the whole input, and it can modulate the receptive field of each output variable with respect to the input characteristics. To our knowledge, attention modules in deep learning either compute the entire self-attention matrix on a low dimensional input or use a local attention mechanism that can be seen as a strong approximation of the non-local self-attention formulation. Specifically in the medical imaging context, previous works [12, 15, 11] implicitly used a simplification of (2) with a diagonal self-attention matrix. This solution can be applied to large images since it scales linearly with the number of voxels but does not help to capture contextual information.
Different implementations of the non-local self-attention matrix are listed in [13]. These can be unified as follows:
[TABLE]
where is a pair of embedding functions (possibly identities) and is typically either identity, exponential or ReLU. Hence, these approaches are impractical to apply to 3D images because the number of interactions to be computed scales as .
2.0.2 Permutohedral Attention Module
The proposed PAM relies on a slightly different formulation of the self-attention matrix to align more closely with the formulation of the non-local means filtering algorithm [3] used in the denoising literature. When applied to the set of feature-descriptor pairs (where is the number of variables described), non-local mean gives the set of filtered descriptors:
[TABLE]
Hence:
[TABLE]
is the corresponding attention formulation with a feature extractor network ( its parameter set).
Avoiding a brute-force computation of (4), we adapted the Permutohedral Lattice approximation algorithm [1] to estimate the self-attention module output in against for the original non-local neural network formulations listed in [13]. Learning the parameter sets and is achieved through back-propagation. Hence, the PAM can be integrated in a deep learning framework to compute self-attention for high dimensional inputs (cf. Section 3.0.3 for concrete architectures examples). The PAM approximates the proposed attention mechanism in 4 steps: embedding of the features into the Permutohedral Lattice higher dimensional space, Splat, Blur and Slice, as illustrated in Fig. 1. Each of these steps scales linearly in .
The advantage of this approximation algorithm against other possibilities [2, 5] is that the gradients with respect to the input feature vectors and the descriptor vectors can be expressed using the four steps composing the forward pass and be fully parallelized. Omitting the dependencies in , , and , we can express the forward pass as:
[TABLE]
where is the embedding operator, is the Splat operator, is the Blur operator and is the Slice operator. With the same notations, the backward pass can be expressed as:
[TABLE]
where is the loss and (similarly with ). is the Gaussian blurring operator where the Gaussian blur is applied in the reverse order in terms of direction of the Lattice. is the position embedding matrix and is a permutation computed during .
3 Experiments
3.0.1 Data
We evaluate the impact of PAM for non-local neural networks for the task of simultaneous segmentation and labeling of vertebrae. We performed our experiment on the CSI 2014 workshop challenge data111http://spineweb.digitalimaginggroup.ca/, which consists of 20 CT images. We used all 20 CT images in our framework using a 5-fold cross validation for evaluation. We resampled the data to obtain (1mm, 1mm, 3mm) voxels.
3.0.2 Implementation details
We implemented the PAM as well as all our pipelines using Pytorch. We optimized our networks with ADAM on patches with a fixed learning rate of and a batch size of 1. We used the Dice loss as loss function. Our implementation is publicly available222https://github.com/SamuelJoutard/Permutohedral_attention_module.
3.0.3 Models
As a preliminary experiment, we consider a specific 6-layer fully convolutional network (referred to as FCN). We design 2 baselines for this shallow setting. FCN is a plain fully convolutional network with a first () convolution with 18 output channels followed by 4 () embedding convolutions with 18 output channels each and a prediction () convolutional layer. Dil.FCN is a similar architecture where we replace each embedding convolution by a dilated block. A dilated block corresponds to 3 () convolutions in parallel with 6 output channels each. Of these 3 convolutions, two have dilated filters (dilatation factor of 2 and 4 respectively). The outputs of a dilated block are then concatenated before the next block. Then, we incorporate in each baseline the PAM (networks are respectively called FCN+PAM and Dil.FCN+PAM) and compare the results of those 4 configurations.
Figure 2 represents the Dil.FCN+PAM architecture. In this figure, we observe that, once we obtain the features and descriptors to compute attention, we split each feature and descriptor vector in two. Hence we obtain two sets of feature-descriptor pairs and on which we apply the PAM independently. There are two main advantages to doing so. First, it allows us to further reduce computation time and memory footprint. Second, it generates a per-group-of-channel attention map which makes the model more flexible (as a unique attention matrix for all descriptor channels is a particular case of two attention matrix, one for each group of channel). The reason for not splitting the feature-descriptor pairs set into more subsets is because we want a trade-off between the advantages described above and the preservation of relevant features to compute attention.
Then, we consider a 3D U-Net [4] which is one of the most popular architectures for segmentation [6]. We refer to our 3D U-Net simply as U-Net. We incorporate the PAM into our U-Net as shown in Fig. 3 and demonstrate that the PAM can also improve architectures which have large receptive fields (we call this network U-PAM-Net). As shown in Fig. 3, we incorporate the PAM at the half-resolution level. Hence, we compute attention for () voxel regions which, in our experiments, led to similar results as computing attention at the voxel level while decreasing computation time and making convergence faster.
As the PAM introduces a small number of extra parameters, we compensate with additional channels in the first convolution on the architectures without the PAM so that the corresponding networks have either as many as or more degrees of freedom than networks with the PAM integrated.
3.0.4 Results
We measure the performance of the different architectures with the Dice scores. Table 1 shows that the PAM improves performance for all the architectures it was incorporated into. In addition, we highlight that the shallow network Dil.FCN+PAM performs almost as well as the much deeper network 3D U-Net. Indeed, the dilated convolutions manage to describe the voxels using contextual information while the PAM uses those meaningful features to compute voxels interactions. Table 1 also illustrates the limitation of down-sampling layers pointed earlier as U-Net performs poorly on cervical vertebrae which appear very small in our images. U-PAM-Net manages to reach higher accuracy performances than [8], which makes use of a task-specific framework especially tuned to ”count” the vertebrae from spine segmentation. While [8] report an accuracy of 81%, our proposed framework obtained 89% using the same evaluation metric and on the same dataset. It should be noted that the training frameworks in terms of test-train split were different for both approaches. Figure 4 shows a representative example of the results we observed.
4 Discussion
In this work, we propose the Permutohedral Attention Module, a computationally efficient attention module to be applied in 3D deep learning framework. The PAM can be incorporated in any CNN architecture. We demonstrated its ability to efficiently handle non-local information in the context of vertebrae segmentation and presented its potential to reduce networks size in specific tasks. Future work will notably include the investigation of asymmetric attention matrix for feature filtering and the integration of the PAM formulation in path training.
4.0.1 Acknowledgement
We thank E. Molteni, C. Sudre, B. Murray, K. Georgiadis, Z. Eaton-Rosen, M. Ebner for their useful comments. This work is supported by the Wellcome/EPSRC Centre for Medical Engineering [WT 203148/Z/16/Z]. TV is supported by a Medtronic / RAEng Research Chair [RCSRF1819/7/34].
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1[1] Adams, A., Baek, J., Davis, M.A.: Fast high-dimensional filtering using the permutohedral lattice. Computer Graphics Forum (2010)
- 2[2] Adams, A., Gelfand, N., Dolson, J., Levoy, M.: Gaussian KD-trees for fast high-dimensional filtering. ACM Trans. Graph. 28 (3), 21:1–21:12 (Jul 2009)
- 3[3] Buades, A., Coll, B.: A non-local algorithm for image denoising. In: In CVPR. pp. 60–65 (2005)
- 4[4] Çiçek, Ö., Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-Net: Learning dense volumetric segmentation from sparse annotation. In: MICCAI (2016)
- 5[5] Chen, J., Paris, S., Durand, F.: Real-time edge-aware image processing with the bilateral grid. In: ACM SIGGRAPH 2007 Papers. SIGGRAPH ’07 (2007)
- 6[6] Isensee, F., Petersen, J., Klein, A., Zimmerer, D., Jaeger, P.F., Kohl, S., Wasserthal, J., Koehler, G., Norajitra, T., Wirkert, S.J., Maier-Hein, K.H.: nn U-Net: Self-adapting framework for u-net-based medical image segmentation. ar Xiv preprint ar Xiv:1809.10486 (2018)
- 7[7] Jampani, V., Kiefel, M., Gehler, P.: Learning sparse high dimensional filters: Image filtering, dense crfs and bilateral neural networks (06 2016). https://doi.org/10.1109/CVPR.2016.482
- 8[8] Lessmann, N., van Ginneken, B., de Jong, P.A., Isgum, I.: Iterative fully convolutional neural networks for automatic vertebra segmentation and identification. Medical Image Analysis 53 , 142–155 (2019)
