Explainable Transformer-Based Deep Learning Model for the Detection of Malaria Parasites from Blood Cell Images

Malaria is a life-threatening disease caused by female anopheles mosquito bites. Various plasmodium parasites spread in the victim’s blood cells and keep their life in a critical situation. If not treated at the early stage, malaria can cause even death. Microscopy is a familiar process for diagnosing malaria, collecting the victim’s blood samples, and counting the parasite and red blood cells. However, the microscopy process is time-consuming and can produce an erroneous result in some cases. With the recent success of machine learning and deep learning in medical diagnosis, it is quite possible to minimize diagnosis costs and improve overall detection accuracy compared with the traditional microscopy method. This paper proposes a multiheaded attention-based transformer model to diagnose the malaria parasite from blood cell images. To demonstrate the effectiveness of the proposed model, the gradient-weighted class activation map (Grad-CAM) technique was implemented to identify which parts of an image the proposed model paid much more attention to compared with the remaining parts by generating a heatmap image. The proposed model achieved a testing accuracy, precision, recall, f1-score, and AUC score of 96.41%, 96.99%, 95.88%, 96.44%, and 99.11%, respectively, for the original malaria parasite dataset and 99.25%, 99.08%, 99.42%, 99.25%, and 99.99%, respectively, for the modified dataset. Various hyperparameters were also finetuned to obtain optimum results, which were also compared with state-of-the-art (SOTA) methods for malaria parasite detection, and the proposed method outperformed the existing methods.


Introduction
The World Health Organization states that about 438,000 and 620,000 people died from malaria in 2015 and 2017, respectively, whereas 300 to 500 million people are infected by malaria [1]. Malaria virus transmission is influenced by weather conditions that are suitable for a mosquito to live for extended periods, where environmental temperatures are high enough, particularly after rain. For that reason, 90% of malaria cases occur in Africa, and cases are also frequent in humid areas, such as Asia and Latin America [2][3][4]. If the disease is not treated at the early stages, this may even lead to death. The usual process for detecting malaria starts with collecting blood samples and counting the parasites and red blood cells (RBCs). Figure 1 shows images of RBCs both uninfected and infected by the malaria parasite. This process needs medical experts to collect and examine millions of blood samples, which is costly, time-consuming, and error-prone processes [5]. There are two traditional approaches for detecting malaria: one is very time-consuming because it needs to identify at least 5,000 RBCs, and another is an antigen-based fast diagnostic examination that is very costly. To overcome the limitations of the traditional approaches, in the last few years, researchers have focused on solving this problem using several machine learning and deep learning algorithms. A number of studies have been carried out recently to identify malaria using image analysis by artificial intelligence (AI). Bibin et al. proposed a deep belief network (DBN) to detect malaria parasites (MPs) in RBC images [6]. They used 4100 images for training their model and achieved a specificity of 95.92%, a sensitivity of 97.60%, and an F-score of 89.66%. Pandit and Anand detected MPs from the RBC images using an artificial neural network [7] using 24 healthy RBC and 24 infected RBC images in order to train their model and obtained an accuracy of between 90% and 100%. Jain et al. used a CNN model to detect MPs from RBC images [8] without using GPU and preprocessing techniques while providing a low-cost detection algorithm, which achieved an accuracy of 97%. Rajaraman et al. pretrained CNN models for extracting the features from 27,558 RBC cell images to detect MPs and achieved an accuracy of 92.7% [5]. Alqudah et al. developed a lightweight CNN to accurately detect MPs using RBC images [9]. They trained their model using 19,290 images with 4134 test data and achieved an accuracy of 98.85%. Sriporn et al. used six transfer learning models (TL): Xception, Inception-V3, ResNet-50, NasNetMobile, VGG-16, and AlexNet to detect MPs [10]. Several combinations of activation function and optimizer were employed to improve the model's effectiveness. A combined accuracy of 99.28% was achieved by their models trained with 7000 images. Fuhad et al. proposed an automated CNN model to detect MPs from RBC images [11] and performed three training techniques-general, distillation, and autoencoder training-to improve model accuracy after correctly labeling the incorrectly labeled images. Masud et al. proposed leveraging the CNN model to detect MPs using a mobile application [12] and a cyclical stochastic gradient descent optimizer and achieved an accuracy of 97.30%. Maqsood et al. developed a customized CNN model to detect MPs [13] with the assistance of bilateral filtering (BF) and image augmentation methods and achieved an accuracy of 96.82%. Umer et al. developed a stacked CNN model to predict MPs from thin RBC images and achieved an outstanding performance with an accuracy of 99.98%, precision of 100%, and recall of 99.9% [14]. Hung and Carpenter proposed a region-based CNN to detect the object from the RBC images [15]. The total accuracy using one-stage classification and two-stage classification was 59% and 98%, respectively. Pattanaik et al. suggested a methodology for detecting malaria from cell images using computer-aided diagnosis (CAD) [16]. They employed an artificial neural network with a functional link and sparse stacking to pretrain the system's parameters and achieved an accuracy of 89.10% and a sensitivity of 93.90% to detect malaria from a private dataset of 2565 RCB pictures gathered from the University of Alabama at Birmingham. Olugboja et al. used a support vector machine (SVM) and CNN [17] to obtain accuracies of 95% and 91.66%, respectively. Gopakumar et al. created a custom CNN based on a stack of images [18]. A two-level segmentation technique was introduced after the cell counting problem was reinterpreted as a segmentation problem. An accuracy of 98.77%, a sensitivity of 99.14%, and a specificity of 99.62% were achieved from the CNN focus stack model. Khan et al. used three machine learning (ML) models-logistic regression (LR), decision tree (DT), and random forest (RF)-to predict MPs from RBC images [19]. Firstly, they extracted the aggregated features from the cell images and achieved a high recall of 86% using RF. Fatima and Farid developed a computer-aided system (CAD) to detect MPs from RBC images [20] upon removing the noise and enhancing the quality of the images using the BF method. To detect the MPs, they used adaptive thresholding and morphological image processing and achieved an accuracy of 91%. Mohanty et al. used two models, autoencoder (AE) [21] and self-organizing maps (SOM) [22], to detect MPs and found that AE was better than SOM, which achieved an accuracy of 87.5% [23]. Dong et al. proposed three TL models, LeNet [24], AlexNet, and GoogLeNet [25], to detect MPs [26]. SVM was used to make a comparison with the TL models, which achieved an accuracy of 95%, which was more significant than the accuracy of 92% using the support vector machine (SVM). Anggraini et al. proposed a CAD to detect MPs from RBC images [27] with gray-scale preprocessing for stretching the contrast of the images and global thresholding to gain the different blood cell components from the images.
So far, many computerized systems have been proposed; most of them were based on traditional machine learning or conventional deep learning approaches, which provided satisfactory performances, but there is still scope for further improvement. After developing the vision transformer model [28], the attention-based transformer model has shown promising results in medical imaging, bioinformatics, computer vision tasks, etc. compared with the conventional convolution-based deep learning model. However, to date, no attention-based works have been carried out to detect malaria parasites. Again, the interpretability of a deep CNN model is a major issue. More recently, visualizing what a deep learning model has learned has attracted significant attention to the deep learning community. However, most previous works have failed to introduce the interpretability of the model for malaria parasite detection. To overcome these issues, in this work, an explainable transformer-based model is proposed to detect the malaria parasite from the cell image of blood smear images. Various hyperparameters, such as encoder depth, optimizer (Adam and stochastic gradient descent (SGD)), batch size, etc., were experimented with to achieve better performance. Two malaria parasite datasets (original and modified) were taken into consideration to conduct the experiments.
The key contributions of this paper are: (1) A multiheaded attention transformer-based model was implemented for the detection of malaria parasites for the first time. (4) The proposed model for malaria parasite detection was compared with SOTA models. Figure 2 shows the overall design of the proposed methodology. Firstly, the raw images were preprocessed, followed by dataset splitting into training and testing sets to build the model. Finally, to visualize the trained model, Grad-CAM was used to show the heatmap image.

Dataset Description
The dataset for malaria detection contains segmented RBC images. It is archived at the National Library of Medicine and is also openly accessible at "https://lhncbc.nlm.nih.gov/ LHC-publications/pubs/MalariaDatasets.html" (accessed on 10 May 2022). Rajaraman et al. [5] developed a segmentation process and implemented it for segmenting RBC images from thin blood smear images. The dataset has a total 27,588 RBC images, among which 13,779 are infected and 13,779 are uninfected images of the malaria parasite. A detailed distribution of the dataset is given in Table 1. This dataset was further studied by a medical expert in the research work conducted by Fuhad et al. [11]. They discovered some suspicious data in the dataset, including data that seemed to be infected but was labeled as uninfected, as well as data that appeared uninfected but was labeled as infected. The data that had been mislabeled was afterward manually annotated. These incorrectly labeled data were simply set aside during annotation, with 647 false infected and suspicious data and 750 false uninfected and suspicious data being eliminated. The updated dataset was uploaded to Google Drive [29], which is open to the public and also taken into account in this work. In both datasets, 20% of the images were used for testing purposes, and 80% of the images were used for training the proposed model.

Preprocessing (Resize)
The raw images of the dataset come in a variety of sizes. As the proposed model contains fully connected layers in the classifier layer, the model needs a fixed-sized input image [30]. Therefore, the raw images were resized into 96 × 96 × 3.

Model Architecture
Various attention-based models have been developed recently. To date, the vision transformer has most attracted researchers for computer vision tasks [28]. The compact convolutional transformer (CCT) is a slightly modified model from the vision transformer introduced in 2021 [31]. CCT with Grad-CAM visualization was implemented in this study to detect the malaria parasite. Figure 3 shows the model architecture of CCT.

Convolutional Block
The traditional convolutional layer and the ReLU(.) activation function was used. A 3 × 3 kernel size with stride 3 was used to make it a nonoverlapping slide. After that, a maxpool layer was used. Instead of a full input image in the transformer model, the input image is divided into patches/grid images, which are given to the transformer's encoder. In the proposed transformer-based model, the convolution filter was used for patching. Instead of patching images directly, these convolutional blocks took the input images to a latent representation that provides more flexibility than the vision transformer. Filters for the convolutional layer were employed to align with the vision transformer embedding dimension. Given an input image X ∈ R H×W×C where E is the number of filters = 768. After the convolutional layer, the output image was reshaped from R H ×W ×E to R N×E for converting it to the convolutional patches, where the number of sequences or patches N ≡ (H W ). This convolutional block maintains the locally spatial information. To keep tracking the position or sequence number of each patch, a learnable positional embedding was added.

Multiheaded Attention Mechanism
The main part of the compact convolutional transformer is the multiheaded selfattention (MSA). The whole part of an image is not necessary for extracting valuable information; the attention mechanism focuses on the valuable part. Various attention mechanisms have been developed so far. However, the multiheaded self-attention was first introduced in the vision transformer. Figure 4 shows the scaled dot-product-based multihead attention mechanism [32]. The input is projected to queries, keys, and values using different learnable weights with linear layers in self-attention.
Now queries, Q ∈ R N×d ; keys, K ∈ R N×d ; and values, V ∈ R N×D . In the case of a scaled dot-product form of attention, the dot product is computed between the queries and keys, which is scaled by √ d k . After that, nonlinear softmax function is imposed to obtain the attention weights.
This Z ∈ R N×N is the attention weight among the patches. This attention weight is then multiplied with values to obtain the self-attention weighted output H ∈ R N×d Therefore, the scaled dot-product attention function can be written in shorted form as below: Rather than performing a single attention function with d-dimensional queries, keys, and values, it is advantageous to linearly project queries, keys, and values to d h , d k , and d v dimensions, h times using different learnable weights with linear layers. After that, the scaled dot-product attention function is applied in parallel in all the h heads, resulting in h number of d v -dimensional values. These attention-weighted values are then concatenated and further projected with linear layers. The multiheaded attention mechanism helps the model attend to different parts from different representation subspaces. In this work, d h = d k = d v = d was applied, and the number of heads used was h = 8.
where W 0 ∈ R hd×d .

Transformer Encoder
The image patches come from the convolutional block and are passed through the transformer encoder. Firstly, layer normalization is applied to the image patches that normalize the activations along the feature direction instead of the mini-batch direction in batch normalization. Then multiheaded self-attention is applied to these normalized patches. The results are added with the residual connected original patches, as shown in Figure 4. Further layer normalization and a feed-forward block are imposed along with another residual connection. The feed-forward block has two linear layers along with a dropout layer and a GELU nonlinearity. The first linear layer expands the dimension four times, and the second linear layer reduces the dimension back (feed-forward block).
The outcomes from these two paths are added again. For this work, 16 transformer encoders were implied sequentially.

Sequence Pooling
In the vision transformer, a class token is used to classify the final output. However, in the compact convolutional transformer, instead of using a class token, sequence pooling is used. Sequence pooling pools over the entire sequence of data. Given the output of the last transformer encoder block X L ∈ R b×N×d , b is the mini-batch size, N is the number of sequences, X L is sent through a linear layer, and then the output X L from this linear layer is multiplied with X L .
The output F is then sent to the final linear classifier to classify the input data.

Grad-CAM Visualization
The gradient-weighted class activation map (Grad-CAM) is a technique to interpret what the model has actually learned [33]. This technique generates a class-specific heatmap using a trained deep learning model for a particular input image. This Grad-CAM approach highlights the input image regions where the model pays much attention to producing discriminative patterns from the last layer before the final classifier, as the last layer contains the most highly semantic features. Grad-CAM uses the feature maps from the last convolutional layer, providing the best discriminative semantics. Let y c be the class score for class c from the classifier before the SoftMax layer. Grad-CAM has three basic steps: Step-1: Compute the gradients of class score y C with respect to the feature maps A k of the last convolutional layer before the classifier, i.e., Step-2: To obtain the attention weights α c , global average pool the gradients over the width (indexed by i) and height (indexed by j).
Step-3: Calculate the final Grad-CAM heatmap by the weighted (α c ) sum of feature maps (A k ) and then apply the ReLU (.) function to retain only the positive values and turn all the negative values into zero.
Firstly, the proposed model was trained with the training samples from the dataset. After the training phase was completed, the trained model was used for evaluation with the testing parts of the dataset. In addition, to explain what the trained model had actually learned, the Grad-CAM technique explained above was applied. Various test images were selected randomly to generate the corresponding heatmap from the trained model using the Grad-CAM approach. In this case, the multilayer perceptron layer of the last transformer encoder before the final classifier was chosen as the target layer. Features and gradients were extracted from that layer, and a heatmap was generated using the above Grad-CAM formula. Subsequently, the heatmap was resized with nearest-neighbor interpolation as the same size as the input image, and the heatmap was overlaid with the input image. Figure 5 shows the original input images and their corresponding heatmap images. For the heatmap image conversion, a jet color map was used. It can be seen from the overlaid heatmap images that the lesion areas are much more reddish than the other regions of the image. These reddish areas are the main lesions responsible for the malaria parasite [34].
There is no existing segmentation dataset of RBC cell images with the parasite mask on the RBC cell image for quantitative analysis. Annotation made in the dataset used in this work was that normal RBC images come with a clean version without any lesions, but the parasite images come with lesions [34]. Based on the presence of these lesions, the RBC cell images were classified either as normal or parasite. To show explainability of the trained model, the CAM technique was applied to generate heatmap images that showed the actual parts (lesions) the model paid attention to during feature extraction and classification. This technique can bring new insights toward detecting MPs accurately.

Performance Evaluation Procedure
Pytorch python framework was used to conduct the whole experiment. The model was run on a highly computing GPU-supported Desktop PC with 11th Generation Intel (R) Core (TM) i9-11900 CPU @2.50GHz, 32 GB RAM, NVIDIA GeForce, and RTX 3090 24 GB GPU running on a 64-bit Windows 10 Pro operating system.
The cell images were preprocessed, then the proposed transformer-based model was trained using original and modified datasets. The performance of the model was tested using 20% of the dataset in both cases. In both cases, the proposed model was trained for 50 epochs, and the learning rate was fixed to 0.001. Various hyperparameters such as optimizers, batch size, and transformer's encoder depth were experimented with for performance analysis.
For measuring the performance of the deep learning model, various evaluation metrics were used. The proposed work was evaluated with confusion matrix (CM), accuracy, precision, recall, f1-score, and area under the curve (AUC) [35,36]: where T P = true positive means that a malaria-infected person is correctly detected as a malaria-infected person, T N = true negative means that a noninfected person is correctly detected as a noninfected person, F P = false positive means that a noninfected person is wrongly detected as an infected person, and F N = false negative means that an infected person is wrongly detected as a noninfected person. The original and modified datasets were considered for balanced binary classification. To examine the performance of the proposed model, various hyperparameters were considered. Among various optimization methods developed so far for deep learning, "Adam" [37] and "SGD" [38] are the two most used and popular optimization ones. Therefore, to demonstrate their effectiveness in malaria parasite detection, the proposed model was trained using both optimizers. Batch size is also a key factor for the model's learning and obtaining more generalized results. A larger batch size makes a model speed up the training process, whereas a much larger batch size very often provides poor generalization. In this study, the proposed model was also tuned with various batch sizes (8, 16, 32, and 64). Furthermore, different encoder depths (8, 12, and 16) were also experimented with.

Adam Optimizer for Original Dataset
The different performance criteria of the proposed model with the ADAM optimizer, for instance, precision, recall, F1-score, and accuracy, were calculated and are presented in Table 2. The ROCs of the model with the ADAM optimizer for different batch sizes are presented in Figure 6, which shows that the highest AUC of 64.61% was achieved using a batch size of 8. This could be due to the fact that the model trained with larger batch sizes with the Adam optimizer did not show a continuous improvement. Even though the ADAM optimizer produced very high precision, the other results for recall, F1-score, and accuracy were disappointing.

SGD Optimizer for Original Dataset
The effectiveness of the model with the SGD optimizer, along with various batch sizes, is briefly discussed in this section. The model's training and test accuracies for the original dataset are shown in Figure 7, and the training loss curve in Figure 8.  The proposed model's highest training accuracy was 98.61%, which was achieved with a batch size of 8, whereas the highest testing accuracy was 96.86% with a batch size of 64. A batch size of 64 resulted in the shortest training loss of 0.1%. To calculate how well the proposed model with the SGD optimizer detects malaria-infected patients, the same number of cell images as in Adam was used for testing. A number of predicted patients are shown by the CM in Figure 9. The highest accuracy of 96.41% was achieved with the SGD optimizer and a batch size of 32, and the highest recall of 95.88% was achieved for the same batch size (Table 3).

Encoder Depth for Original Dataset
To show the impact of different depths of the encoder, the SGD optimizer was used, and the batch size was fixed to 16. The highest test accuracy of 96.66% was obtained with an encoder depth of 8 ( Figure 11).
ROC scores of~99% were achieved from the model with all encoder depths of 8, 12, and 16 ( Figure 12).

Adam Optimizer for Modified Dataset
The proposed transformer-based model's classification performance with the modified dataset and Adam optimizer was evaluated and is presented in Table 4. Again, other than precision, the other performance results were poor. The highest ROC of 59.7% was achieved with a batch size of 16 ( Figure 13), indicating no promising results.

SGD Optimizer for Modified Dataset
Furthermore, the SGD optimizer was used with various batch sizes for the modified dataset, and accuracy and loss curves are shown in Figures 14 and 15, respectively.  The CMs for the model with the SGD optimizer are shown in Figure 16. Although the highest accuracy of 99.25% and recall of 99.50% were achieved for a batch size of 64 with the SGD optimizer, the results showed insignificant differences between the batch sizes ( Table 5). The ROCs of each batch size are demonstrated in Figure 17; the results did not show much difference, with AUC values close to 1.0 in all cases.

Encoder Depth for Modified Dataset
Encoder depth was also finetuned for the modified dataset. The larger model with higher encoder depth showed higher fluctuation, and the highest test accuracy of 99.29% was obtained from the proposed model with encoder depths of 8 and 16. ROC curves in Figure 18 ensured that all models achieved the same high AUC score of approximately 99.9%.

Performance Comparison between Two Datasets
With smaller batch sizes, the learning process became easier and provided the best evaluation results, but on the other hand, with greater batch sizes, the transformer-based model converged faster and provided much more generalization. After correcting the mislabeled data in the original datasets, the problem became a much easier balanced binary classification. From the above experimental results, it was observed that the Adam optimizer showed poor results in all cases, as it was not guaranteed to converge to the global optimum point. On the other hand, the results obtained using the SGD optimizer were significantly better than that obtained by the Adam optimizer for all performance criteria employed in this work. From Figure 19, it was observed that the classification performance of the model was optimistic with the modified dataset rather than the original dataset. Therefore, the proposed model with the SGD optimizer and modified dataset could be the best combination for accurate MP detection. However, a similar comparison made between the two datasets using the Adam optimizer showed no improvement with the modified dataset. Furthermore, this indicated that for both datasets, the SGD optimizer could produce optimistic results.

Performance Comparison with Previous Works
In this section, the performance of the proposed model is compared with several SOTA methods for both datasets. The details of the SOTA methods have been described in the Introduction section.
For the original dataset, the first five rows show the results of the SOTA models in Table 6. It was observed that the highest accuracy score of 91.80% was achieved from the previous work by Fatima and Farid [20]. On the other hand, the proposed transformerbased model achieved a promising accuracy of 96.41% when the SGD optimizer was used with a batch size of 32 and a learning rate of 0.001, almost 5% higher than the other best work reported to date. This suggests that the proposed model could produce even better results than that reported using the same dataset. For the modified dataset, the highest accuracy of 99.23% was achieved by Fuhad et al. [11]. However, the AUC of the SOTA models did not achieve a satisfactory level, whereas the proposed model showed an optimistic AUC score of 99.99% when the SGD optimizer with a batch size of 64 and a learning rate of 0.001 was employed. The highest AUC of the proposed model proved the highest differentiation capability between malariainfected and uninfected patients. The table suggested that the proposed transformerbased model achieved a satisfactory classification performance compared with the SOTA models mentioned.
RBC images were collected from an open-source repository. However, the real working procedure starts with segmenting, first, the RBCs in blood smear images that contain various other cells. Afterward, from the affected regions of the RBCs, the malaria parasite is classified. In this work, the segmentation task was ignored, and readymade RBC cell images were used only to classify the malaria parasite. Conventional CNN models show imagespecific inductive bias [40], and they are based on a local receptive field. To capture global information, CNN models need larger kernels or very deep network models. However, the transformer models are free from these shortcomings, and, therefore, the transformer-based model proposed for the malaria parasite showed excellent performance. Moreover, the Grad-CAM visualization demonstrated its explanation visibility. It was also noticed that similar to this study, the application of the SGD optimizer by other studies also produced the highest performance.

Conclusions
A multiheaded attention-based transformer model was proposed for malaria parasite detection. In addition, to interpret the trained model, Grad-CAM visualization was used to verify the learning. The proposed work with the transformer model achieved accuracy, precision, recall, and an AUC score of 99.25%, 99.00%, 99.50%, and 99.99%, respectively. Various SOTA works for malaria parasite detection were compared with the proposed model. The model outperformed all the previous works for detecting malaria parasites. Our future work will be focused on segmenting RBCs from blood smear images and classifying malaria parasites from the segmented RBC images.