TransMed: Transformers Advance Multi-Modal Medical Image Classification

Over the past decade, convolutional neural networks (CNN) have shown very competitive performance in medical image analysis tasks, such as disease classification, tumor segmentation, and lesion detection. CNN has great advantages in extracting local features of images. However, due to the locality of convolution operation, it cannot deal with long-range relationships well. Recently, transformers have been applied to computer vision and achieved remarkable success in large-scale datasets. Compared with natural images, multi-modal medical images have explicit and important long-range dependencies, and effective multi-modal fusion strategies can greatly improve the performance of deep models. This prompts us to study transformer-based structures and apply them to multi-modal medical images. Existing transformer-based network architectures require large-scale datasets to achieve better performance. However, medical imaging datasets are relatively small, which makes it difficult to apply pure transformers to medical image analysis. Therefore, we propose TransMed for multi-modal medical image classification. TransMed combines the advantages of CNN and transformer to efficiently extract low-level features of images and establish long-range dependencies between modalities. We evaluated our model on two datasets, parotid gland tumors classification and knee injury classification. Combining our contributions, we achieve an improvement of 10.1% and 1.9% in average accuracy, respectively, outperforming other state-of-the-art CNN-based models. The results of the proposed method are promising and have tremendous potential to be applied to a large number of medical image analysis tasks. To our best knowledge, this is the first work to apply transformers to multi-modal medical image classification.


I. INTRODUCTION
Transformers were first applied in the field of natural language processing (NLP) [1].It is a deep neural network mainly based on the self-attention mechanism to extract intrinsic features.Because of its powerful representation capabilities, researchers hope to find a way to apply transformers to computer vision tasks.Compared with text, images involve larger size, noise, and redundant modalities, so it is considered more difficult to use transformers on these tasks.Recently, transformers have made a breakthrough in computer vision.A large number of transformer-based methods have been proposed for computer vision tasks, such as DETR [2] for object detection, SETR [3] for semantic segmentation, ViT [4] and DeiT [5] for image recognition.
Transformers have achieved success in natural images, but it has received little attention in medical image analysis, especially in multi-modal medical image fusion.Multimodal images are widely used in medical image analysis to achieve lesion segmentation or disease classification.The existing medical image multi-modal fusion based on deep learning can be divided into three categories: input-level fusion, feature-level fusion, and decision-level fusion [6].Input-level fusion strategy fuses multi-modal images into the deep network by multi-channel, learns fusion feature representation, and then trains the network.Input-level fusion can retain the original image information to the maximum extent and learn the image features.Feature-level fusion strategy trains a single deep network by taking the image of each modality as a single input.Each representation is fused in the network layer, and the final result is fed to the decision layer to obtain the final result.Feature-level fusion network can effectively capture the information of different modalities of the same patient.Decision-level fusion integrates the output of each network to obtain the final result.Decision-level fusion network aims to learn more abundant information from different modalities independently.
However, they all have shortcomings in varying degrees.The input-level fusion strategy is difficult to establish the internal relationship between different modalities of the same patient, which leads to the degradation of the model performance.Each modality of the feature-level network corresponds to a neural network, which brings huge computational costs, especially in the case of a large number of modalities.The output of each modality of decision-level fusion is independent of each other, so the model cannot establish the internal relationship between different modalities of the same patient.In addition, like decision-level fusion strategy, decision-level fusion strategy is also computationally intensive.
Therefore, there is an urgent need to combine the three fusion strategies efficiently.A good multi-modal fusion strategy should achieve as much interaction between different modalities as possible with low computational complexity.
Compared with CNN, transformers can effectively mine long-range relationships between sequences.The existing computer vision models based on transformer mainly deal with 2D natural images, such as ImageNet [6] and other large-scale datasets.The method of constructing sequences in 2D images is to cut the images into a series of patches.This kind of sequence construction method implicitly shows longrange dependencies, which is not very intuitive, so it may be difficult to bring significant performance improvement.
On the contrary, there are more explicit sequences in medical images, which contain important long-range dependency and semantic information, as shown in Fig. 1.Due to the similarity of human organs, most visual representations are orderly in medical images.Destruction of these sequences will significantly reduce the performance of the model.It can be considered that compared with natural images, the sequence relationship of medical images (such as modality, slice, patch) holds more abundant information.In practice, doctors will synthesize the pathological information of each modality to make the diagnosis.However, the existing multimodal fusion methods are too simple to consider the correlation of these sequences, and lack of modeling for these longrange dependencies.The transformer is an elegant, efficient, and powerful encoder for processing sequence relations, which is the motivation for us to propose the multi-modal medical image classification method based on transformers.
In this work, we present the first study to explore the tremendous potential of transformers in the context of multimodal medical image classification.The proposed method is inspired by the property that the transformer is effective in extracting the relationship between sequences.However, due to the small scale of medical image datasets and the lack of sufficient information to establish the relationship between low-level semantic features, the performance of pure transformer networks based on ViT and DeiT is not satisfactory in multi-modal medical image classification.Therefore, we propose TransMed, which combines the advantages of CNN and transformer to capture low-level features and crossmodality high-level connections.TransMed first processes the multi-modal images as sequences and sends them to CNN, then uses transformers to learn the relationship between the sequences and make predictions.Since the transformer effectively models the global features of multi-modal images, TransMed outperforms the existing multi-modal fusion methods in terms of parameters, operation speed, and accuracy.A large number of experiments have proved the effectiveness of our method.
In summary, we make the following three contributions: 1) We apply transformers to medical image classification for the first time, and greatly improve the accuracy and efficiency of deep models.2) We propose a novel multi-modal image fusion strategy in this work, which can be leveraged to capture mutual information from images of different modalities in a more efficient way.3) Experimental evaluations demonstrate that the proposed method achieves the most advanced performance in the classification of the parotid gland tumors.The rest of this paper is organized as follows.Section II presents some closely related works.The pipeline of our proposed method is in Section III.Section IV introduces the experimental results and details.Finally, we summarize our work in Section V.

II. RELATED WORK A. MULTI-MODAL MEDICAL IMAGE ANALYSIS
Multi-modal medical analysis is one of the most fundamental and challenging parts of medical image analysis.It is proved that a reasonable fusion of different modalities has been a potential means to enhance Deep networks [6].Multi-modal fusion can capture more abundant pathological information and improve the quality of diagnosis.
[8]- [10] mainly used the input-level fusion, which is the most common fusion method in multi-modal medical image analysis.Some other papers have shown the potential of feature-level fusion in medical image processing.Hyper DenseNet built dual deep networks for different modalities of Magnetic resonance imaging (MRI) and linked features across these streams [11].[12] fused final features from modality-specific paths to make final decisions.MMFNet used specific encoders to capture modality-specific features and designs a decoder with a complex structure to fuse these features [13].Different from the first two techniques, [14], [15] applied decision-level fusion technology to improve performance.[15] set three modality-specific encoders to capture low-level features and a decoder to fuse low-level and high-level features, then the results of each branch were fused to generate the final result.[14] designed a gate network to dynamically combine each decision and make a prediction.
Besides, some studies have evaluated multiple fusion methods at the same time.[16] used feature-level fusion and decision-level fusion in their work.[17] designed three kinds of fusion networks, and gets better performance than a single modality.These fusion methods improve the performance of the model to a certain extent, but there are some shortcomings, such as poor scalability, large computational complexity, and difficulty in establishing long-range connections.

B. TRANSFORMERS
Transformers were first proposed for machine translation and achieved satisfactory results in a large number of NLP tasks.Then, the transformer structures were introduced into the field of computer vision, and some modifications were made according to the specific tasks.The results show the potential of transformers to surpass pure CNN.Some work uses the framework of CNN and transformer [2], [18], while others directly use pure transformers to replace CNN [2], [4], [5], [19].These methods have shown encouraging results in computer vision tasks, but their direct applications in multimodal medical images are not effective and require a lot of computing resources.As far as we know, TransMed is the first multi-modal medical image classification framework based on transformers, which provides a novel multi-modal image fusion strategy.

III. METHODS
The most common method of multi-modal medical image classification is to train CNN directly (such as Resnet [20]).Firstly, the image is encoded as a high-level feature representation, and then its features or decisions are fused.Different from the existing methods, our method uses transformers to introduce the self-attention mechanism into the multimodal fusion strategy.We will first introduce how to directly apply transformers to aggregate feature representations from decomposed image patches in Section 3.A.Then, the overall framework of TransMed will be described in detail in Section 3.B.

A. TRANSFORMERS AGGREGATE MULTI-MODAL FEATURES
In this work, we follow the original DeiT implementation as much as possible.The advantage of this intentionally simple setting is to reduce the impact of other tricks on the performance of the model and intuitively show the benefits of transformers.In addition, we can use the extensible DeiT model and its pre-trained weights almost immediately.The important components of the transformer including self-attention (SA), multi-head self-attention (MSA), and multi-layer perception (MLP).The input of transformers includes a variety of embeddings and tokens.Slightly different from DeiT, we remove the linear projection layer and distillation token.We will describe each of these components in this section.

1) Self-Attention
SA is an attention mechanism, which uses other parts of the same sample to predict the rest of the data sample.In computer vision, it is a little similar to non-local networks [21].SA has many forms, and the common transformer relies on the form of scaled dot-product shown in Figure 3.In the SA layer, the input vector is first transformed into three different vectors: query matrix Q, key matrix K, and value matrix V, the output is the weighted sum of the value vectors.The weight assigned to each value is determined by the dot product of the query and the corresponding key.The attention function between different input vectors is calculated as follows: Where d k is the dimension of key vector k.√ d k provides an appropriate normalization to make the gradient more stable.
2) Multi-head Self-Attention MSA is the core component of the transformer.As shown in Figure 4, The difference from SA is that the multi-head mechanism splits the input into many small parts, then calculates the scaled dot-product of each input in parallel, and splices all the attention outputs to get the final result.The formula of MSA can be written as follows: Where the projections W Q i , W K i , W V i and W O are trainable parameter matrices, h is the number of transformer layers.The advantage of MSA is that it allows the model to learn sequence and location information in different representation subspaces.

3) Multi-Layer Perceptron
In this paper, a MLP is added on top of the MSA layer.The MLP is composed of linear layers separated by a GeLU [22] activation.Both MSA and MLP have skip-connections like residual networks and with a layer normalization.Therefore, it is assumed that the representation of the t − 1 layer is x t−1 , LN represents the linear normalization, and the output of the t layer can be written as follows: x t = M LP (LN ( xt )) + xt (5)

4) Embeddings and Tokens
The input layer contains five embeddings and tokens, which are patch embedding, position embedding, class embedding, patch token, and class token.Patch embedding is the representation of each patch's output from CNN, and class embedding is a trainable vector.To encode the spatial information and location information of a patch into patch tokens, we use position embeddings and patch embeddings to preserve the information.Class embedding does not have patch embedding that can be added, so class token and class embedding are equivalent.Suppose the input is x, the trainable vector is W c , the position embedding is x po , patch tokens x pt and class token x ct can be expressed as follows: The class token is attached to patch tokens before the input layer of transformers, passes through the transformer layer, and then outputs from the fully connected layer to predict the class.

B. TRANSMED
The structure of TransMed is shown in Figure 2. Instead of using pure transformers as the encoder, TransMed adopts a hybrid model including CNN and transformer, in which CNN is used as a low-level feature extractor to generate the patch embedding.
Given a multi-modal image x ∈ R B×M ×C×D×H×W , where spatial resolution is H × W , the depth is D, the number of channels is C, the number of modalities is M , and the batch size is B. Before sending it to the CNN encoder, it is necessary to construct the sequence.Firstly, three adjacent 2D slices of a multi-modal image are superimposed to construct three-channel images.Then, according to [4], each image will be divided into K × K.The larger K value means that the size of each patch is smaller.We will evaluate the impact of different K values on the performance of the model in Section 4.E.Finally, the image is encoded into a After the image sequence is constructed, it is input into the 2D CNN.The last fully connected layer of 2D CNN is replaced by a linear projection layer to map the features of the vector patch to the potential embedding space.2D CNN extracts low-level features from the image sequence and encodes them preliminarily.The output shape is (B, 1  3 M CDK 2 , P ), in which the size of P is set to adapt to the input size of the transformer.

A. DATASET AND PREPROCESSING 1) DATASET
We use a dataset collected in cooperative hospitals to evaluate the performance of our proposed method under multi-modal images.This dataset included two modalities of MRI (T1 and T2) of 344 patients, and the ground truth labels are obtained from biopsies.
The incidence of malignant tumors in parotid gland tumors is about 20% [23].Correct preoperative diagnosis of these tumors is essential for proper surgical planning.Among them, imaging examination plays an important role in determining the nature of parotid gland masses.MRI is considered to be the preferred imaging method for preoperative diagnosis of parotid tumors [24].MRI can provide information about the exact location of the lesion, the relationship with the surrounding structure, and can assess the spread of nerves and bone invasion.However, it is reported that parotid gland tumors show considerable overlap in imaging features (such as tumor margins, homogeneity, and signal intensity), so it is difficult for doctors to identify the mass.

2) PREPROCESSING
First, perform OTSU [26] to extract the foreground area in the original image.Then the images of different modalities of the same patient are registered to improve the consistency of the foreground area.Then resample each image to (18,448,448).Therefore, 344 images are finally included, each of which is a stack of 3D images of MRI T1 and T2, and the size is (36, 448, 448).Data augmentation uses random flipping and random noise.Random flipping performs flipping of the image with 50% probability.Random noise adds Gaussian noise with a mean value of 0 and a variance of 0.1 to the image.

B. EXPERIMENTAL SETTINGS AND EVALUATION CRITERIA
The patients were randomly divided into the training group (n = 275) and independent test group (n = 69) according to the ratio of 4:1, and then the training group was used to optimize the model parameters.We set SGD as the optimizer with a learning rate equal to 10 −3 and momentum equal to 0.7.The maximum training round is set to 100.Our experiments were performed on NVIDIA 3080 GPU (with 10GB GPU memory).The code is implemented using Pytorch [27] and TorchIO [28].To eliminate accidental factors, each model is subjected to 10 independent experiments, and each experiment is randomly divided into the training group and the test group.Besides, other experimental parameters keep consistent during training.
The evaluation criteria for each model are the overall accuracy rate ACC(i) and the precision rate of each category P (i), as defined in the following: Where T is the total number of samples and T c is the total number of samples with the correct prediction.
Where T ic is the total number of samples correctly predicted as class i, and T if is the total number of samples that are wrongly predicted as class i. P (i) can describe the stability and robustness of the model in small datasets.

C. BASELINE METHODS
The input-level fusion strategy can be easily implemented using mainstream 2D CNN and 3D CNN, so the selected network includes Resnet34, Resnet152, 3D Resnet34, P3D [29], and C3D [30].In feature-level fusion experiments, we used two common feature-level fusion methods [11], [12].Since these two papers focus on segmentation tasks, we modify the network structure to adapt to the classification tasks.The deep network used in the decision-level fusion experiments is the same as the input-level strategy.

D. RESULTS
Table.TransMed consistently outperforms previous multi-modal fusion strategies by a large margin.TransFuse-S achieves on average about 12.8% improvement in terms of the average accuracy with respect to the P3D while the larger version TransMed-B and TransMed-L slightly suffer from overfitting on the dataset.Table 1 also compares the number of parameters and computational costs between our proposed models and previous methods.TransMed achieves state-of-the-art performance with much fewer parameters and computational costs.TransMed is highly efficient as it models the longrange relationship between modalities very well.We expect that our method can inspire further exploration of multimodal medical image fusion in future work.

E. ABLATION EXPERIMENTS
To demonstrate the effect of transformers in TransMed, we conducted ablation experiments.For TransMed, changing the backbone from TransMed-T to TransMed-S results in 1.9% improvement in average accuracy, at the expense of a much larger computational cost.Therefore, considering the computation cost, all experimental comparisons in this paper are conducted with TransMed-T to demonstrate the effectiveness of TransMed.
In the experiment, TransMed's CNN and transformers were removed respectively, and all other conditions remained unchanged.The results are shown in Table 2.The results indicate that the transformer greatly improves the ability of the deep model to explore the relationship between modalities with little increase of parameters and computation.However, the performance of the pure transformer structure is poor due to the small dataset.
We also explored the impact of different patch sizes on performance in image serialization by changing K values respectively while other conditions remain unchanged.The results are shown in Table 3.The experimental results show that the performance is poor when the K value is large.The possible reason is that too small image patches destroy the semantic information of the image.

V. CONCLUSION
The transformer is a powerful deep neural network structure for processing sequences in NLP, but it has received little attention in medical image analysis.In this paper, we propose TransMed, which is a novel design of multi-modal medical image classification based on transformers.TransMed has achieved very competitive results in challenging parotid tumor classification.TransMed is easy to implement and has a flexible structure, which can be extended to multiple medical image modalities with low resource cost.
These preliminary results are encouraging, but there are still many challenges.One is to apply TransMed to other medical image analysis tasks, such as tumor segmentation and lesion detection.Another challenge is to use the pure transformer structure.Pure transformer structure has been successful in large-scale natural image datasets.However,  our preliminary experiments show that there is still a big gap between the pure transformer and typical CNN in small medical image datasets.We expect future work to further improve TransMed.

FIGURE 1 .
FIGURE 1.Compared with natural images, multi-modal medical images have more informative sequences.

FIGURE 2 .
FIGURE 2. Overview of TransMed, which is composed of CNN branch and transformer branch.

FIGURE 3 .
FIGURE 3. Overview of self-attention, matmul means matrix product of two arrays.

FIGURE 4 .
FIGURE 4.An illustration of our multi-head self-attention component, concat means concatenate representations.
1 reports the performance of our proposed models, in which four variants are provided: the tiny version (TransMed-T) use ResNet18 and DeiT-Tiny (DeiT-T) as backbones for CNN branch and transformer branch, respectively; the small version (TransMed-S) use ResNet34 and DeiT-Small (DeiT-S) as backbone; the base version (TransMed-B) use ResNet50 and DeiT-Base (DeiT-B) as backbone; the large version (TransMed-L) use ResNet152 and DeiT-B.

TABLE 1 .
Comparison on the parotid gland tumors dataset (average accuracy % and precision % for each disease.IF, FF and DF represent input-level fusion, feature-level fusion and decision-level fusion, respectively.).

TABLE 3 .
Ablation study on different patch sizes.