Computer-Aided Diagnosis of Alzheimer’s Disease through Weak Supervision Deep Learning Framework with Attention Mechanism

Alzheimer’s disease (AD) is the most prevalent neurodegenerative disease causing dementia and poses significant health risks to middle-aged and elderly people. Brain magnetic resonance imaging (MRI) is the most widely used diagnostic method for AD. However, it is challenging to collect sufficient brain imaging data with high-quality annotations. Weakly supervised learning (WSL) is a machine learning technique aimed at learning effective feature representation from limited or low-quality annotations. In this paper, we propose a WSL-based deep learning (DL) framework (ADGNET) consisting of a backbone network with an attention mechanism and a task network for simultaneous image classification and image reconstruction to identify and classify AD using limited annotations. The ADGNET achieves excellent performance based on six evaluation metrics (Kappa, sensitivity, specificity, precision, accuracy, F1-score) on two brain MRI datasets (2D MRI and 3D MRI data) using fine-tuning with only 20% of the labels from both datasets. The ADGNET has an F1-score of 99.61% and sensitivity is 99.69%, outperforming two state-of-the-art models (ResNext WSL and SimCLR). The proposed method represents a potential WSL-based computer-aided diagnosis method for AD in clinical practice.


Introduction
Alzheimer's disease (AD) is a common chronic progressive neurodegenerative disease of the elderly characterized by progressive dementia and brain degeneration. It significantly affects cognitive functions, memory functions, the quality of life, and the emotions of more than 50 million people worldwide [1]. According to a report by the World Health Organization (WHO), AD has become the fifth leading cause of death, and the number of AD patients will increase to 152 million by 2050, and by 2050 [2]. However, the etiology of AD remains unclear, and there are no effective drugs or treatments to reverse dementia [3]. The preclinical stage of AD, called mild cognitive impairment (MCI), is a transitional state between normal aging and AD [4]. According to a report of the American Academy of Neurology [5], about 10% to 15% of patients with MCI may eventually suffer from AD, whereas only 1% to 2% of patients experience normal aging. Unfortunately, due to a lack of understanding of AD by patients and their family members, most patients suffer from moderate and severe stages of AD at the time of diagnosis and have missed the optimal intervention stage [6]. Therefore, it is of great significance to identify the risk and extent of AD as early as possible. Typically, doctors have to conduct careful medical assessments of 1.
An attention module (AM) is proposed to improve the discriminative ability of the backbone network with a low computational cost. The AM is an automatic weighting module that adjusts the weights of the channels in the feature maps so that the backbone network selectively focuses on the significant parts of the input.

2.
The task networks perform two tasks (image classification and image reconstruction) in parallel. The task networks utilize the feature vector generated by the backbone network and use fully-connected (FC) layers and a decoder for label prediction and image reconstruction.

3.
A multi-task learning (MTL) framework is proposed for conducting image recognition and reconstruction in parallel, with low computational requirements and good performance (with the best F1-score of 99.61% and a sensitivity of 99.69%) using only 20% of the labels from the datasets for fine-tuning.
The rest of the paper consists of five parts. The related studies are described in Section 2. The proposed method is explained in Section 3. Section 4 summarizes the results. The discussion is presented in Section 5, and the conclusion is given in Section 6.

Multi-Task Learning
Multi-task learning is a transfer learning method that extracts domain-specific information from related tasks for an improved representation of the input data [24]. The concept of MTL was first proposed by Caruana et al. [25] and has been applied in many fields. Various studies have been conducted to explore effective MTL methods. Misra et al. built a novel sharing unit to learn representation from different tasks and reported excellent results [26]. Lu et al. developed an adaptive feature sharing mechanism in MTL to identify different attributes of people [27]. In the above studies, each task has a same priority. While some studies focused on a single task, whereas others acted as auxiliary tasks. Auxiliary tasks, which provide additional information, provide valid information from aspects of the main task. Zhang et al. used head pose estimation and attribute prediction of faces as auxiliary tasks and facial landmark detection as the main task; it was found that the precision of this method was higher than that of other methods [28]. Our framework belongs to this type of transfer learning method.

Weakly Supervised Learning
It is well-known that supervised learning-based models require a large amount of well-labeled data to obtain accurate predictions. In contrast, unsupervised learning-based models typically lack high precision, and the learning process is less effective than that of supervised learning methods [29]. Weakly supervised learning is a machine learning technique with the objective of learning effective feature representation from limited or low-quality annotations [30]. Various WSL-based methods have been explored in different fields. Hu et al. proposed a CNN-based WSL framework for the task of multimodal image registration and achieved STOA performance [31]. Wang et al. developed a WSL-based method for accurate automated segmentation of remote sensing data with a proposed U-Nets framework and obtained superior segmentation performance [32]. ResNeXt WSL is the most recent WSL method. It was used to pre-train images from the Instagram website, followed by fine-tuning on the ImageMet dataset [22]. SimCLR, an unsupervised learning method, was used in combination with WSL and achieved SOTA performance [23].

Image Classification
The objective of image classification is to classify an image or instance into categories [33]. With the rapid development and verification of CNNs, various CNN-based frameworks have been developed and used for image classification [34]. LeCun et al. first develop the CNN framework (LeNet) for document recognition [35]. AlexNet, which is an improvement of LeNet and surpassed traditional machine learning methods, won the Imagenet competition in 2012 [16]. Since then, different types of CNN architectures have been proposed, such as ResNet [36], ResNext [21], and InceptionNet [37]. In this study, we adopted the structure of the residual block, which is used in ResNet [36].

Image Reconstruction
Image reconstruction, which is a critical problem in medical imaging, is a technique for creating 2D or 3D images from sets of 1D projections [38]. The autoencoder is the most popular technique and has proved effective in image reconstruction of unlabeled images [39]. In this study, we designed a sub-network to perform image reconstruction using abundant features.

The Pipeline of the Proposed Framework
The proposed WSL-based DL framework, which is called ADGNET, is a CNN-based single-input-multi-output (SIMO) architecture consisting of two components: an improved backbone network with the attention mechanism and task nets that consists of two subnetworks, i.e., the classification sub-network (CSN) and the reconstruction sub-network (RSN). The backbone network has a residual network structure with the proposed AM to obtain highly discriminative representations while suppressing unrelated regions in the images. As shown in Figure 1, the backbone network consists of five convolution stages (C1-Attention to C5-Attention), followed by the Resnet. Generally, feature maps generated by the deeper stages contain more semantic information, and those generated by the shallower stages contain more detailed information, such as edges and corners. The backbone network extracts features step-by-step from the MRI input images and generates a pooling map using global average pooling (GAP). The task nets use the pooling map as input and flatten it as a feature vector V f . The V f is then sent to two different task branches; one generates the prediction vector V p using the FC layer, and the other reconstructs the original images using the FC layers and a decoding module. As shown in Figure 1, the V p is sent to the argmax (Amax), which returns the index with the largest value of the axes of the V p and provides the classification results. Here C denotes the number of categories. Figure 1. The proposed framework. The framework consists of two parts, including a backbone network with an attention mechanism that acts as a shared network for extracting salient features and task nets that contain two sub-networks. The two sub-networks simultaneously conduct two sub-tasks, i.e., classification and reconstruction.

Backbone Network with the Proposed Attention Module
The backbone network is a multi-stage convolution network that follows the Resnet to avoid the gradient vanishing problem. Different stages in the backbone network generate feature maps with different resolutions. Notably, the backbone network is a convolution network shared by the two sub-task nets, providing a parameter-efficient and time-efficient method. In the reconstruction task, the backbone network can be regarded as part of an encoder network that automatically learns the feature representation from the images without annotation information. In the classification task, the backbone network can be considered a feature extractor, which is optimized using the supervised learning principle.
As described in Section 3.1, each stage of the backbone network generates the attention feature maps after implementing the proposed AM. As shown in Figure 2, the input feature maps F s i with a size of H × W × C and a scale of s at layer i are the output of the ith stage of Resnet. Given the F s i as input, the AM outputs the channel attention factors ( CAF) with a size of 1 × 1 × C so that the network can automatically determine the importance of the extracted features.This process can also be regarded as a feature filtering and selection method that improves the discrimination ability of the network at low computational cost. The CAF are then fused with the F s i using the element-wise multiplication operation (EWMO), and the fused feature map F s i is the output. The feature extraction process of the backbone network can be expressed as follows: where AM is the proposed attention module, and ⊗ is the EWMO. The AM is an automatic weighting module that learns the channel weights of the input feature maps. As shown in Figure 2, the input F s i is first downsampled using the global max-pooling operation to retain important information while reducing the computational cost. The downsampled feature maps are then flattened into a one-dimensional vector. The flattened vector is then sent to a convolution layer with a 1 × 1 kernel to extract features from the vector and adjust its dimension. The extracted features are sent to the batch norm (BN) layer and activated using the ReLU function to speed up the training and convergence speed of the module and increase its nonlinear representation ability. After these operations, an FC layer (Linear) with ReLU as the activation function is adopted to output the CAF with a size of 1 × 1 × C. The final output is generated using the sigmoid function to convert the values of the CAF to a range of 0 to 1; these values can be considered the importance scores of the channels in the F s i .

Task Sub-Networks
As shown in Figure 1, the task sub-networks consist of the CSN and the RSN. The input to the two parts is the flattened vector (V f ). The two tasks are performed in parallel. The CSN is a simple FC layer called FC p that generates a vector with a size of 1 × C. The vector is then sent to the sigmoid function to generate the prediction vector (V p ) with a range of 0 to 1 that is used as a probability of prediction of the specified classes. The process of the CSN can be formulated as follows: where FC is the FC layer FC p ; σ is the sigmoid function. The RSN consists of two components, i.e., the encoder and the decoder. The encoder is constructed using the backbone network and two FC layers called FC e1 and FC e2 . The backbone network extracts and abstracts the features step-by-step, and the two FC layers encode the features into a vector (V e ). The decoder component is a multi-layer transposed convolution network. The details of the decoder component are shown in Figure 3. The input of the decoder component is the V e after the reshape operation, which converts the V e to a two-dimensional feature map. The decoder is a modular network consisting of two parts with multiple transposed convolution layers. Each transposed convolution layer in the decoder has multiple transposed convolution kernels with a size of 3 × 3, a stride of 2, and a padding of 1. As shown in Figure 3, there are M× transposed convolution layers and ReLU layers in the first part, which decode the input feature maps and up-sample the input. The second part of the encoder contains a convolution layer and a Tanh layer. The convolution layer is used for dimension normalization to convert the feature maps to the same size as the input MRI image. The Tanh layer is used to output the predicted MRI image since it has a wide range of predicted values, improving the prediction accuracy. The RSN process can be formulated as follows: where FC e1 and FC e2 are the FC layers; tanh is the tanh function. Dec is the proposed decoder part and Im r is the reconstructed image.

Loss Function
The proposed ADGNET can be trained in an end-to-end manner. The loss function of the framework is composed of two parts and is defined as follows: where L cls and L rec are the classification loss and the image reconstruction loss, respectively. λ 1 and the λ 2 are the weighting factors which balance the two losses.
As we described in the introduction, it is difficult to distinguish dementia in the early and middle stages from normal aging because of the small differences in brain imaging. Thus, we adopted the focus idea, as introduced in previous works [40,41], to ensure that the framework focuses primarily on difficult and misclassified samples. We proposed a new loss function based on the cross-entropy loss. The modified classification loss is defined as follows: where N is the number of samples participating in a single optimization. γ i represents the class-wise weight reduction factors, which adjust the importance of different samples for an improved representation. i is the ground truth probability of a target belonging to a given class, and i is the prediction probability of the target belonging to the given class.
The new loss function is a modification based on the cross-entropy (CE) loss. As shown in Figure 1, the final probability of a sample be a specified category is generated using the sigmoid function which range from 0 to 1. Therefore, the equation 5 demonstrated the loss for each specified category following the formulation of the CE loss, and the class-wise weight reduction factors γ i can be considered as a numerical vector. In our manuscript, the values of γ i were all set as 2 as default.
We used the mean square error function as the reconstruction loss; it is defined as follows: where N is the number of samples participating in a single optimization. Y i is the ground truth value of the input sample, andŶ i is the prediction value of the reconstructed sample.

Multi-Modal Brain Imaging Dataset
In this study, the proposed ADGNET was evaluated on two different brain imaging datasets for a comprehensive assessment. The two brain imaging datasets are the Kaggle Alzheimer's classification dataset (KACD) [42] and the Recognition of Alzheimer's Disease dataset (ROAD) [43]. The example data of the two datasets are shown in Figure 4. Each dataset was divided into two parts: a train-val part and a test part using the train-testsplit function (TTSF) from the scikit-learn library. The details of the split are shown in Table 1. The KACD dataset contains 6400 2D MRI images from 6400 cases, and each case is assigned into one of four categories: Non-Demented, Very Mild Demented, Mildly Demented, and Moderately Demented. The ROAD contains 532 3D MRI images from 532 cases, and each case is assigned into one of three categories: Non-Demented, Mildly Demented, and Alzheimer's disease. As shown in Table 1, the data set was separated into two parts, including a training-val set (TVS) for training and selection of model weights and an independent test set (TS) to evaluate the performance of the models. The TVS of the KACD contains 5121 2D MRI images, and the TS of the KACD contains 1279 2D MRI images. The TVS of the ROAD contains 300 3D MRI images, and the TS of the ROAD contains 232 3D MRI images.

Evaluation Metrics
The Kappa score (Kappa), sensitivity (Sen), specificity (Spe), precision (Pr), accuracy (Acc) and F1-score metrics were used to evaluate the performance of the proposed ADGNET comprehensively. The equations of the six metrics are as follows: Given the definitions of pe and p0, the Kappa score is defined as follows: The sensitivity is defined as: The specificity is defined as: The precision is defined as: The accuracy is defined as: The F1-score is defined as: where TP represents the true positive, TN represents the true negative, FP represents the false positive, and FN represents the false negative. Six evaluation metrics (Kappa, Sen, Spe, Pr, Acc and F1-score) were employed to evaluate the performance of the proposed ADGNET and other SOTA WSL-based methods. The Kappa is a statistical indicator of the stability of the model prediction. The Sen is related to the positive prediction rate and is a significant indicator in medical diagnosis. The Spe indicates the correctness of the model's prediction and also has great significance in medical diagnosis. The Pr refers to the ability of the model to provide a positive prediction. The Acc is an indicator of the correctness of the model's prediction. The F1-score is the harmonic mean of the Pr and Sen.

Experimental Results
The performance of the proposed ADGNET was evaluated using multimodal datasets (2D MRI images and 3D MRI images) to assess the generalization ability and transferability of the model with six metrics (Kappa, Sen, Spe, Pr, Acc and F1-score). Two sets of experiments were conducted (experiments A and B) to evaluate the performance of the proposed ADGNET. A bootstrapping method was used to calculate the empirical distributions of the boxplots. All experiments were conducted on the KACD dataset and ROAD dataset. An ablation study was also conducted to better demonstrate the effectiveness of the proposed framework. As shown in Tables 2 and 3, the ADGNET (proposed) means the proposed framework as demonstrated in Figure 1; the ADGNET (no RSN) means the Subnet2:RSN as shown in Figure 1 was excluded while the rest of the proposed framework are retained and trained with the same amount of annotations; the ADGNET (no AM) means the Attention Mechanism as shown in Figure 2 was excluded while the rest of the proposed framework are retained and trained with the same amount of annotations. The training and inference processes were performed on four Nvidia GTX 2080Ti GPUs and Intel Xeon E5-2600 v4 3.60 GHz CPU using the Pytorch framework. Table 2. Performance indices of the proposed ADGNET framework of the experiment A and the average performance of the two state-of-the-art (SOTA) models on the KACD dataset.  Table 3. Performance indices of the proposed ADGNET framework of the experiment B and the average performance of the two SOTA models on the ROAD dataset. In this experiment, we used the 2D MRI images from the KACD dataset to evaluate the models' performances. The overall results of the six metrics for the proposed ADGNET, the ResNeXt WSL, and the SimCLR are listed in Table 2. The optimum performance was obtained by ADGNET, with an F1-score of 99.61%, followed by SimCLR (98.67%) and ResNeXt WSL (98.37%). The Acc was highest for ADGNET (99.61%), followed by SimCLR (98.60%) and ResNeXt WSL (98.36%). The Pr, Spe, Sen and Kappa of ADGNET were 99.53%, 99.53%, 99.69% and 99.22%, respectively. The values of the indices were higher than those of ResNeXt WSL (97.84%, 97.81%, 98.91% and 96.72%) and SimCLR (98.60%, 98.59%, 98.75% and 97.34%). The corresponding boxplots of the six evaluation metrics (Kappa, Sen, Spe, Pr, Acc and F1-score) of the models' performance on the KACD dataset are shown in Figure 5. In this experiment, we used the 3D MRI images from the ROAD dataset to evaluate the models' performance. The overall result of the six metrics for the proposed ADGNET, the ResNeXt WSL, and the SimCLR are listed in Table 3. The best performance was obtained by the proposed ADGNET, with an F1-score of 98.49%, followed by SimCLR (93.00%) and ResNeXt WSL (92.61%). The Acc was highest for ADGNET (98.71%), followed by ResNeXt WSL (93.53%) and SimCLR (93.00%). The Pr, Spe, Sen and Kappa of ADGNET were 98.99%, 99.24%, 98.00% and 97.36%, respectively. The values of the indices were higher than those of ResNeXt WSL (91.26%, 93.18%, 94.00% and 86.87%) and SimCLR (93.00%, 94.70%, 93.00% and 87.70%). The corresponding boxplots of the six evaluation metrics (Kappa, Sen, Spe, Pr, Acc and F1-score) of the models' performance on the ROAD dataset are shown in Figure 6.

Training Details
The TVS of each dataset was split into 5 parts using a stratified sampling method. The model was trained using 20% of the labels. A 5-fold cross-validation was adopted to evaluate the performance of the trained model. The samples in the TS of each dataset were used to verify the performance of the proposed ADGNET.

Discussion
We developed a CAD method for the identification and classification of AD in multimodal brain imaging data (2D MRI and 3D MRI) using WSL-based DL techniques. Excellent performance was obtained by the proposed ADGNET based on six evaluation metrics, and the method proved superior to two SOTA WSL-based methods. The proposed ADGNET is a modular framework consisting of a backbone and the task subnets. We used a residual block in the design of the backbone network to retain the features while preventing degradation of the framework. We incorporated an AM into the backbone network to ensure the high discriminatory ability of the backbone network with a low computational cost. Unlike the conventional methods like Resnext WSL which assign the same weight to each channel, the proposed AM learned the channel weights of the input feature maps from the supervised information and the images for an improved feature representation of the samples. This helps the framework to focus more on the most discriminative part from the feature space. The feature vector obtained from the pooling map of the backbone network was flattened and sent to the two sub-networks. The CSN extracted the feature information directly from the vector and was optimized using the supervised information. The RSN encoded the vector to a new feature space using two FC layers and used a decoder to reconstruct the input images. The two FC layers and the backbone network comprised the encoder that was used for feature coding. The proposed decoder network consisted of the transposed convolution layer and the convolution layer. The objective of the transposed convolution layer was to learn the feature information for image reconstruction, and the convolution layer was used to adjust the number of channels and generate the final output. Unlike previous WSL-based methods (e.g., ResNeXt WSL and SimCLR), which have to be pretrained on a large independent dataset and fine-tuned on the target dataset, the proposed ADGNET is trained in an end-to-end manner. The training process of the ADGNET is controlled by adjusting the weighting parameters. When λ 1 is zero, the network only learns the features from the images. When λ 2 is zero, the network only learns the features from the annotations. In this way, the large-scale unlabeled data can be fully utilized to help the proposed framework obtain more stable and representative features. Excellent performance was achieved by ADGNET based on the six statistical metrics for the multi-modal brain imaging datasets (KACD (2D MRI) and ROAD (3D MRI)). In order to intuitively analyze the experimental results, the heatmaps of the proposed ADGNET and the two SOTA models were demonstrated in Figure 7. As can be seen from Figure 7, the proposed ADGNET is able to capture key features while retaining more features by means of the AM and the RSN. While the ResNeXt WSL and SimCLR only use very limited features for prediction and their prediction scores are relatively low. Notably, the ADGNET's prediction score is quite higher than ResNeXt WSL and SimCLR, which indicated that the ADGNET's prediction is more reliable. The ADGNET has a promising potential as an auxiliary tool to assist in the diagnosis of AD due to its high performance, good stability, and cross-modal flexibility. Besides, medical diagnosis in a real situation is much more complex than in experimental environments, and sufficient and high-quality annotations are difficult to obtain. Therefore, the development of our proposed WSL-based DL methods is crucial for diagnosing conditions such as AD. However, the proposed ADGNET may also encounter some problems when the distribution of the data has extremely category imbalance. It is also important to develop effective generative frameworks to generate a large amount of effective data for compensation with WSL-based methods. Figure 7. Visualization of heatmaps. We compare the visualization heatmaps of the ADGNET (proposed, no reconstruction sub-network (RSN) and no attention module (AM)), the ResNeXt WSL (weakly supervised learning), and the SimCLR. The heatmap visualization is calculated for the last convolutional outputs and P denotes the prediction score of each network for the ground-truth category.

Conclusions
This study presented a unique WSL-based DL framework for the identification and classification of AD using multi-modal brain imaging data (2D MRI and 3D MRI). The proposed ADGNET provided excellent performances on six metrics (Kappa, Sen, Spe, Pr, Acc and F1-score), outperforming the two SOTA WSL-based models on two public datasets (KACD (2D MRI) and ROAD (3D MRI)) using limited annotation (only 20% of the labels). Most notably, the Kappa of the ADGNET was 0.9922 on the KACD dataset and 0.9736 on the ROAD dataset. These values were 2.50% and 1.88% higher than those of the two SOTA methods on the KACD dataset and 10.49% and 9.66% higher on the ROAD dataset, respectively.
The excellent performance achieved by ADGNET indicates that the proposed AM and the framework are suitable for the task and that the model is superior to the two SOTA WSL-based methods. The proposed AM module enabled the ADGNET to automatically assign different weights for different channels in the feature maps for a better capture of discriminative features. It is well-known that obtaining a large sample size and highquality annotations of medical images is time-consuming and expensive. The introduction of sub-network for image reconstruction task help the ADGNET acquire effective features mining from large scale unlabeled data. Therefore, the development of WSL-based DL methods might represent a potential research direction to achieve accurate mining of massive medical data. In the future, the potential of this framework will be explored in-depth for other challenging tasks, including the detection of brain tumors and other major diseases.
Author Contributions: All authors contributed extensively to the study presented in this manuscript. S.L. and Y.G. contributed significantly to the conception of the study. S.L. designed the network and conduct the experiments. S.L. and Y.G. provided, marked, and analyzed the experimental results. Y.G. supervised the work and contributed with valuable discussions and scientific advice. All authors contributed in writing this manuscript. All authors have read and agreed to the published version of the manuscript.

Conflicts of Interest:
The authors declare no conflict of interest.

Abbreviations
The following abbreviations are used in this manuscript: