AMD Severity Prediction And Explainability Using Image Registration And Deep Embedded Clustering
Dwarikanath Mahapatra

TL;DR
This paper introduces a deep learning approach combining image registration and clustering to predict and explain AMD severity from OCT images, achieving high accuracy and improved interpretability.
Contribution
It presents a novel method integrating image registration and deep embedded clustering for AMD severity prediction with enhanced explainability.
Findings
Achieves state-of-the-art classification performance
Performs well on unseen data
Provides better explainability than traditional methods
Abstract
We propose a method to predict severity of age related macular degeneration (AMD) from input optical coherence tomography (OCT) images. Although there is no standard clinical severity scale for AMD, we leverage deep learning (DL) based image registration and clustering methods to identify diseased cases and predict their severity. Experiments demonstrate our approach's disease classification performance matches state of the art methods. The predicted disease severity performs well on previously unseen data. Registration output provides better explainability than class activation maps regarding label and severity decisions
| Bef. | After Registration | ||||||
| Reg | Reg-DEC | RegNoDEC | Reg-kMeans | DIRNet | FlowNet | VoxelMorph | |
| DM() | 78.9 | 89.3 | 85.9 | 84.8 | 83.5 | 87.6 | 88.0 |
| HD95(mm) | 12.9 | 6.9 | 8.4 | 8.7 | 9.8 | 7.5 | 7.4 |
| MAD | 13.7 | 7.3 | 8.9 | 10.3 | 9.1 | 8.6 | 7.9 |
| Time(s) | 0.5 | 0.4 | 0.6 | 0.5 | 0.6 | 0.6 | |
| DEC | kmeans | MultCNN [66] | |||||
|---|---|---|---|---|---|---|---|
| Sen | 93.6 | 91.7 | 92.5 | 92.6 | 89.5 | 85.7 | 92.5 |
| Spe | 94.3 | 92.8 | 93.6 | 93.5 | 90.6 | 86.8 | 93.4 |
| AUC | 96.4 | 94.1 | 95.2 | 95.3 | 91.9 | 87.7 | 95.2 |
| Time(h) | 4.3 | 16.7 | 12.4 | 13.6 | 2.5 | 0.5 | 15.1 |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsRetinal Imaging and Analysis · Artificial Intelligence in Healthcare · Digital Imaging for Blood Diseases
11institutetext: IBM Research Australia
11email: [dwarim,hidem]@au1.ibm.com.
AMD Severity Prediction And Explainability Using Image Registration And Deep Embedded Clustering
Dwarikanath Mahapatra and Hidemasa Muta
Abstract
We propose a method to predict severity of age related macular degeneration (AMD) from input optical coherence tomography (OCT) images. Although there is no standard clinical severity scale for AMD, we leverage deep learning (DL) based image registration and clustering methods to identify diseased cases and predict their severity. Experiments demonstrate our approach’s disease classification performance matches state of the art methods. The predicted disease severity performs well on previously unseen data. Registration output provides better explainability than class activation maps regarding label and severity decisions.
1 Introduction
Most approaches to deep learning (DL) based medical image classification output a binary decision about presence or absence of a disease without explicitly justifying decisions. Moreover, disease severity prediction in an unsupervised approach is not clearly defined unless the labels provide such information, as in diabetic retinopathy [1]. Diseases such as age related macular degeneration (AMD) do not have a standard clinical severity scale and it is left to the observer’s expertise to assess severity. While class activation maps (CAMs) [73] highlight image regions that have high response to the trained classifier they do not provide measurable parameters to explain the decision. Explainability of classifier decisions is an essential requirement of modern diagnosis systems.
In this paper we propose a convolutional neural network (CNN) based optical coherence tomography (OCT) image registration method that: 1) predicts the disease class of a given image (e.g., normal, diabetic macular edema (DME) or dry AMD); 2) uses registration output to grade disease severity on a normalized scale of where indicates normal and indicates confirmed disease, and 3) provides explainability by outputting measurable parameters.
Previous approaches to DL based image registration include regressors [70, 68, 4, 5, 23, 75, 63, 28, 22, 11, 21, 35] and generative adversarial networks (GANs) [34, 30, 31, 36, 39, 18]. [3, 17, 20, 16, 57, 55] learn a parameterized registration function from training data without the need for simulated deformations in [68, 53, 33, 25, 34, 24]. Although there is considerable research in the field of interpretable machine learning their application to medical image analysis problems is limited [65, 7, 43, 42, 74, 27, 26, 32]. The CAMs of [73] serve as visualization aids rather than showing quantitative parameters. We propose a novel approach to overcome the limitations of CAM, by providing quanitative measures and their visualization for disease diagnosis based on image registration. Image registration makes the approach fast and enables projection of registration parameters to a linear scale for comparison against normal and diseased cases. It also provides localized and accurate quantitative output compared to CAMs. Our paper makes the following contributions: 1) a novel approach for AMD severity estimation using registration parameters and clustering; and 2) mapping registration output to a classification decision and output quantitative values explaining classification decision.
2 Method
Our proposed method consists of: 1) atlas construction for different classes; 2) End to end training of a neural network to estimate registration parameters and assign severity labels; 3) Assign a test volume to a disease severity scale, output its registration parameters and provide quantitatively interpretable information.
2.1 Atlas Construction Using Groupwise Registration
All normal volumes are coarsely aligned using their point cloud cluster and the iterated closest point (ICP) algorithm. Groupwise registration using ITK [2] on all volumes gives the atlas image . Each normal image is registered to using B-splines. The registration parameters are displacements of grid nodes. They are easier to store and predict than a dense D deformation field and can be used to generate the D deformation field. The above steps are used to obtain atlases for AMD () and DME ().
2.2 Deep Embedded Clustering Network
Deep embedded clustering [72, 40, 62, 10, 38, 37, 46] is an unsupervised clustering approach and gives superior results than traditional clustering algorithms. To cluster n points into k clusters, each represented by a centroid , DEC first transforms the data with a nonlinear mapping , where are learnable parameters and is the latent feature space with lower dimensionality than . Similarity between embedded point and cluster centroid is given by the Student’s t-distribution as
[TABLE]
where for all experiments. DEC simultaneously learns cluster centers in feature space and the parameters . It involves: (1) parameter initialization with a deep autoencoder [69] and (2) iterative parameter optimization by computing an auxiliary target distribution and minimizing the Kullback–Leibler (KL) divergence. For further details we refer the reader to [72]
2.3 Estimation of Registration parameters
Conventional registration methods output a deformation field from an input image pair while we jointly estimate the grid displacements and severity label using end to end training. Figure 1 depicts our workflow. An input volume of dimension , is number of slices, is converted to a stack of convolution feature maps by downsampling to and employing convolution. The output is shown in Figure 1 as d256 fN k1, which indicates output maps of dimension () , feature maps () and kernel dimension () of . The next convolution layer uses kernels and outputs feature maps. This is followed by a max pooling step that reduces the map dimensions to and the next convolution layer outputs feature maps using kernels. After three further max pooling and convolution layers, the output of the “Encoder” stage are feature maps of dimension .
The Encoder output is used in two ways. The first branch is the input to the Deep Embedded Clustering (DEC) network (green boxes depicting fully connected layers) that outputs a cluster label indicating severity score. The second branch from the Encoder is connected, along with the input volume’s disease label, to a fully connected (FC) layer (orange boxes) having neurons. It is followed by two more FC layers of neurons each and the final output is the set of registration parameters. The “Class Label id” (disease label of input volume) and the Encoder output are combined using a global pooling step. The motivation behind combining the two is as follows: We are interested to register, for example, a normal volume to the normal atlas. The ground truth registration parameters of a normal volume correspond to those obtained when registering the input volume to the normal atlas, and we want the regression network to predict these parameters. Feeding the input volume’s actual disease label guides the regression network to register the image to the corresponding atlas.
2.4 Training Stage Implementation
The entire dataset is divided into training (), validation () and test () folds for each class. The DEC parameter initialization closely follows the steps outlined in [72]. The regression network is trained using the input images, their labels and the corresponding registration parameters. We augment the datasets times by rotation and flipping and obtain their registration parameters with the corresponding atlas. In the first phase of training only the regression network is trained using mean squared error (MSE) loss for epochs to get an initial set of weights. Subsequently, the DEC is trained using the output of the Encoder network. After training is complete we cluster the different volumes and observe that of the normal patients are assigned to clusters and . of Diabetic macular Edema (DME) cases are assigned to clusters and , while of AMD cases are assigned to clusters and . Thus the following mapping between image labels and cluster labels are obtained .
2.5 Predicting Severity of test image
When a test image comes in we first use the trained DEC to predict the cluster label, which apart from providing disease severity on a scale of also gives the image’s disease class. The disease label is then used to predict the image’s registration parameters to the corresponding atlas. Depending upon the desired level of granularity of disease severity the number of clusters can be varied to identify different cohorts that exhibit specific traits.
3 Experimental Results
We demonstrate the effectiveness of our algorithm on a public dataset [66, 19, 29, 61, 59, 60, 58] consisting of OCT volumes from 50 normal, 48 dry AMD, and 50 DME patients. The axial resolution of the images is -m with scan dimension of pixels. The number of B-scans varies between per volume in different patients. The dataset is publicly available at http://www.biosigdata.com. For all registration steps we used a grid size of . The number of predicted grid parameters are
3.1 Registration Results
The output registration parameters from our method are used to generate a deformation field using B-splines and compared with outputs of other registration methods. For the purpose of quantitative evaluation we applied simulated deformation fields and use different registration methods to recover the registration field. Validation of accuracy is based on mean absolute distance (MAD) between applied and recovered deformation fields. We also manually annotate retinal layers and compute their Hausdorff Distance () and Dice Metric (DM) before and after registration. Our method was implemented with Python and Keras, using SGD and Adam with and batch normalization. Training and test was performed on a NVIDIA Tesla K GPU with GB RAM.
Table 1 compares results of the following methods: 1) : Our proposed method; 2) : using only the registration without additional clustering; 3) : The method of [3, 13, 15, 71, 14, 12, 56]; 4) : - the registration method of [6, 50, 51, 52, 54, 47, 48]; 5) : - the method of [70, 45, 44, 49, 64, 41]; 6) - replacing DEC with kmeans clustering. Our method outperforms the state of the art DL based registration methods.
3.2 Classification Results
Table 2 summarizes the performance of different methods on the test set for classifying between normal, DME and AMD. Results are also shown for CNN based classification networks such as VGG-16 [67], Resnet [8] and DenseNet [9], three of the most widely used classification CNNs and the multiscale CNN ensemble of [66] that serves as the baseline for this dataset. Our proposed method outperforms standard CNN architectures, thus proving the efficacy of combining registration with clustering for classification tasks. It also shows ’s advantages of lower computing time and fewer training parameters.
3.3 Identification of Disease Subgroups And Explainability
Besides predicting a disease label and severity score, our method provides explainability behind the decision. For a given test image and its predicted registration parameters we calculate its distance from each of the cluster centers to give us a single value quantifying the sample’s similarity with each disease cluster. Let the sample be assigned to cluster and let the corresponding distances of to each cluster be . We calculate a normalized value
[TABLE]
where gives a probability of the test sample reaching the highest severity score. It is also a severity score on a normalized scale of . Scores from multiple visits help to build a patient severity profile for analysing different factors behind increase or decrease of severity, as well as the corresponding rate of change. The rate of severity change is an important factor to determine a personalized diagnosis plan. is different from the class probability obtained from a CNN classifier. The classifier probability is its confidence in the decision while gives the probability of transitioning to the most severe stage.
Tables 1,2 demonstrate ’s superior performance for classification and registration. To determine ’s effectiveness in predicting disease severity of classes not part of the training data, we train our severity prediction network on normal and AMD images only, leaving out the DME affected images. We keep the same number of clusters (i.e., ) as before. Since there are no DME images and number of clusters is unchanged, assignment of images to clusters is different than before. In this case of AMD images are assigned to clusters which is a drop of than the previous assignment while of normal samples are assigned to clusters which is decrease of .
We see fewer images in clusters although the majority of original assignments of normal and AMD cases are unchanged. When we use this trained model on the DME images we find that of the images are assigned to clusters , a decrease of from before. The above results lead to the following conclusions: 1) ’s performance reduces by for DME and maximum of (for Normal images) when DME images were not part of the training data. This is not a significant drop indicating ’s capacity to identify sub-groups that were not part of the training data. 2) Using k-means clustering does not give same performance levels demonstrating that end to end feature learning combined with clustering gives much better results than performing the steps separately. accurately predicts disease severity even though there is no standard severity grading scale. Severity scale also identifies sub-groups from the population with a specific disease activity.
Figure 2 first and second columns, respectively, show AMD images accurately classified by and DenseNet. The yellow arrows highlight regions of abnormality identified by clinicians. Red ellipses (in first column) show the region of disease activity. The length of major axis quantifies magnitude of displacement of the corresponding grid point, and the orientation indicates direction. The local displacement magnitude is proportional to disease severity while the orientation identifies the exact location. The second column shows the corresponding CAMs obtained from DenseNet (region highlighted in green). Although the CAMs include the region of disease activity it does not localize it accurately and is spread out, nor does it output a measurable value. By dividing the displacement magnitude with the distance between the grid points we get a value very close to . The advantages of our registration based method is obvious since it pinpoints abnormality and quantifies it in terms of displacement magnitude and angle.
Figure 2 third column shows examples of normal images that were rightly classified by but incorrectly classified as AMD by DenseNet. The green regions highlight disease activty as identified by DenseNet, which is erroneous since there are no abnormalities here. does not show any localization of pathologies in these examples. The fourth column shows examples of DME that were rightly identified by , despite not being being part of the training data, alongwith red ellipses showing localized regions of disease activity. They were assigned to clusters respectively. The CNNs trained to classify AMD and normal would mostly classify the second and third image as diseased while the first image was usually classified as normal because of its similar appearance to some normal images. Thus, our method identifies different patient cohorts despite those not being part of the training data.
4 Conclusion
We propose a method to predict disease severity from retinal OCT images despite there being no labels provided for the disease severity. CNN regressor predicts registration parameters for a given test image which are undergo clustering to output a disease severity scale and a disease probability score in addition to the classification label (diseased or normal). Experimental results show our proposed method achieves better registration and classification performance compared to existing approaches. We are able to identify distinct patient cohorts not part of training data. Our approach also provides explainability behind the classification decision by quantifying disease activity from the registration parameters.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1[1] https://www.eyepacs.com
- 2[2] ”the insight segmentation and registration toolkit” www.itk.org
- 3[3] Balakrishnan, G., Zhao, A., Sabuncu, M., Guttag, J.: An supervised learning model for deformable medical image registration. In: Proc. CVPR. pp. 9252–9260 (2018)
- 4[4] Bozorgtabar, B., Mahapatra, D., von Teng, H., Pollinger, A., Ebner, L., Thiran, J.P., M.Reyes.: Informative sample generation using class aware generative adversarial networks for classification of chest xrays. Computer Vision and Image Understanding 184, 57–65 (2019)
- 5[5] D. Mahapatra, B.B., Garnavi, R.: Image super-resolution using progressive generative adversarial networks for medical image analysis. Computerized Medical Imaging and Graphics 71(1), 30–39 (2019)
- 6[6] Dosovitskiy, A., Fischer, P., et. al.: Flownet: Learning optical flow with convolutional networks. In: In Proc. IEEE ICCV. pp. 2758–2766 (2015)
- 7[7] Graziani, M., Andrearczyk, V., Müller, H.: Regression concept vectors for bidirectional explanations in histopathology. In: In Proc. MICCAI-i MIMIC. pp. 124–132 (2018)
- 8[8] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: In Proc. CVPR (2016)
