Multi-Label Diagnosis of Arrhythmias Based on a Modiﬁed Two-Category Cross-Entropy Loss Function

: The 12-lead resting electrocardiogram (ECG) is commonly used in hospitals to assess heart health. The ECG can reﬂect a variety of cardiac abnormalities, requiring multi-label classiﬁcation. However, the diagnosis results in previous studies have been imprecise. For example, in some previous studies, some cardiac abnormalities that cannot coexist often appeared in the diagnostic results. In this work, we explore how to realize the effective multi-label diagnosis of ECG signals and prevent the prediction of cardiac arrhythmias that cannot coexist. In this work, a multi-label classiﬁcation method based on a convolutional neural network (CNN), long short-term memory (LSTM), and an attention mechanism is presented for the multi-label diagnosis of cardiac arrhythmia using resting ECGs. In addition, this work proposes a modiﬁed two-category cross-entropy loss function by introducing a regularization term to avoid the existence of arrhythmias that cannot coexist. The effectiveness of the modiﬁed cross-entropy loss function is validated using a 12-lead resting ECG database collected by our team. Using traditional and modiﬁed cross-entropy loss functions, three deep learning methods are employed to classify six types of ECG signals. Experimental results show the modiﬁed cross-entropy loss function greatly reduces the number of non-coexisting label pairs while maintaining prediction accuracy. Deep learning methods are effective in the multi-label diagnosis of ECG signals, and diagnostic efﬁciency can be improved by using the modiﬁed cross-entropy loss function. In addition, the modiﬁed cross-entropy loss function helps prevent diagnostic models from outputting two arrhythmias that cannot coexist, further reducing the false positive rate of non-coexisting arrhythmic diseases, thereby demonstrating the potential value of the modiﬁed loss function in clinical applications.


Introduction
Cardiovascular disease (CVD) is one of the leading causes of death, accounting for over 31% of deaths worldwide [1].There are many types of cardiovascular diseases, and their impact on human health also varies.Determining the type of CVD plays an important role in follow-up treatment.In the clinic, one of the most commonly used methods to diagnose CVD is the resting electrocardiogram (ECG).Medical personnel place electrodes at fixed positions on the resting patient to acquire and select a high-quality 10 s ECG and make a diagnosis based on the ECG waveform.According to incomplete statistics, there are more than 100 kinds of cardiovascular diseases, and the detection of ECGs depends on the diagnostic experience of medical professionals.Therefore, it is very important to develop ECG-based diagnostic tools.
Most early ECG diagnostic tools were realized by imitating the logical conclusions of the physician.Geddes et al. [1] proposed classifying various premature ventricular contractions (PVC) using rule-based reasoning.First, the parameters for detection were selected according to the ECG characteristics of PVC, such as the R-R interval and the duration and shape of the QRS complex.Then, certain medical rules were used as criteria for assessing the occurrence of PVC.Kezdi et al. [2] proposed an algorithm for detecting ectopic beats and arrhythmia based on clinical experience.The R-wave was determined by calculating the slope of the QRS complex.Supraventricular tachycardia and ventricular ectopy were detected by calculating the changes in the R-R interval and the width, polarity, and height of the QRS complex.The parameters selected for these methods are clinically interpretable.However, other feature extraction methods (except for R-wave) are not accurate enough because of the strong personalization and nonlinearity of ECG signals, especially in different types of arrhythmias.Since different types of ECG signals have different time-frequency features, large errors can easily occur in the calculation of feature parameters, leading to the failure of this type of method.
Another type of method is pattern recognition.First, certain statistical features are extracted, and then a classifier is created using machine learning (ML) to classify different types of arrhythmias.In many studies, time/morphological statistics [3][4][5][6][7], spectral features [8,9], and higher-order statistical parameters [10][11][12][13] have been used to diagnose ventricular arrhythmias in malignant arrhythmias.These mathematical features, in combination with classifiers such as artificial neural networks (ANNs) or support vector machines (SVMs) [14][15][16], can efficiently filter out rhythms such as ventricular fibrillation and ventricular tachycardia.The two steps (i.e., feature extraction and classification) in pattern recognition help in the diagnosis of cardiac arrhythmias.The accuracy and efficiency of detection are better than simulating the physician's logical conclusions.The disadvantage is that the signal features are artificially determined, or more precisely, the quality of the signal features often depends on artificial experience.Therefore, it is difficult to find effective statistical features because there are too many types of arrhythmias.
In recent years, with the development of deep learning, researchers have begun to use deep learning instead of artificial feature extraction methods [17] to evaluate ECG signals."Artificial feature extraction methods" refer to the methods used to calculate the features of electrocardiogram signals from different perspectives (such as the time domain, frequency domain, and time-frequency domain) for the classification of arrhythmias.The selection of these features is based on personal subjective experience.Feng et al. [18] employed dynamic time warping (DTW), C-means clustering, and the BP algorithm to optimize the parameters of the probabilistic process neural network (PPNN).The method achieved an F1 score of 0.7615 and an accuracy of 74.16% on the Chinese Cardiovascular Disease Database (CCDD).While PPNN offers advantages such as few-shot learning and computational complexity, the limited size of its parameters hampers its classification performance.Yıldırım et al. [19] proposed a new one-dimensional convolutional neural network model (1D CNN) to classify 17 types of cardiac arrhythmias.Its accuracy and F1 score on the MIT-BIH arrhythmia database were 91.33% and 0.8538, respectively.The model demonstrated efficient and rapid diagnostic capabilities.Luo et al. [20] conducted a study using the same database and proposed a hybrid convolutional recurrent neural network (HCRNet), achieving an accuracy of 99.01%.However, the MIT-BIH data were derived from internal patients, and the ECG signals exhibited highly personalized characteristics.Thus, a model with high accuracy might not necessarily possess a high degree of generalizability across different patients.Yao et al. [21] proposed the ATI-CNN model to address the low performance of a CNN in the detection of variable-length ECG signals.This model integrated a CNN, recurrent cells, and an attention module.On the China Physiological Signal Challenge (CPSC) dataset, ATI-CNN achieved an F1 score of 0.812 and a precision of 0.826.By combining the spatiotemporal features of ECG signals, ATI-CNN improved accuracy while reducing the number of model parameters, thereby lowering training costs.However, this model did not consider the oneto-many relationship between patients and arrhythmia labels.Objectively, deep learning methods learn features from a large number of data to classify ECG signals, which will be the development direction of intelligent ECG diagnosis in the future.
In ECG signals, some arrhythmias can occur simultaneously, whereas others do not.For example, in ECG signals of a period of sustained atrial fibrillation, PVCs but not premature atrial fibrillation can occur simultaneously.The relationship between the various designations is complex, making multi-label classification of ECG signals challenging [22][23][24].Yoo et al. [25] optimized the algorithm from the perspective of multi-label classification of arrhythmia and proposed xECGNet.By incorporating the L2 norm of attention maps of different disease categories into the loss function, xECGNet achieved a multi-label subset accuracy of 84.6% in the classification tasks of eight types of arrhythmias on the CPSC dataset.Yang et al. [26] proposed using a stacking approach to combine the classification results of ResNet and random forest and obtain the final results through voting.Despite the method's accuracy improving to 95%, integrating multiple models increased deployment costs, making it challenging to apply to general medical embedded devices.Nowadays, current methods emphasize learning the relationships between labels from research data (the labels themselves).However, due to the complex relationships between the labels of ECG signals, it is difficult to learn these relationships from only research data.This causes the diagnostic models to output some arrhythmias that cannot coexist, leading to the increased misdiagnosis rate of the multi-label ECG diagnostic algorithm [27].
In this work, we propose a multi-label diagnostic method based on a modified twocategory cross-entropy loss function.This method first incorporates LSTM and attention mechanisms to enhance the classification accuracy of the CNN model.Building upon this, to address the issue of certain conclusions being unable to coexist in arrhythmia diagnosis, we add a regularization term to the traditional binary cross-entropy loss function, which disallows the coexistence of certain arrhythmia disease label pairs.The regularization term helps constrain the network's learning direction, enabling it to consider the mutually exclusive relationships between various disease labels.It improves the applicability of the ECG diagnostic algorithm in real-life diagnosis scenarios.
The main innovative points of this article are: (A) A new multi-label training loss function is proposed by adding a regularization term that does not allow the coexistence of some arrhythmias; (B) A CNN + LSTM + ATTENTION architecture is presented to improve ECG classification performance; (C) More than 10,000 ECG recordings of the six most common cardiac arrhythmias are used to test the loss function and classification method, and the performance is compared between patients.Our method improves the accuracy of classifying four types of arrhythmias (normal, sinus tachycardia, atrial flutter, and atrial tachycardia) and reduces the incidence of misdiagnosing atrial flutter and atrial tachycardia as false positives.
This paper is organized as follows.In Section 2, explanations of the CNN + LSTM + ATTENTION architecture and the modified cross-entropy loss function are presented.The new ECG database is described in Section 3.An analysis of the modified cross-entropy loss function and its comparison with other methods are described in Section 4. Further details of the presented method and future research topics are given in Section 5. Section 6 presents the conclusions of this paper.

Deep Learning Model
In this work, a deep learning model, consisting of a convolutional neural network (CNN) [28], long short-term memory (LSTM) [29], and an attention mechanism [30] is used to classify ECG signals.

Feature Extraction
A CNN is used for feature extraction, as shown in Figure 1.For the convolution operation in the CNN, it is assumed that z l j represents the j-th channel output of the i-th convolutional layer and o l j is the input.The input o l j and the output z l j of the l-th layer can be expressed by Equation (1) and Equation (2), respectively.
where f(•) is the activation function, M j is the subset of the feature map of the (l − 1)-th layer, k l ij is the convolution kernel matrix, b l j is the bias, and '*' is the convolution symbol.For the pooling operation in the CNN, α stands for the sampling coefficient and represents the maximum pooling (•) function.The input o l+1 j and the output z l+1 j of the (l + 1)-th layer can be expressed by Equation (3) and Equation (4), respectively.

LSTM
The features Z ∈ R T×D obtained by the CNN are input to the following LSTM, where T is the length of the input features and D is the number of input features.The workflow is shown in Figure 2.
The internal state C t ∈ R S between the units in the LSTM layer is used to determine the relationship between the ECG features extracted by the CNN.S represents the length of the vector output from the LSTM layer.z t represents the t-th slice in the group of input features (1 ≤ t ≤ T). h t ∈ R S represents the hidden state of the LSTM layer corresponding to z t .The final output h t can be calculated as follows:

Attention Mechanism
The attention mechanism is used to compute the attention distribution in the hidden state h t (1 ≤ t ≤ T) at each time point.The final output features are then formed by the weighted average of the attention distribution.The computational process is illustrated below: where Z ∈ R S represents the results after the weighted average, β i represents the weighting factor in the hidden state h i (1 ≤ t ≤ T) , b s and W s are both trainable weights, W s represents the query vector, and u i (1 ≤ t ≤ T) represents the intermediate weighting factor in the calculation.

Fully Connected Layer
Finally, the features Z obtained by the attention mechanism are input to the fully connected layer to perform the final classification.The final prediction vector z is obtained as follows: where z 1 and z 2 each represent a weighing matrix in the fully connected layer.b 1 and b 2 represent the bias matrices.z 1 represents the output of the first fully connected layer, and Sigmiod represents the activation function.

The Modified Cross-Entropy Loss Function
In multi-label classification, a two-category cross-entropy loss function is usually used to calculate the loss between the labels and the predicted outcomes.In this work, two types of cross-entropy loss functions are studied, given by Equations ( 16) and (17).
where N is the number of arrhythmia disease types, y k is the k-th element in the real ECG label vector, and a k is the k-th element in the predicted ECG label vector.M is the number of combinations belonging to the coexistence of arrhythmia diseases with strong negative correlations, a l is the l-th combination of arrhythmia diseases that cannot coexist, and a i and a j are the i-th element and j-th element, respectively, in the predicted ECG label vector in a j .According to Equation ( 16), loss − 1 is the traditional cross-entropy loss function used for multi-label classification and is widely used in deep learning.However, the traditional cross-entropy loss function does not consider the correlations between different labels [31][32][33].This results in cardiac arrhythmias, which almost never occur simultaneously, in the predicted results.
According to Equation ( 17), loss − 2 is the modified cross-entropy loss function, obtained by introducing a regularization term.The regularization term increases the penalty of the co-occurrence of cardiac arrhythmias that cannot coexist, which is expected to improve the prediction performance of deep learning models.
Specifically, when the model predicts the presence of non-coexisting arrhythmia disease label pairs in the results due to the logarithmic function's derivative property, it rapidly increases the value of the regularization term, allowing the model to continue training.Conversely, this regularization term tends toward 0, resulting in the degeneration of the loss function into a binary cross-entropy loss function, which does not affect the prediction of other coexisting disease labels.

ECG Database
This work is based on the 12-lead ECG data collected by SID MEDICAL TECHNOLOGY CO., LTD from many hospitals in Shanghai.The device used to acquire the ECG signals was the Inno-12 ECG acquisition workstation, as shown in Figure 3.The ECG signals collected were 10 seconds long, and the sampling frequency was 500 Hz.The ECG signals were first magnified 400 times using electrode tabs and then discretized, ensuring the accuracy of acquisition.Considering the power frequency interference, a trap filter was developed in the hardware circuit.Each ECG sample was processed with a Butterworth bandpass filter (0.5~100 Hz) to remove high-and low-frequency noise.
A total of 39,069 data were collected, including six types of ECGs: normal ECG, sinus tachycardia, sinus bradycardia, atrial flutter, atrial tachycardia, and premature ventricular contraction (PVC), as shown in Figure 4.In a 10-second ECG signal, atrial flutter and atrial tachycardia cannot coexist simultaneously.In this work, they are considered non-coexisting arrhythmia disease label pairs.All ECG data were labeled by two professional cardiologists.If the two cardiologists disagreed, the label was determined by a third chief cardiologist.Then, these ECG data were divided into a training dataset (23,322), a validation dataset (2591), and a test dataset (13,156).The distribution of arrhythmias in the different datasets is shown in Table 1.

Experimental Setup
In terms of hardware, all experiments were carried out on a Dell T5820 workstation with an Intel Core i9-10900X CPU, 64 GB of RAM, and two graphics cards (NVIDIA RTX 3060 12GB) sourced from Dell in Shanghai, China.In terms of software, all deep learning models were constructed using Numpy 1.19.5, TensorFlow 1.13.1, and Keras 2.2.4,which were installed on Ubuntu 20.04.

Parameter Setting 4.2.1. The Deep Learning Model
Three CNN models (i.e., 1D VGG16 [34], 1D ResNet34 [35], and 1D ResNet50 [35]) were used to compare whether our proposed method leads to performance improvements in the CNN models, as shown in Figure 5.In our method, only one CNN model is used for feature extraction.The character '/2' in each sub-image means that the stride size in the corresponding network layer is 2. The VGG16 model comprises 16 convolutional layers and adopts the traditional stacked convolution layer approach.Its model structure is relatively deep but simple.The ResNet34 model has 34 convolutional layers and adds residual structures, in contrast to VGG16.It resolves the issue of gradient vanishing during model training by incorporating skip connections that directly add the input to the output.The ResNet50 model, on the other hand, has 50 convolutional layers and utilizes bottleneck structures to reduce computational complexity and improve model efficiency.
Three deep learning models (i.e., VGG16 + LSTM + ATTENTION, ResNet34 + LSTM + ATTENTION, and ResNet34 + LSTM + ATTENTION) were used to verify whether our proposed method leads to performance improvements in the CNN models.The corresponding parameter settings and network structures are shown in Table 2.The input size of all three deep learning models was 5000 × 12 .The output sizes of the three deep learning models were 44 × 512, 22 × 512, and 22 × 2048, respectively.In the LSTM layer, an intermediate output with a size of 1 × 60 was generated at each iteration.The activation function 'sigmoid' was used for the forget gate, the input gate, and the output gate.The activation function 'tanh' was used for updating the state C t .The initialization method used for the matrix weight was 'glorot uniform'.
In the attention layer, the sizes of the matrix weight W s , bias b s , and query vector u s were 60 × 60, 60 × 1, and 60 × 1, respectively.The initialization method used was 'glorot uniform'.The first dense layer consisted of 64 neurons and used the activation function 'ReLU'.The second dense layer consisted of six neurons (corresponding to the different diseases) and used the activation function 'Sigmoid'.The initialization method used in the two fully connected layers was 'glorot uniform'.

The Modified Cross-Entropy Loss Function
In the database created in this work, atrial flutter and atrial tachycardia have a high negative correlation.It was found that the correlation (Poisson correlation degree) between atrial flutter and atrial tachycardia was −0.98 according to the correlation analysis of arrhythmia diseases based on the 200,000 ECG conclusions obtained from Shanghai Zhongshan Hospital.Thus, the loss function used in this work can be expressed using Equation (18).
where a 4 and a 5 represent the existence probabilities of atrial flutter and atrial tachycardia, respectively, obtained from the predicted ECG label vector.The influence of the presence of both atrial flutter and atrial tachycardia in the predicted outcomes on the regularization term is shown in Figure 6a.The regularization term tended toward 0 when either only one or neither (i.e., atrial flutter and atrial tachycardia) appeared in the predicted outcomes.The regularization term increased rapidly when the probability of atrial tachycardia and atrial flutter simultaneously exceeded 0.5.Figure 6b shows the influence of the regularization term on the model's weight matrix concerning the partial derivative values of the loss function.When two labels with a negative correlation were present simultaneously, the corresponding ∂Loss/∂W value increased, thereby enhancing the speed of weight updates in the backward propagation process of the model.This enabled the model to promptly recognize negative correlations between the labels and adjust the weights accordingly.Conversely, when the model experienced a decrease in the speed of the weight updates, it tended to achieve stability.

Evaluation Indicators
In this section, six evaluation indicators are examined to assess the performance of the presented models.The six evaluation indicators are (1) Error Num, (2) Hamming Loss, (3) 94.74%, 93.64%, and 93.52%, respectively.It is proved that a CNN with a suitable structure is effective in multi-label ECG classification.After adding LSTM + ATTENTION, the F1 scores of the three methods were 95.21%, 93.98%, and 94.16%, respectively.This shows that the prediction performance of a CNN can be improved or ensured by adding LSTM + ATTENTION.In this section, the traditional cross-entropy loss function (see Equation ( 16)) is used to train the presented deep learning methods, as shown in Table 2.
VGG16 [34], ResNet34 [35], and ResNet50 [35] are the most commonly used CNN models for ECG classification.In this work, VGG16, ResNet34, ResNet50, and their combinations with LSTM + ATTENTION were used to verify the performance of the improved loss function.'Adam' was chosen as the optimizer, the initial learning rate was set to 0.001, and the number of training epochs was set to 100.Regarding the hyperparameter settings for each CNN model, we established them based on parameters published in the literature [36,37] and determined the optimal model training configuration using the GridsearchCV algorithm [38].To evaluate the performance of the multi-label model on the validation dataset, an early stop mechanism was introduced into the training process to prevent overfitting.The training of the model was stopped if the loss of the multi-label model on the validation dataset did not decrease in 10 consecutive training sessions.
The training dataset (23,322) and the validation dataset (2591), as described in Section 3, were both used to train the three deep learning models.The test dataset (13,156) was used to test the effectiveness of the trained models.
The training process using the traditional cross-entropy loss function is shown in Figure 7a.The corresponding experimental results are shown in Table 4.The training process using the modified cross-entropy loss function is shown in Figure 7b.The corresponding experimental results are given in Table 4.
The early stop mechanism stopped the training of the model when it entered a stable phase.In Figure 7, it can be seen that ( 1) the model loss did not decrease after 68 training epochs when using the traditional loss function, and (2) the model loss did not decrease after 52 training epochs when using the modified loss function.It can be concluded that the training epochs were shorter when using the modified loss function.
In addition, it can be seen in Table 4 that (1) for both loss functions, VGG16 outperformed the ResNet34 and ResNet50 models across all evaluation metrics; (2) 'Error Num' was significantly reduced when using the modified loss function; (3) 'Precision' slightly increased when using the modified loss function; and (4) 'Subset Accuracy', 'Jaccard Index', 'Recall', and 'F1 score' decreased slightly when using the modified loss function.It can be concluded that the modified loss function can significantly reduce the number of coexisting strongly negatively correlated labels while guaranteeing model performance.Therefore, it can be concluded that the modified loss function can effectively prevent the occurrence of strongly negatively correlated arrhythmias in the multi-label diagnosis of arrhythmias.
Table 5 compares the accuracy of classifying different arrhythmias using the two different loss functions.In the table, it can be seen that there was a slight improvement in accuracy when diagnosing normal ECG, sinus tachycardia, atrial flutter, and atrial tachycardia with the improved loss function.However, it should be noted that the model's accuracy in classifying PVCs decreased by more than 1%.Furthermore, with regard to the overall improvement in precision evident in Table 4 and Figure 8, we conclude that using the modified two-category cross-entropy loss function significantly reduces the number of misdiagnoses of atrial tachycardia.

Discussion
This article proposes a multi-label diagnosis method for cardiac arrhythmias based on a modified two-category cross-entropy loss function.In order to validate the performance of LSTM + ATTENTION, the classic neural networks VGG16, ResNet34, and ResNet50 are used for evaluation.The results show that the prediction performance of the CNN can be improved or ensured by adding LSTM + ATTENTION.
Many types of diseases can be identified from ECG signals, and some of these diseases cannot exist simultaneously.We compare the traditional loss function to our improved loss function across different CNN models.The results indicate that using the traditional loss function still produces non-coexisting labels.However, when using the proposed modified loss function in this paper with the addition of a regularization term, the model's weight update rate between negatively correlated labels is strengthened, forcing the CNN model to learn the connections between non-coexisting labels and preventing the appearance of non-coexisting label pairs in the diagnostic results.In addition, the improved loss function shortens the required training period of the model, demonstrating the effectiveness of our approach in reducing model training costs and enhancing the feasibility of clinical applications.
To validate the classification performance of the modified loss function in diagnosing cardiac arrhythmias using neural network models, we compare the accuracy of the two different loss functions in six types of ECG arrhythmias.The results indicate that our method can improve the precision of the model for negatively correlated atrial tachycardia and atrial flutter labels.This means that it can reduce the risk of false positives in medical diagnosis, demonstrating the potential value of the improved loss function in clinical applications.However, our method shows decreased accuracy in the identification of PVCs and sinus bradycardia.Currently, our research focuses on six common types of cardiac arrhythmias.In the future, we will expand our scope to include a broader range of cardiac arrhythmia datasets.

Conclusions
This work applies a CNN + LSTM + ATTENTION model to multi-label ECG classification.To prevent the occurrence of label pairs that cannot exist simultaneously, in the presented method, a modified cross-entropy loss function is proposed.The modified loss function introduces a regularization term to increase the penalty for the coexistence of arrhythmias exhibiting a strong negative correlation.Experimental results show that the modified loss function helps prevent the occurrence of strongly negatively correlated arrhythmias, sacrificing prediction accuracy by only a small margin.This work provides theoretical evidence for multi-label ECG classification in clinical diagnosis.

Figure 1 .
Figure 1.Structure of the deep learning model for multi-label diagnosis of cardiac arrhythmias.

Figure 6 .
Figure 6.Changes in the regularization term: (a) The impact of negatively correlated labels on the regularization term.(b) The impact of the regularization term on the partial derivative values of the loss function with respect to the model's weight matrix.

Figure 7 .
Figure 7.Comparison of the traditional and modified loss functions: (a) training process using the traditional loss function; (b) training process using the modified loss function.
where f t , i t , and O t represent the update results of the forget gate, input gate, and output gate, respectively.W f , W i , W c, and W o represent the weights of the forget gate, input gate, output gate, and LSTM state unit, respectively.
Figure 2. Structure of LTSM and attention mechanism.

Table 1 .
Distribution of cardiac arrhythmias in different datasets.

Normal Sinus Tachycardia Sinus Bradycardia Atrial Flutter Atrial Tachycardia PVC
Note: atrial flutter and atrial tachycardia are non-coexisting arrhythmia disease label pairs.

Table 2 .
Parameter settings and network structures of the three deep learning models.

Table 3 .
Comparison of the performance of the three CNN models after adding LSTM + ATTENTION.

Table 4 .
Comparison between the modified and traditional loss functions.Effectiveness of the Modified Cross-Entropy Loss FunctionIn this section, the modified cross-entropy loss function (see Equation (18)) is used for training the presented deep learning methods, as shown in Table2.The other parameter settings are the same as those used in Section 4.2.

Table 5 .
Comparison of the accuracy of 6 types of cardiac arrhythmias between the modified and the traditional loss functions.Comparison of the precision of 6 types of cardiac arrhythmias between the modified and traditional loss functions.