Joint Imbalance Adaptation for Radiology Report Generation
Wang Li, Guangzeng Han, Yuexin Wu, I.-Chan Huang, Xiaolei Huang

TL;DR
This paper introduces a new method to improve radiology report generation by addressing data imbalance issues in medical tokens and labels.
Contribution
The novel JIMA model uses a hard-to-easy learning strategy to reduce overfitting on frequent patterns and improve performance on infrequent medical terms.
Findings
JIMA improves evaluation metrics by 16.75–50.50% on radiology datasets.
The model enhances performance on infrequent tokens and abnormal radiological entries.
Human evaluations confirm improved clinical accuracy of generated reports.
Abstract
Radiology report generation, translating radiological images into precise and clinically relevant description, may face the data imbalance challenge — medical tokens appear less frequently than regular tokens, and normal entries are significantly more than abnormal ones. However, very few studies consider the imbalance issues, not even with conjugate imbalance factors. In this study, we propose a Joint Imbalance Adaptation (JIMA) model to promote task robustness by leveraging token and label imbalance. We employ a hard-to-easy learning strategy that mitigates overfitting to frequent labels and tokens, thereby encouraging the model to focus more on infrequent labels and clinical tokens. JIMA presents notable improvements (16.75–50.50% on average) across evaluation metrics on IU X-ray and MIMIC-CXR datasets. Our ablation analysis and human evaluations show the improvements mainly come…
Genes, proteins, chemicals, diseases, species, mutations and cell lines named across the full text — each resolved to its canonical identifier and authoritative record.
Click any figure to enlarge with its caption.
Figure 1
Figure 2
Figure 3
Figure 4
Figure 5
Figure 6- —National Science Foundation,United States
- —National Science Foundation, United States
- —https://doi.org/10.13039/100000054National Cancer Institute
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Machine Learning in Healthcare · Multimodal Machine Learning Applications
Introduction
Radiology report generation is a multimodal and medical image-to-text task that generates text descriptions for radiographs (e.g., X-ray or CT scan), which may reduce the workloads of radiologists [1, 2]. The task has own unique characteristics than general image-to-text tasks (e.g., image captioning), such as lengthy medical notes, medical annotations, and clinical terminologies. As demonstrated in Fig. 1, data imbalance can significantly impact model robustness that prevents model deployment in practice — models can easily overfit on frequent patterns. However, modeling data imbalance to augment the robust generation of the radiology report is understudied.Fig. 1. State-of-the-art model performance on normal and abnormal entries by BLEU-4 (left two) and low- and high-frequent tokens by F1 scores (right two). We used two different colors to denote model performance on normal (orange) vs abnormal (light green) reports or frequent (orange) vs infrequent (light green) tokens
Two major data imbalances exist in the radiology generation task, label and token. Label imbalance pertains to a disproportionate ratio of normal and abnormal diagnosis categories, which exist in radiological images and text reports. For instance, normal cases (images and reports) dominate radiology data, which can easily lead to underperformance in disease detection and professional description. As shown in Table 1, abnormal reports are considerably longer than normal reports while can only count less than 15%. These abnormal reports are much harder to generate than shorter reports [3–5] and can be worse with fewer samples than normal cases.1 Existing imbalance learning studies of radiology report generation primarily focus on label imbalance [7, 8]. Token imbalance is a critical challenge in generation that tokens have varied occurrence frequencies, and the issue is more critical in the medical task. Learning infrequent tokens can be harder than frequent tokens for generation models [9, 10]. Medical tokens appear less frequently than regular ones, and the infrequent tokens may contain more medical results, highlighting the very unique challenge of this task. The imbalance issue in radiology report generation also connects to broader challenges in multi-label classification or image object detection tasks, where models often struggle to effectively learn from infrequent but clinically significant labels [11] and object categories [12]. Figure 1 illustrates the learning progress of the state-of-the-art (SOTA) model RRG [13] in predicting a report with predominantly normal diagnoses. The model shows strong performance on normal cases but struggles on abnormal reports.Table 1. Data statistics summaryImageReportVocabAbnormal %Normal %L \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$L_{normal}$$\end{document} \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$L_{abnormal}$$\end{document} IU X-ray74703955151732.96%67.04%35.9927.7640.72MIMIC-CXR377,110227,83513,87613.97%86.03%59.7034.5759.36Variations exist in label (Normal and Abnormal %) and average report length (L). Vocab refers to the vocabular size, and normal or abnormal indicates report labels
To promote the quality of generated reports, we propose Joint Imbalance Adaptation (JIMA) model by curriculum learning [14], a training strategy to present models with increasingly complex examples and mimic human learning from simple to difficult task. JIMA automatically guides the model learning process by leveraging optimization difficulties, strengthening learning capability on infrequent samples, and alleviating overfitting on frequent patterns on both label and token. Our method implements a tailored curriculum strategy that dynamically adjusts data example difficulties by integrating radiology-domain-specific knowledge and difficulty metrics not previously considered in curriculum learning approaches, which aims to address data imbalance and promote radiology report quality. We incorporate the token and label metrics as a joint optimization and design a novel Training Scheduler that sampling and sorting training instances with a multi-aspect scoring mechanism. The scheduler automatically adjusts training samples when model performance varies across multiple imbalance factors. We conduct experiments on two publicly available datasets, MIMIC-CXR [15] and IU X-ray [16] with automatic and human evaluations. By comparing with six state-of-the-art (SOTA) baselines on overall and imbalance performance settings, our approach shows promising results over the SOTA baselines. Notably, JIMA demonstrates minimal impact on the performance of highly frequent tokens and labels, significantly enhances performance for moderately frequent samples, and reveals limitations in boosting performance for the rarest tokens. Our ablation and qualitative analyses show that JIMA can generate more precise medical reports, alleviating label and token imbalance.
Related Work
Radiology report generation is a domain-specific image-to-text task that has two major directions, retrieval- [17, 18] and generation-based [19–21]. The retrieval-based approach compares similarities between an input radiology image and a set of report candidates, ranks the candidates, and returns the most similar one [5, 17, 18, 22, 23]. In contrast, our study focuses on the generation-based task, which automatically generates a precise report from an input image. The task has domain-specific characteristics in the clinical field. The clinical data contains many infrequent medical terminologies and longer documents than image captioning from general domains [6]. As radiology report generation can reduce the workloads of radiologists, generating highly qualified and precise can be a critical challenge, especially under the imbalance settings. Differing from previous work, we aim to promote model robustness and reliability under imbalance settings, which have been rarely studied in the radiology report generation.
Imbalance learning aims to model skewed data distributions. The primary focus of imbalance learning is on class or label imbalance, such as positive or negative reviews in sentiment analysis [11]. Recent studies have proposed new approaches to solve imbalance issues by weighting hard and infrequent examples [12] or leveraging imbalance distributions to augment minority data labels [11, 24]. The studies can inspire our work to model multiple imbalance factors, as the imbalance is a multifaceted issue in radiology report generation that exists beyond the data label. Some studies developed new objective functions or data augmentation approach to promote model performance on minority labels, such as focal-loss by down-weighting easy samples [12] or SMOTE by creating synthetic data on minority labels [24]. However, those methods may not be applicable to our primary generation unit, token, which has large vocabulary sizes and extreme sparsity. In terms of radiology report generation, reports may have disease-related labels. Recent studies have augmented model robustness by balancing performance between disease and normal by reinforcement learning [7, 8]. However, those methods focused on the label imbalance, and our study considers a multifaceted imbalance challenge, including label and token imbalance. The token imbalance can be even more critical for the clinical domain, as medical tokens appear less frequently than regular tokens in radiology reports. A close work is the TIMER [10], which considers the token imbalance. However, the approach ignores other imbalance factors, which is solved by our approach. Particularly, our approach jointly models multiple imbalance factors, label and token, and we propose a new curriculum learning method to learn the imbalance factors.Fig. 2. Frequent and infrequent token distributions conditioning on report label. We denote the report types with light green (normal) and orange (abnormal), which show different token imbalance distributions
Data
We collected two publicly accessible datasets for this study, IU X-ray [16] and MIMIC-CXR [15], de-identified chest X-ray datasets to evaluate radiology report generation. IU X-ray [16], collected from the Indiana Network for Patient Care, includes 7470 X-ray images and corresponding 3955 radiology reports. MIMIC-CXR [15], collected from the Beth Israel Deaconess Medical Center, contains 377,110 X-ray images and 227,827 radiology reports for 65,379 patients. Each report is a text document and associates with one or more front and side X-ray images. Table 1 summarizes statistics of data imbalance and Fig. 2 visualizes the distributions of frequent (ranked in the top 12.5% of the vocabulary) and infrequent tokens. We include preprocessing details in Appendix 1.
Table 1 presents imbalance patterns in tokens and labels. Abnormal entries are predominant in both datasets, and MIMIC-CXR displays a more skewed label distribution, as more abnormal samples were collected during diagnosis phases not for screening purposes. MIMIC-CXR has a longer average length than IU X-ray. The lengthier documents may pose a unique multimodal generation challenge in the medical field. To conduct our analysis, we define the low and high frequencies using the top 12.5% frequent tokens. Figure 1 suggests a joint relation between label and token imbalance and higher ratios of low-frequency tokens in abnormal reports. This observation motivates us to investigate how the imbalance impacts model robustness and reliability.
Imbalance Effects
We examine the potential impact of label and token imbalance on model performance. To ensure consistency, we keep the top 12.5% to split low- and high-frequent tokens for evaluation purposes. The analysis includes three state-of-the-art models, R2Gen [19], WCL [25], and CMN [26]. We use BLEU-4 [27] and F1 scores to measure performance across both token (low vs high frequency) and label (normal vs. abnormal) imbalance. We visualize performance variations in Fig. 2.
The results suggest that the models exhibit significant difficulties in coping under label and token imbalance. Models consistently perform worse on abnormal reports, which are lengthier and have more infrequent tokens than normal reports. For example, the top 12.5% frequent tokens count > 80% tokens in two datasets, and low-frequent tokens have much worse performance than frequent tokens, as infrequent tokens are harder to optimize [28]. However, infrequent tokens contain higher ratios of medical terms (e.g., silhouettes and pulmonary) describing health states. The significantly varying performance highlights the unique challenges to adapt token and label imbalance. While existing work [7] has considered label imbalance, however, the study did not examine the performance effects of label or token imbalance. The findings inspire us to propose our model Joint Imbalance Adaptation (JIMA) to model token and label imbalance.
Joint Imbalance Adaptation
In this section, we present our approach Joint Imbalance Adaptation (JIMA) in Fig. 3 by using curriculum learning. JIMA aims to augment model robustness under label and token imbalance. As optimizing data imbalance has been demonstrated difficulty, deploying such a learning strategy will strengthen model robustness and reliability. Our proposed approach deploys curriculum learning (CL) [29] that automatically adjusts the optimization process by gradually selecting training data entries from learning difficulty — learning from hard to easy samples as our optimization strategy [30]. To achieve the goal, we design two major CL modules: a difficulty measurer for assessing the difficulty of samples, and a training scheduler for determining the percentage of training data. Then we employ our CL training strategy to two tasks. Task 1 aims to predict entities from the images, and task 2 can generate a report from the images’ features and entity distribution.Fig. 3JIMA has two curriculum learning tasks. Task 1 aims to predict entity distribution from images, and task 2 aims to generate a report from the image’s feature and entity distribution. We assign one color per task and solid arrows as workflows. We extracted imbalanced entity distributions from the training data by the Radgraph as the gold truth and compared the entity estimation with the predicted entity distribution. We feed the fused radiology image features and imbalance patterns for the generation process. During the training, tasks 1 and 2 decide which data samples will be fed for the model training
Difficulty measurer is the core scoring function of the curriculum learning that decides which data samples should be fed to models for training. To diversify learning aspects and jointly incorporate imbalance factors, we propose a novel measurement to improve model performance over imbalance patterns. Our measurement adopts a competitive mechanism that encourages correct options with higher ranking over incorrect ones, rather than independently increasing the likelihood of correct options and decreasing the likelihood of incorrect options. This approach helps mitigate overfitting on common samples and underfitting on rare samples since it focuses on ranking of correct options rather than prediction confidence. Specifically, given a reference token z, vocabulary list V, and the prediction \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{p} \in \mathcal {R}^{|V|}$$\end{document} , we calculate the token reference (z) probability ranking in the prediction \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{p} $$\end{document} as the following:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} k = Rank(\textbf{p}, \textbf{p}[z])/ |V| \end{aligned}$$\end{document}where |V| is the vocabulary size. \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$Rank(\textbf{p}, \textbf{p}[z])$$\end{document} assigns a rank to \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{p}$$\end{document} in descending order and identifies the position of \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{p}[z]$$\end{document} within this ranking. We used the Rank() function to inference imbalance distributions on the token level, which will give a higher preference towards tokens that are hard to predict. To avoid the biased performance evaluation in different labels and tokens, we calculate the average value in non-entity and entity tokens separately, extracted by the Radgraph [31]. k ranges from 0 to 1 under regularization with |V|. A higher value of k indicates that the sample is more difficult. Then, we feed the difficulty information to the next step, Training Scheduler.
Training scheduler aims to automatically leverage imbalance effects by selecting training samples via the difficulty measurers. Our goal is to increase the number of easier samples when the performance decreases and vice versa. According to our goal, we design our scheduler function, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$c(s_{t})$$\end{document} as follows:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} c(s_{t}) = min(1, [1-\frac{(s_{t}-s_{t-1})}{s_{t-1}}] \times c(s_{t-1})), t \ge 1 \end{aligned}$$\end{document}, where s is the average performance of all training samples, measuring the model’s learning ability. t is the training step. Given decreasing performance as an example, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\frac{(s_{t}-s_{t-1})}{s_{t-1}}$$\end{document} will be negative. During the process, the ratio \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$1-\frac{(s_{t}-s_{t-1})}{s_{t-1}} > 1$$\end{document} will allow the model to include more easy training data than the last step \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$c(s_{t-1})$$\end{document} . When the performance increases, the scheduler feeds less easy samples to the model and reduces the over-fitting on these samples. After multiple epochs of training, harder samples receive more training iterations than easier samples. In this way, we can alleviate the challenge from imbalanced tokens and labels in the radiology report generation task. To start our curriculum learning, we record the samples’ average performance of the last two regular training epochs as \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$s_{0}$$\end{document} and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$s_{1}$$\end{document} , where we empirically initialize \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$c(s_{0})$$\end{document} as 1.
CL-Task 1
CL-Task 1 is to exploit imbalance patterns of radiology labels to generate clinically accurate reports. Entities in clinical reports play a crucial role in disease diagnosis. However, these clinical tokens often occur infrequently and are significantly underestimated during model training. Hence, we assess the accuracy of clinical entities to evaluate performance. Our intuition is that as abnormal cases contain more infrequent entities, focusing on the clinical entities may benefit the abnormal cases. If our generated reports are clinically correct, the visual extractor can accurately extract the same entities as gold entities from images.
The computing process is as follows. Given a radiology image Img and the corresponding report \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$Z =\left( z_{0}, \ldots , z_{l}\right) $$\end{document} with the length l, we extract the features from images with a visual extractor. We use ResNet101 [32] ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{\mathcal {R}}$$\end{document} ) as our visual extractor and obtain image features ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X}$$\end{document} ) from different convolutional channels, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X} = f_{\mathcal {R}}({\text {Img}})$$\end{document} . \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X} \in \mathcal {R}^{patch\_size \times d}$$\end{document} , where d is the size of the feature vector. To predict entities distribution, we feed the feature from \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X}$$\end{document} into the Entity Extractor ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{E}$$\end{document} ) with parameters \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$W_{E} \in \mathcal {R}^{d \times |V|}$$\end{document} and average the value on each patch(1st dimension),
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} \textbf{q} = AVG_{:1}(f_{E}(\textbf{X} | W_{E})) \end{aligned}$$\end{document}Then we obtain the entity distribution representation \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{q} \in \mathcal {R}^{|V|}$$\end{document} . To optimize the model, we minimize Binary Cross Entropy as follows,
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} \mathcal {L}_{task1} = \frac{1}{|V|} \sum _{i=1}^{|V|}-\left( y_i ^* \log \left( q_i\right) +\left( 1-y_i\right) * \log \left( 1-q_i\right) \right) \end{aligned}$$\end{document}where \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$q_{i}$$\end{document} is the prediction probability of the i-th token and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$y_{i} = 1$$\end{document} if i-th token is the entities. We extract the gold entities ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{e}$$\end{document} ) by radgraph [31]. To evaluate the sample’s difficulty in this task, we input the entity distribution prediction \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{q}$$\end{document} into (1) and obtain \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$k^{task1} = \sum _{i}^{|\textbf{e}|} Rank(\textbf{q},\textbf{q}[e_{i}])/( |V| \cdot |\textbf{e}|)$$\end{document} .
CL-Task 2
CL-Task 2 implements an image-to-text generation pipeline with the objective of improving the infrequent tokens prediction in reports. To generate a report containing more clinically useful information, we integrate the probability prediction of entities( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{q}$$\end{document} ) in (3) with image’s feature ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X}$$\end{document} ). Since \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$d \ne |V|$$\end{document} , we cannot interact \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{q}$$\end{document} and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{X}$$\end{document} directly. To facilitate their interaction and information sharing, we employ a cross-concatenation and perform an element-wise multiplication on their cross-concatenated matrix as follows:
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} \textbf{S} = concat_{:2}(\textbf{X}, \textbf{q}) \odot concat_{:2}(\textbf{q}, \textbf{X}) \end{aligned}$$\end{document}where \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{S} \in \mathcal {R}^{patch\_size \times (d + |V|)}$$\end{document} and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\odot $$\end{document} refers to the element-wise multiplication. We empirically compare with the simple sum, dot product, and cross-modal attention network [33]; however, the element-wise multiplication achieved the best results on the validation set. We infer that the multiplication is less complex than the cross-modal attention network while keeps more vector information from the text and image modalities. Finally, we adopt a transformer structure to encode \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{S}$$\end{document} and generate i-th token probability distribution \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\mathbf {P_{i}}$$\end{document} from encoding feature \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\textbf{S}$$\end{document} and i-th token, \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\mathbf {{P}_{i}} = f_{{\mathcal {T}}}(\textbf{S}, z_{i-1})$$\end{document} . To optimize the model, we minimize negative log-likelihood loss (NLL) as follows,
\documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\begin{aligned} \mathcal {L}_{task2} = -\sum _{i}^{l} \log \left( \mathbf {P_{i}}\right) \end{aligned}$$\end{document}We can access the sample’s difficulty with \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\mathbf {P_{i}}$$\end{document} by (1), \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$k^{task2} \!=\! \sum _{i}^{l} Rank(\textbf{P}_{i},\textbf{P}_{i}[z_{i}])/( |V| \cdot l) $$\end{document} .
Algorithm 1Optimization Process of JIMA.
CL-Joint Optimization
We propose a joint optimization approach to integrate two tasks. Algorithm 1 summarizes the overall optimization process of our approach. We set the learning rate of task 1 as \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\alpha $$\end{document} , and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\beta $$\end{document} refers to the learning rate of task 2. In each training step, we sample different data for different tasks, and each task focuses on optimizing its own module of the models. For example, we update the visual extractor ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{\mathcal {R}}$$\end{document} ) and the entity extractor ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{E}$$\end{document} ) in task 1. Next, we freeze the parameters of the visual extractor and the entity extractor and update the parameters of the transformer ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{\mathcal {T}}$$\end{document} ) specifically for task 2. Our optimization approach integrates with curriculum learning to tailor joint imbalance learning for each module ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_\mathcal {R}$$\end{document} , \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_{E}$$\end{document} , \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$f_\mathcal {T}$$\end{document} ). Curriculum learning empowers the model to concentrate on optimizing hard samples while mitigating the risk of overfitting to easier samples. The joint optimization scheme facilitates each task to manage different module parameters optimization and learn transferable knowledge from the simpler to more complex tasks. As a result, all modules collaborate to enhance error reduction from previous tasks.
Experiments
We design our experiments to evaluate performance on both regular and imbalanced settings via automatic and human evaluations. The automatic evaluation includes NLG-oriented and clinical-correctness metrics. NLG-oriented metrics measure the similarity between generated and reference reports. Clinical correctness and human evaluation belong to factually oriented metrics and domain-specific evaluation methods. To be consistent with our baselines [10, 13, 19], we utilize the F1 CheXbert [34] for the clinical-correctness metrics. The experiments compare our proposed approach (JIMA) and the state-of-the-art baselines. Two of our five baselines (CMM + RL & RRG) are designed to solve label imbalance by improving the abnormal findings generation. We conduct ablation and case analyses to fully understand the capabilities of our proposed approach. We include more implementation details and hyperparameter settings in Appendix 2.1.
Baselines
To examine the validity of our method, we include five state-of-the-art baselines under the same experimental settings: R2Gen [19], CMN [26], WCL [25], CMN + RL [20], RRG [23], TIMER [10], and RGRG [35] — and obtain from their open-sourced code repositories.
R2Gen [19] is a transformer-based model with ResNet101 [32] as the visual extractor. To capture some patterns in medical reports, R2Gen proposes a relational memory to enhance the transformer so that the model can learn from the patterns’ characteristics. Furthermore, R2Gen deploys a memory-driven conditional layer normalization to the transformer decoder, facilitating the incorporation of the previous step generation into the current step.
CMN [26] is a novel extension to the transformer architecture that facilitates the alignment of textual and visual modalities. The cross-modal memory network records the shared information of visual and textual features. The alignment process is carried out via memory querying and responding. The model maps the visual and textual features into the same representation space in memory querying and learns a weighted representation of these features in memory responding.
WCL [25] utilizes the R2Gen framework and incorporates a weakly supervised contrastive loss. Specifically, WCL leverages the contrastive loss to enhance the similarity between a given source image and its corresponding target sequence. Furthermore, the model enhances its ability to learn from difficult samples by assigning more weights to instances sharing common labels.
CMM + RL [20] is a cross-modal memory-based model with reinforcement learning for optimization. CMM + RL designs a cross-modal memory model to align the visual and textual features and deploy reinforcement learning to capture the label imbalance between abnormality and normality. The author uses BLEU-4 as a reward to guide the model to generate the next word from the image and previous words.
RRG [13, 23] aims to generate clinically correct reports by weakly supervised learning of the entities and relations from reports. RRG is a BERT-based model with Densenet-121 [36] as a visual extractor. RRG leverages RadGraph [31] to extract the entities and relation labels in a report. RRG utilizes reinforcement learning to optimize the model. The reward assesses the consistency and completeness of entities and the relation set between generated reports and reference radiology reports. RRG addresses label imbalance issues by maximizing the reward of predicting more complicated entities and relations in abnormal samples.
TIMER [10] aims to decrease the over-fitting of frequent tokens by introducing an unlikelihood loss to punish the error on these tokens. The tokens set of unlikelihood loss is automatically adjusted by maximizing the average F1 score on different frequency tokens.
RGRG [35] adopts GPT2 as the language generation model and generates a report based on the localized visual features of anatomical regions, which are extracted by an object detection. This baseline experiment was specifically carried out on the MIMIC-CXR dataset, as the IU X-ray dataset lacks anatomical region information, resulting in the inability to train an object detection module effectively.
Imbalance Setting
We evaluate model robustness under token and label imbalance settings and present results in Sects. 6.2 and 6.3. For token imbalance, we compare F1 scores of frequent and infrequent tokens separately. We introduce three different scales to define frequency token sets, 1/4, 1/6, and 1/8 respectively. The splits define the top 1/4, 1/6, and 1/8 vocabulary as frequent tokens and the rest vocabulary as infrequent tokens. The setting is to demonstrate the effectiveness of our approach in adapting token imbalance. For label imbalance, we divide our samples into a binary category, normal and abnormal.Table 2. Overall performanceDatasetModelNLG metricsCE metricsBLEU-1BLEU-2BLEU-3BLEU-4METEORROUGE-LF1IU X-rayR2Gen48.8031.9323.2417.7220.2137.1063.62CMN45.5329.5021.4716.5318.9936.7864.83WCL44.7429.3021.4916.7920.4537.1149.24CMM + RL49.4030.0821.4516.1020.1038.4040.79RRG49.9631.4422.1117.0518.8133.4649.10TIMER49.3432.4923.8418.6120.3838.2594.52JIMA (Ours)50.50****33.12****24.15****18.88****21.16****38.56****96.58 \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\overline{\Delta }$$\end{document} (%) \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$^\textrm{B}$$\end{document} 5.497.748.6510.446.864.8672.10 \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\hat{\Delta }$$\end{document} (%) \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$^\textrm{B}$$\end{document} 2.351.931.301.453.820.812.18MIMIC-CXRR2Gen35.4221.9914.5010.3013.7527.2454.60CMN35.6021.4114.079.9114.1827.1450.50WCL37.3023.1315.4910.7014.4027.3955.58CMM+RL35.3521.8014.8210.5814.2027.3765.43RRG37.5719.7815.879.5614.7726.8162.20TIMER38.3022.4914.6010.4014.7028.0075.86RGRG30.720.5914.1010.1815.4324.0380.28JIMA (Ours)41.37****24.83****16.72****11.20****16.75****30.15****81.25 \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\overline{\Delta }$$\end{document} (%) \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$^\textrm{B}$$\end{document} 16.2615.2413.349.5915.7312.5231.29 \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\hat{\Delta }$$\end{document} (%) \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$^\textrm{B}$$\end{document} 8.0210.4014.527.6913.947.681.21 \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\overline{\Delta }$$\end{document} is the averaged percentage improvements over baselines, and the \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\hat{\Delta }$$\end{document} refers to the percentage improvements over the state-of-the-art model. The evaluation includes both generation (NLG) and clinical-correctness (CE) metrics, where bold numbers indicate the best performance
Results and Analysis
In this section, we present overall performance and report results of imbalance evaluations and include an ablation analysis and a case study. Generally, JIMA outperforms the state-of-the-art baselines by a large margin, especially under imbalance settings. Our qualitative studies show our method can achieve more clinically accuracy and generate more precisely clinical terms.
Overall Performance
Table 2 presents the performance of JIMA by NLG and clinical-correctness metrics. JIMA outperforms baseline models (both imbalance and regular methods) on BLEU scores by a large margin, confirming the validity of selecting training samples by our curriculum learning method. The approach enables the model to learn multiple times from the samples with lower BLEU-4, resulting in a better performance compared to the baseline models. For example, JIMA shows an improvement of 16.59% on average for IU X-ray and 16.28% for MIMIC-CXR. We infer this is as our tasks 1 and 2 jointly work to improve the token and label imbalanced problem.
Second, our model achieves the best performance in F1 of the clinical metric, which indicates that task 1 (Sect. 4.1) can enable the model to put more attention on difficult samples with lower F1 scores. Additionally, our method promotes clinical token prediction as performance on infrequent tokens and medical terms has been improved. For example, our generation significantly outperforms the baselines on F1 score by 72.10% on IU X-ray and 31.29% on the MIMIC-CXR average. CMN \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$+$$\end{document} RL performs better than other baselines on IU X-ray but not on MIMIC-CXR. JIMA maintains a stable performance on both IU X-ray and MIMIC-CXR. We infer this as our joint imbalance adaptation can yield more improvements.
Token Imbalance
Table 3 compares high- and low-frequent tokens F1 in different ratio splits. Our method consistently outperforms baselines in the low-frequent tokens across frequency splits ( \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$\frac{1}{4}, \frac{1}{6}$$\end{document} , and \documentclass[12pt]{minimal} \usepackage{amsmath} \usepackage{wasysym} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{amsbsy} \usepackage{mathrsfs} \usepackage{upgreek} \setlength{\oddsidemargin}{-69pt} \begin{document}$$ \frac{1}{8}$$\end{document} ) on IU X-ray and MIMIC-CXR. While RRG and CMN + RL approaches have adapted label imbalance, the approaches may not be able to adapt the token imbalance. Our approach achieves better performance on the token imbalance. Generating rare tokens with accuracy remains a difficult task despite the high performance achieved on frequent tokens. Common tokens are prone to overfitting while rare tokens are predicted with less precision. For example, the 0.00 score by R2GEN on 3/4 split of the MIMIC-CXR vocabulary. Performance imbalance can deteriorate the clinical correctness of generated reports as medical terminologies are usually infrequent. Nonetheless, our joint imbalance adaptation approach has shown considerable improvements in this area, indicating a promising direction to enhance the robustness of radiology report generation, a critical clinical task.Table 3. Results on high- and low-frequent tokens with three ratio splitsIU X-rayMIMIC-CXR RatioMethodinfreqfreqinfreqfreq1/8R2GEN4.4662.732.5252.01CMN5.8855.862.2345.60WCL5.2960.232.9148.60CMN + RL5.1949.360.2123.64RRG7.2841.942.5043.57TIMER13.2361.893.1552.66RGRG--0.2231.33JIMA (ours)14.8762.553.58****53.061/6R2GEN2.8061.622.0249.86CMN5.7565.120.8552.02WCL3.7259.262.1347.88CMN + RL5.1949.360.1423.36RRG4.5540.462.0943.56TIMER5.9367.792.0251.72RGRG--0.2630.66JIMA (ours)10.52****68.82****2.83****52.321/4R2GEN1.1659.980.0048.77CMN2.6063.920.3351.09WCL1.5056.830.3046.95CMN + RL5.1949.360.0723.05RRG2.0438.840.3941.45TIMER8.6664.000.5851.39RGRG--0.2029.56JIMA (ours)9.77****66.23****0.94****51.92We measured the model performance on frequent and infrequent tokens by F1 score
Label Imbalance
We report NLG evaluations on label imbalance (normal vs. abnormal) in Table 4. JIMA significantly outperforms baseline models both on normal and abnormal splits, which demonstrates its effectiveness under label imbalance. JIMA also performs better than the label imbalance methods, RRG and CMM+RL, indicating that the joint imbalance adaptation is a promising direction to improve model robustness. It is worth noting that models generally perform better on normal samples than on abnormal ones. We infer this for two reasons: (1) abnormal reports contain more infrequent medical tokens and (2) abnormal reports are longer, as discussed in Sect. 3. JIMA shows more improvements on abnormal samples over baselines while maintains a similar performance on samples with normal labels. The observations suggest that our approach can successfully learn from lengthier documents with more medical tokens.Table 4. Label imbalance evaluation with binary label types, normal and abnormalDatasetLabelModelBLEU-1BLEU-2BLEU-3BLEU-4METEORROUGE-LIU X-rayNormalR2Gen50.5034.9125.8620.9323.6640.56CMN47.4232.8025.2518.7220.5138.69WCL49.7435.4428.0218.7126.8842.09CMM+RL51.6836.6521.9919.4724.5340.05RRG50.0333.7624.8119.8920.4334.39TIMER51.8332.4333.7120.1924.4339.39JIMA (ours)52.65** 37.0628**.**3921**.56****27.20****42.33AbnormalR2Gen42.6727.8618.4712.3515.0430.10CMN35.0921.4214.9711.3214.3629.85WCL32.3119.9313.8710.5013.8130.37CMM+RL38.0925.4211.1715.0913.1327.64RRG43.3823.4410.0215.5812.4331.52TIMER44.2526.7315.2810.7615.4333.26JIMA (ours)45.4127.9519.15****15.68** 16.3634**.59MIMIC-CXRNormalR2Gen40.4226.7619.7515.6017.5832.02CMN41.4227.8020.2515.7217.5133.69WCL39.7425.4418.0213.7116.8832.09CMM+RL17.5010.116.8314.998.0519.10RRG38.7821.6318.0412.0918.2727.56TIMER40.3327.5319.8814.8717.4733.08RGRG32.0922.6716.4012.3018.2627.28JIMA (ours)** 41.7927.8720**.**4916**.00****17.93****33.87AbnormalR2Gen33.9719.3112.0710.1710.9826.82CMN33.0019.4410.028.7310.2125.16WCL34.5622.4514.6310.2612.4326.87CMM+RL27.7410.875.183.436.1116.08RRG17.479.715.783.748.3717.59TIMER35.6621.8314.2514.879.8426.77RGRG30.5420.3413.829.9215.1323.66JIMA (ours)37.81****22.46****15.26****10.28****14.56****27.38
Ablation Analysis
In this section, we carry out ablation experiments to analyze the impact of our curriculum learning approach on tokens and labels with different frequencies. To investigate the performance across different tokens, we categorize tokens into five groups based on their frequency, with “0” representing the most frequent tokens and “4” representing the least frequent tokens. Each group contains an equal number of tokens. In order to compare the performance across different labels, we present their performance individually. We conduct our ablation experiments on the MIMIC-CXR dataset, and the results are depicted in Fig. 4.Fig. 4. Ablation analysis for JIMA performance comparison with and without curriculum learning across various labels and tokens frequenciesTable 5Human evaluation on the state-of-the-art baseline and our approachDatasetLabelCMM+RLSameJIMA (Ours)IU X-rayNormal6 — 712 — 76 — 10Abnormal4 — 410 — 512 — 13MIMIC-CXRNormal6 — 715 — 77 — 11Abnormal5 — 610 — 77 — 16OverallNormal12 — 1427 — 1413 — 21Abnormal9 — 1020 — 1219 — 29All21 — 2447 — 2632 — 50The baseline does not consider the imbalance effects. To better illustrate the varying performance on the labels, we report performance conditioning on the normal and abnormal reports. “Same” means the experts vote the same for the generated reports
First, we notice that removing curriculum learning does not result in a significant detrimental impact on highly frequent tokens and labels. For instance, the performance is comparable in the “0” token group and the “0–5” label groups. Curriculum learning empowers the model to allocate increased attention to challenging samples, thereby reducing the likelihood of predictions on highly frequent samples. However, our curriculum learning strategy selects training samples based on the ranking of the correct answers. Therefore, despite the reduced probability of the correct answer, the ranking remains unchanged. For example, the correct option still holds the highest estimation). As a result, our curriculum learning approach does not diminish the performance on highly frequent samples. Next, our curriculum learning approach significantly enhances performance primarily on moderately frequent samples. The average improvement amounts to 6.49% in the “1–3” token group and 2.58% in the “6–10” label group. However, our method exhibits limitations in enhancing the performance of exceedingly rare tokens. Notably, the model struggles to predict tokens in the “4” group.
Human Evaluation
To verify the factual correctness, we invite two health professionals to perform the evaluation. First, we randomly select 50 test instances per data from IU X-ray and MIMIC-CXR, respectively. We choose CMM+RL as our targeting comparison, as the model is the best-performing baseline by automatic metrics. In evaluation, we show the X-ray images, corresponding ground truth reports, and two generated reports (one from our model and the other from CMM+RL) to the expert without disclosing their sources. The experts selected a better description from two candidate reports or chose the “Same” option if both reports are of similar quality.
We present our human evaluation results in Table 5, which shows a consistent result with automatic evaluation results. Generally, JIMA outperforms the baseline with 11 reports in total. Notably, our approach exhibits significant improvements in abnormal samples. Even though JIMA has only one more vote than the baseline in normal samples, our model secures ten more votes in abnormal samples. This is because abnormal samples have lengthier reports on average and encompass more medical entities, indicating that our approach generates more clinically precise reports. Furthermore, our human evaluation is consistent with the automated evaluation results shown in Table 2.
Case Study
To verify our model’s effectiveness in generating clinically correct descriptions, we perform a case study in this section and present the result in Fig. 5. We select four samples from IU X-ray and MIMIC-CXR datasets and compare the normal and abnormal samples’ performance separately. The correct pathological and anatomical entity predictions are remarked in blue color. Generally, our predictions cover more than 90% entities in reference reports. Compared to normal samples, abnormal samples have longer descriptions and contain more complex entities. These entities usually are rare in corpus and suffer under-fitting from models. Therefore, models underperform in abnormal samples. However, JIMA can capture most of the entities in all kinds of samples and achieve similar performance in both normal and abnormal samples, which proves our model’s effectiveness in improving the factual completeness and correctness of generated radiology reports.Fig. 5. Qualitative comparison between JIMA and CMM+RL. We highlight correct predictions of pathological and anatomical entities in blue color
Conclusion
In this study, we have examined the imbalance effects on the radiology report generation models from two imbalance factors, label and token. We developed the Joint Imbalance Adaptation (JIMA) model to encounter the multifaceted imbalance challenge. JIMA takes a novel curriculum learning approach to jointly learn imbalance patterns of the radiology token and image label by two modules, difficulty measurer and training scheduler. Extensive experiments, ablation analysis, and human evaluations show that JIMA leads to improvements over the existing state-of-the-art baselines between [0.81%, 3.82%] on IU X-ray and [1.21%, 14.52%] on MIMIC-CXR. Our approach also promotes model robustness in handling token and label imbalance, as shown in Tables 3 and 4. Particularly, our ablation analysis shows that JIMA does not significantly reduce performance on highly frequent tokens and labels, yet significantly improves performance for moderately frequent samples, and still exhibits some limitations in enhancing the performance of the rarest tokens. This study makes a unique contribution to the radiology report generation that jointly considers multiple imbalance factors via curriculum learning. Our future work will focus on refining the JIMA approach to address the limitations highlighted in our ablation study, including low model performance on the exceedingly rare tokens. Further exploration will also include a deeper analysis of other imbalance factors, such as different demographic groups.
Limitations
Limitations should be fully acknowledged before fully interpreting this study, as no research can be fully perfect. Rare tokens. Our approach has improved the model performance on the rare tokens but still keeps relatively lower F1 scores than the frequent tokens. For example, the JIMA model achieves F1 scores of 14.87 on the infrequent tokens versus 62.55 on the frequent tokens on the 1/8 ratio split of the IU-Xray. Our ongoing work will explore if other imbalance learning approaches (e.g., data augmentation [24]) or combining other approaches with our JIMA can achieve better performance. Evaluation. We are aware of other evaluation metrics, such as RadGraph [31] and CheXpert [37]. However, additional metrics may only be applicable to the MIMIC-CXR or have overlapped with our existing method, such as CheXpert and CheXbert [34]. We have included diverse metrics, including NLG, clinical correctness, and human evaluations. To keep consistency with our state-of-the-art baselines, we utilize a similar evaluation schema. Having consistent observations between our human and automatic evaluations may also prove our evaluation’s validity.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1Wang Z, Liu L, Wang L, Zhou L (2023) Metransformer: radiology report generation by transformer with multiple learnable expert tokens. In: Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition (CVPR), pp 11558–11567. https://openaccess.thecvf.com/content/CVPR 2023/html/Wang_ME Transformer_Radiology_Report_Generation_by_Transformer_With_Multiple_Learnable_Expert_CVPR_2023_paper.html
- 2Wu Y, Huang I-C, Huang X (2023) Token imbalance adaptation for radiology report generation. In: Mortazavi BJ, Sarker T, Beam A, Ho JC (eds) Proceedings of the conference on health, inference, and learning, vol 209, pp 72–85. PMLR, Boston, MA. https://proceedings.mlr.press/v 209/wu 23a.html
- 3Endo M, Krishnan R, Krishna V, Ng AY, Rajpurkar P (2021) Retrieval-based chest x-ray report generation using a pre-trained contrastive language-image model. In: Roy S, Pfohl S, Rocheteau E, Tadesse GA, Oala L, Falck F, Zhou Y, Shen L, Zamzmi G, Mugambi P, Zirikly A, Mc Dermott MBA, Alsentzer E (eds) Proceedings of machine learning for health, vol 158, pp 209–219. PMLR, Virtual. https://proceedings.mlr.press/v 158/endo 21a.html
- 4Jeong J, Tian K, Li A, Hartung S, Adithan S, Behzadi F, Calle J, Osayande D, Pohlen M, Rajpurkar P (2024) Multimodal image-text matching improves retrieval-based chest x-ray report generation. In: Oguz I, Noble J, Li X, Styner M, Baumgartner C, Rusu M, Heinmann T, Kontos D, Landman B, Dawant B (eds) Medical imaging with deep learning. proceedings of machine learning research, vol. 227, pp 978–990. PMLR, Paris, France. https://proceedings.mlr.press/v 227/jeong 24a.html
- 5Zhou T, Wang S, Bilmes J (2020) Curriculum learning by dynamic instance hardness. In: Larochelle H, Ranzato M, Hadsell R, Balcan MF, Lin H (eds) Advances in neural information processing systems, vol 33, pp 8602–8613. Curran Associates, Inc., Virtual. https://proceedings.neurips.cc/paper_files/paper/2020/file/62000 dee 5a 05a 6a 71de 3a 6127 a 68778 a-Paper.pdf
- 6Jain S, Agrawal A, Saporta A, Truong S, Duong DN, Bui T, Chambon P, Zhang Y, Lungren M, Ng A, Langlotz C, Rajpurkar P, Rajpurkar P (2021) Radgraph: extracting clinical entities and relations from radiology reports. In: Vanschoren J, Yeung S (eds) Proceedings of the neural information processing systems track on datasets and benchmarks, vol 1. virtual. https://datasets-benchmarks-proceedings.neurips.cc/paper/2021/file/c 8ffe 9a 587b 126f 152ed 3d 89a 146b 445-Paper-round 1.pdf
- 7Song X, Zhang X, Ji J, Liu Y, Wei P (2022) Cross-modal contrastive attention model for medical report generation. In: Proceedings of the 29th international conference on computational linguistics, pp 2388–2397. International Committee on Computational Linguistics, Gyeongju, Republic of Korea. https://aclanthology.org/2022.coling-1.210
- 8Tanida T, Müller P, Kaissis G, Rueckert D (2023) Interactive and explainable region-guided radiology report generation. In: Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition (CVPR), Vancouver, Canada, pp 7433–7442. https://openaccess.thecvf.com/content/CVPR 2023/papers/Tanida_Interactive_and_Explainable_Region-Guided_Radiology_Report_Generation_CVPR_2023_paper.pdf
