SAG-DTA: Prediction of Drug–Target Affinity Using Self-Attention Graph Network

The prediction of drug–target affinity (DTA) is a crucial step for drug screening and discovery. In this study, a new graph-based prediction model named SAG-DTA (self-attention graph drug–target affinity) was implemented. Unlike previous graph-based methods, the proposed model utilized self-attention mechanisms on the drug molecular graph to obtain effective representations of drugs for DTA prediction. Features of each atom node in the molecular graph were weighted using an attention score before being aggregated as molecule representation. Various self-attention scoring methods were compared in this study. In addition, two pooing architectures, namely, global and hierarchical architectures, were presented and evaluated on benchmark datasets. Results of comparative experiments on both regression and binary classification tasks showed that SAG-DTA was superior to previous sequence-based or other graph-based methods and exhibited good generalization ability.


Introduction
Developing a new drug that gains marketing approval is estimated to cost USD 2.6 billion, and the approval rate for drugs entering clinical development is less than 12% [1,2]. Such massive investments and high risks drive scientists to explore novel and more efficient approaches in drug discovery. Under such circumstances, computer-aided drug design methods, especially the recent deep learning-based approaches, have been rapidly developing and have made key contributions to the development of drugs that are in either clinical use or clinical trials. Among the broad range of drug design phases that computational approaches involve, the prediction of drug-target affinity (DTA) is one of the most important steps, as an accurate and efficient DTA prediction algorithm could effectively speed up the process of virtual screening of potential drug molecules, minimizing unnecessary biological and chemical experiments by refining the search space for potential drugs.
Computational approaches for DTA prediction generally comprise two major steps. First, features of drugs or proteins, or representations/descriptors as alternative expressions, are obtained from raw input data by feature extraction methods. Compared to the original input data, the embedded representations are normally more applicable to the subsequent phase and can achieve better performance. The next step, as previously mentioned, is the classification/regression procedure, where the representations act as inputs and the network outputs as either data labels (i.e., active or inactive) or specific values and amino sequences. This chemical context was then combined with graph-derived features for DTA prediction [25].
The performance of either the 1D or structure-based representation can be enhanced by introducing attention mechanisms. The attention mechanisms allow the network to focus on the most relevant parts of the input and have been proven to be useful for various tasks [19,26]. For instances, AttentionDTA added an additional attention block following the two branches of the drug and protein, and, therefore, the learned features could be further weighted according to the attention score before they were fed into the fully connected classifying layers [27]. Lim et al. proposed a distance-aware attention algorithm that could capture the most relevant intermolecular interactions within the 3D proteinligand complex. Such attention mechanisms were proved to be effective when applied to DTA prediction tasks with structural information of a complex [16]. Recently, Lee et al. proposed a novel attention structure that introduced self-attention mechanisms for node pooling named self-attention graph pooling (SAGPool), and it achieved state-of-the-art performances in many graph learning tasks [28]. Inspired by this work, we implemented an SAG-DTA network in this study, which adopted a self-attention graph pooling approach to molecular graph representation. Two architectures, namely, global pooling and hierarchical pooling, were implemented and evaluated, with a detailed comparison of the pooling ratio and scoring method for each architecture.

Materials and Methods
SAG-DTA is an end-to-end prediction algorithm that takes the SMILES of drug molecules and the acid sequence of proteins as inputs and the affinity value that is measured by either the disassociation constant or KIBA (kinase inhibitors bioactivity data) score as the output. SAG-DTA regards DTA prediction as a regression task, and training data of drug-target pairs are sent to the network, which then learns the intrinsic relationship between the input sample and the output affinity value. Based on the GraphDTA, we implemented a more complicated graph representation of the drug molecule by introducing the self-attention pooling mechanism into the network. Specifically, the atom nodes were weighted by attention scores that were learned based on the features of the nodes themselves. Moreover, the atom nodes were also sorted according to the attention scores, and only those nodes with higher scores were kept. We hypothesized that such modification would allow the network to give more attention to the most important part and thus learn more complex and efficient feature representations for the prediction task. The overall architecture of SAG-DTA is presented in Figure 1. It can be seen that the SMILES of the drug molecule was used to build a molecular graph, and then the graph was sent to the GCN network with SAGPooling layers to learn drug features. For the protein, the acid sequence was sent to the CNN network to learn protein representation.

Datasets
The proposed model was first evaluated on two benchmark datasets of DTA prediction, namely, the Davis [29] and KIBA [30] datasets. The Davis dataset contains selectivity assay data of the kinase protein family and the relevant inhibitors with their respective disassociation constant (K d ) values. The KIBA dataset is about four times the size of the Davis dataset regarding the number of interaction entries. Additionally, it differs from the Davis dataset in that the interaction value was recorded using the KIBA score that was computed from the combination of heterogeneous information sources, i.e., IC 50 , K i , and K d . The dataset is of high quality, as the integrated heterogeneous measurements mitigated the data inconsistency arising from the use of a single information source. For consistency with previous studies [12,13], the values were transformed into log space (pK d ) using Equation (1).
In addition to the DTA prediction datasets, the proposed model was also evaluated on two benchmark binary classification datasets of CPI prediction, namely, the BindingDB [23] and Human [31] datasets. The Human dataset includes positive CPI pairs derived from DrugBank [32] and Matador [33], and it is characterized by the highly credible negative CPI samples. The BindingDB is another well-designed CPI dataset derived from a public database [34], and it contains pre-processed training, validation, and test sets. Statistics of these four datasets are summarized in Table 1.

Input Representation
The datasets consist of numerous binding entities, and each entity comprises a drug molecule and target protein pair. The drug molecules were originally stored in the SMILES format, and they were converted to molecular graphs where the atoms and bonds were taken as the nodes and edges, respectively. Self-connection was considered so that the diagonal elements were set to 1. In this study, features for atoms were kept the same as those in GraphDTA and are listed in Table 2. The process was implemented using the RDKit tool (version: 2020.03.4) [35], as shown in Figure 2. For proteins, unique letters that represent categories of amino acids were extracted, and each letter was further represented by integers. The protein sequences could thus be encoded using these integers, which is similar to the method of representation in DeepDTA [12]. Whether the atom is aromatic 1 Total 78 Figure 2. The process of molecular graph construction.

Network Architectures
SAG-DTA network architectures are shown in Figure 3. In this study, we consider two types of architecture in regard to the pooling strategy, namely, the global pooling architecture and hierarchical pooling architecture. The global pooling architecture, as illustrated in the left panel of Figure 3, consists of three graph convolutional layers, and the outputs of these three layers are concatenated before being fed into an SAGPooling layer, i.e., pooling in a global way. The remaining nodes then go through the readout layer and are finally passed to fully connected layers for drug molecule representations. The hierarchical pooing architecture demonstrated in Figure 3b is composed of three blocks, and each of them contains a graph convolutional layer and an SAGPooling layer. The convolutional results of each layer are thus hierarchically pooled and read out. These outputs are then summed before being passed to the fully connected layers to obtain the final drug representations.

Graph Convolution Layer
The graph convolution layer is formulated as Equation (2): Table 2) of the l-th layer, and Θ ∈ R F × F is the trainable convolution weight with input feature dimension F and output feature dimension F . Finally, the rectified linear unit (ReLU) function σ is used as the activation function in our model.

Self-Attention Graph Pooling Layer
The self-attention graph pooling (SAGPool) layer comprise a scoring method and a subsequent mask operation. The process is depicted in Figure 4. Briefly, self-attention scores for all of the atoms in the molecular graph are obtained using certain scoring method; then, all of the nodes are ranked, and the top kN nodes are selected based on their scores Z. k ∈ (0,1] is the pooling ratio that indicates the portion of retained nodes. The mask operation can be formulated as Equation (3): where idx is the indexing operation used to obtain the feature attention mask Z mask . In this study, four types of scoring methods were evaluated, namely, the GNN, GCN, GAT, and SAGE. These four networks are representative GNN variants and were proved to achieve good performance in graph-related tasks.

GNN Scoring Method
The GNN scoring method is defined as Equation (4): where v represents the node itself and N (v) is the set of all neighborhoods of node v. h (l) v ∈ R 1 × F is the feature of node v in the l-th layer, and Θ 1 , Θ 2 ∈ R F × 1 are the trainable convolution weights with input feature dimension F. σ(·) represents the activation function ReLU.

GCN Scoring Method
The GCN scoring method is defined as Equation (5): Equation (5) is identical to Equation (2), except for the fact that the dimension of convolutional weight is changed to R F × 1 to obtain the attention score value Z.

GAT Scoring Method
The GAT scoring method is defined as Equation (6): where Θ ∈ R F × 1 is the trainable convolution weight that is shared by all of the nodes. α v,u is the attention coefficient that is computed as Equation (7): where a is the shared attention operation that maps R 2F to R.

SAGE Scoring Method
The SAGE scoring method is defined as Equation (8): where the mean(·) indicates an averaging operation.

Readout Layer
The readout layer aggregates node features globally or hierarchically that depend on the pooling architecture. In this work, the readout layer is the concatenation of the average of the max of the node features, which can be written as follows Equation (9): where N denotes the number of nodes and x i is the feature vector of the i-th node.

Results and Discussion
The proposed SAG-DTA model contains a number of hyperparameters, and combinations of these hyperparameters form a vast search space. This section presents the evaluation of the two most critical aspects in the self-attention scheme, which are the self-attention pooling ratio and the calculating method for obtaining the attention score. The comparison experiments are detailed in Sections 3.3 and 3.4. For all of these model evaluation experiments, five-fold cross-validation was used. Specifically, the benchmark training set was shuffled and randomly split into five folds, with four of them being used as the training set and the remainder as the validation set. The model was trained on the four-fold training set and validated on the validation set, and this process was repeated five times. The average result was recorded to assess the model performance. After all of the hyperparameters were determined in this way, we used all five folds to train the model and tested it on the benchmark test set. Finally, we compared SAG-DTA with several existing DTA and CPI prediction methods in Sections 3.5 and 3.6.

Metrics
In order to make comparisons with the baseline models, the concordance index (CI) and mean squared Error (MSE) were used to evaluate the performances of the model. CI can be used to evaluate the ranking performance of the models that output continuous values [38], and it is computed as Equation (10): where δ x and δ y are the larger and smaller affinity values, respectively, and b x and b y are the corresponding prediction values of the model. Z is a normalization constant, and h(x) is the step function that takes the form of the following Equation [11]: The other metric, MSE, measures the difference between the vector of predicted values and the vector of the actual value, and it is widely used in regression tasks. It can be calculated as Equation (12): where p i is the predicted value and y i is the actual value.
The proposed model was also evaluated on several compound-protein interaction (CPI) datasets. CPI prediction is a binary classification task, and the following metrics (Equations (13) and (14)) were used to assess the performance of our models: where TP, FP, and FN represent the sample numbers of true positive, false positive, and false negative, respectively. In addition, the area under the receiver operating characteristic curve (AUROC) and the area under the precision recall curve (AUPRC) of the presented model were also calculated to facilitate comparisons with other models.

Setting of the Hyperparameters
The hyperparameters that were used in SAG-DTA model are listed in Table 3. Most of these hyperparameters were derived from the baseline model (i.e., GraphDTA [24]), while the pooling ratio and the scoring method as two key factors for the performance of SAG were determined in detail using fivefold cross-validation. In this study, we evaluated the performances of these two hyperparameters thoroughly on both the global and hierarchical architectures. The search spaces of the hyperparameters and architectures are highlighted in bold in the last three lines of Table 3.

Performances of Various Pooling Ratios
The pooling ratio of SAGPool, which determines the percentage of nodes that should be retained, is a key factor to be considered in the model. To identify the best graph pooling ratio, values from 0.1 to 1 were evaluated for both the global and hierarchical pooling architectures, as illustrated in Figure 5. For the global architecture, the MSE showed a generally downward trend and achieved its lowest value of 0.217 at a pooling ratio 1.0. Another metric CI exhibited oscillation between 0.892 and 0.894 when the pooling ratio was larger than 0.4. The best pooling ratio was finalized as 1.0 in this architecture based on the major indicator MSE.
For the hierarchical architecture, the MSE showed a similar downward trend, with the minimum value of 0.218 achieved at several pooling ratios, including 0.6, 0.8, and 1.0. The ratios were then compared using the candidate CI metric, as demonstrated in the right bottom panel of Figure 5, which is roughly in agreement with the MSE that showed better performance with the increase in the pooling ratio. The ratio value of 1.0 was finally chose, as it achieved the best performance regarding both the MSE (0.218 ± 0.003) and CI (0.895 ± 0.004). The results demonstrate that all atoms in drug molecules had their specific contributions to the drug's interactions with protein targets. Although assigning weights to nodes could differentiate the contribution of different atoms and therefore benefit the performance of the prediction model, the results suggest that those atoms with lower attention scores cannot be completely ignored.

Performances of Various Attention Scoring Methods
The self-attention pooling layer assigns each node an attention score. The attention score has two functions. First, scores of atoms are used as a criterion for ranking and pooling nodes within the graph. Second, the score is used directly as a weighting factor on the node features to differentiate the contribution of different atoms. Since the attention scores directly decide the importance of nodes within each layer, the scoring method thus acts as another important factor in determining the performance of the model and, therefore, needs to be carefully decided. As part of the self-attention pooling strategy, we used the feature of the node itself as the only input feature in the scoring model to obtain the scores of each node, i.e., self-attention. For the scoring method, we adopted the GNN rather than hand-crafted functions to automatically learn the weights. In this section, we compare four GNN variants as scoring methods using fivefold cross-validation, namely, the GNN, GCN, GAT, and SAGE (introduced in Section 2.3.2). The results are illustrated in Figure 6. For the global architecture, the GNN achieved an MSE of 0.217, which was the lowest among the four scoring methods. The obtained CI values showed slight discrepancies, but the GNN, GAT, and SAGE all achieved a CI value of 0.893. For the hierarchical architecture, the GNN also achieved the best MSE (0.218 ± 0.003) and CI (0.895 ± 0.004). These results together demonstrate that the GNN is the most effective method of the four scoring methods for both the global and hierarchical architectures.

Comparisons with Other Baseline Models
The optimal hierarchical and global SAG models that were obtained via the above hyperparameter tuning were compared to traditional machine learning methods (i.e., Kron-RLS [5,39] and SimBoost [6]) and recent cutting-edge DTA prediction approaches, including DeepDTA [12], WideDTA [13], AttentionDTA [27], DeepGS [25], and GraphDTA [24]. In these models, different descriptors were used to represent proteins and compounds (the column 'Proteins and Compounds' in Table 4), including the Smith-Waterman (S-W) [40] descriptor; the PubChem Sim descriptor [41]; and the descriptors obtained from convolutional networks, such as CNN (for SMILES) and GCN (for graph representation). For WideDTA, the protein sequence (PS) and protein motifs and domains (PDM) were specifically used for protein description, whereas ligand SMILES (LS) and ligand maximum common substructure (LMCS) were used for drug description. For all of these methods, the same benchmark test sets were used, and the overall performances measured by MSE and CI are summarized in Table 4. It can be seen that SAG-DTA approaches were superior to 1D representation-based approaches or other graphbased approaches. Among the two pooling architectures, the global architecture achieved better performance with an MSE of 0.209 and a CI of 0.903. Though slightly inferior to the global architecture, the hierarchical variant also obtained good results that were better than those of the other baseline models.
To further test the generalization of the proposed method, we evaluated the model on the KIBA dataset with the same hyperparameters as those in the Davis dataset. The experimental results are shown in Table 5, and can be observed that SAG-DTA is the most accurate among the evaluated methods. In detail, the global SAG-DTA achieved an MSE of 0.130 and a CI of 0.892, and the hierarchical SAG-DTA achieved an MSE of 0.131 and a CI of 0.893. These results demonstrate the effectiveness and good generalization ability of our model in DTA prediction.

Model Evaluations of the Compound-Protein Interaction Task
We also assessed the performances of SAG-DTA in CPI prediction. In this study, we refer to the binary classification task of drug-target interaction as CPI to distinguish it from the DTA, which is a regression task. The two architectures of SAG-DTA were separately evaluated on two widely used benchmark datasets of CPI prediction, namely, the Human and BindingDB datasets. These datasets contain compound and protein pairs in addition to a binary label that indicates whether or not they interact. SAG-DTA was slightly adjusted for the binary classification task by adding a Sigmoid layer only in order to ensure that the model was able to predict probabilities and binary labels for samples.
On the Human dataset, SAG-DTA was compared to traditional machine learning algorithms, including k-nearest neighbors (k-NN); random forest (RF); L2-logistic (L2); support vector machines (SVMs); and some recent graph-based approaches, such as CPI-GNN [22], DrugVQA [42], and TransformerCPI [43]. The performances of these models were obtained from [43] and are summarized in Table 6. It can be observed that both SAG-DTA architectures were superior to other methods in terms of AUROC, precision, and recall. Notably, SAG-DTA achieved a significant improvement in the baseline GraphDTA such that the AUROC was improved to 0.984 (±0.003) from 0.960 (±0.005). The evaluation results on the BindingDB dataset are summarized in Table 7. Among these graph-based methods, SAG-DTA of the global architecture achieved the best performance in terms of AUROC (0.963) and AUPRC (0.966), and the hierarchical architecture variant was also superior to other methods. Notably, hyperparameters of both the two SAG-DTA variants were not fine-tuned for both the Human and BindingDB datasets. In summary, the superior performance of SAG-DTA on both DTA and CPI tasks suggests its good generalization ability. To provide insight into the improved results by introducing the self-attention algorithm, we discussed the mechanism here in terms of the machine learning aspect as well as chemical intuition.
From a machine learning perspective, the self-attention algorithm in SAG-DTA is a function of the weighting operation that assigns weights, i.e., attention scores, to each atom node within a molecule graph. The features/information of different nodes are therefore weighted before they are aggregated as the final molecule descriptor. Molecular graph descriptors obtained in this way can be more effective, because, in some cases, such as DTA and CPI tasks, the nodes are not equally important for the final prediction. In contrast, for these graph prediction models without self-attention, the node features are indiscriminately aggregated. As a result, the features of some critical nodes are not 'highlighted' in the final graph representation.
The above discussions can be naturally extended to the drug molecular graph and DTA/CPI tasks. It can be assumed that atoms in a drug molecule typically do not contribute equally to the final affinity value, and the attention scores can therefore differentiate the importance of different atoms. These critical atoms that play chemical roles in the process of drug-protein interaction will gain more weight when involved into affinity prediction. Consequently, effective representations of molecules are obtained with the help of the self-attention algorithm.

Conclusions
Predicting drug-target affinity is of great importance to drug development, and an accurate DTA algorithm will benefit the drug screening by minimizing experimental costs and reducing development durations. In this paper, we proposed a graph-based DTA prediction method named SAG-DTA, which utilizes self-attention mechanisms on the drug molecular graph to obtain drug representation. Evaluation of the model on benchmark datasets demonstrated that both hierarchical architecture-based and global architecturebased SAG-DTA achieved superior performance to that of various existing DTA prediction methods, suggesting the effectiveness of the proposed approach in predicting the affinity of drug and protein pairs. Furthermore, the good performance of SAG-CPI, which is the CPI version of SAG-DTA, demonstrated the good generalization ability of the proposed method as well as the effectiveness of the self-attention mechanisms.

Informed Consent Statement: Not applicable.
Data Availability Statement: All the relevant data are included within the paper.

Conflicts of Interest:
The authors declare no competing financial interest.