Forecasting the Walking Assistance Rehabilitation Level of Stroke Patients Using Artificial Intelligence

Cerebrovascular accidents (CVA) cause a range of impairments in coordination, such as a spectrum of walking impairments ranging from mild gait imbalance to complete loss of mobility. Patients with CVA need personalized approaches tailored to their degree of walking impairment for effective rehabilitation. This paper aims to evaluate the validity of using various machine learning (ML) and deep learning (DL) classification models (support vector machine, Decision Tree, Perceptron, Light Gradient Boosting Machine, AutoGluon, SuperTML, and TabNet) for automated classification of walking assistant devices for CVA patients. We reviewed a total of 383 CVA patients’ (1623 observations) prescription data for eight different walking assistant devices from five hospitals. Among the classification models, the advanced tree-based classification models (LightGBM and tree models in AutoGluon) achieved classification results of over 90% accuracy, recall, precision, and F1-score. In particular, AutoGluon not only presented the highest predictive performance (almost 92% in accuracy, recall, precision, and F1-score, and 86.8% in balanced accuracy) but also demonstrated that the classification performances of the tree-based models were higher than that of the other models on its leaderboard. Therefore, we believe that tree-based classification models have potential as practical diagnosis tools for medical rehabilitation.


Introduction
Cerebrovascular accidents (CVA), i.e., strokes, could lead to walking impairments ranging from mild gait imbalance to complete loss of mobility for patients. Therefore, rehabilitation walking therapy for those patients starts with the proper prescription of walking assistance devices, such as a tilt table, a harness, a (hemi) walker, or a (quarter or single) cane. During the prescription of these devices, the diagnostician's bias might act as noise that could cause misdiagnosis with unnecessary costs for the patients and the hospitals [1]. Therefore, this paper evaluates machine learning (ML) and deep learning (DL) classification algorithms to confirm whether these models could be supportive tools for diagnosticians by providing suitable predictive performance.
With great advances in ML and DL algorithms (although DL is an area of ML, we separated them for comparison), artificial intelligence (AI) techniques have been applied to various areas of image classification [2,3] to Go [4] and games [5,6]. Especially in the medical domain, numerous studies have also been conducted, including cancer detection with image classification [7], a patient modeling system for clinical demonstration [8], an emergency screening system that differentiates acute cerebral ischemia and stroke mimics [9], a gait monitoring system that predicts stroke disease [10], etc. In the rehabilitation domain, walking assistance robot development [11], AI-based virtual reality rehabilitation [12], and forecasting mortality of stroke patients after complete rehabilitation with tree-based ML models [13] have been studied. Although there exist similar studies [14,15] to ours, the former employed only support vector machines (SVM) [16] for gait classification after extracting features using hidden Markov models [17] and the latter only used lasso regression [18] to prevent overfitting from the small sample size when investigating factors affecting stroke patients' clinical outcomes and when predicting their discharge scores. Different from these studies, this paper aims to evaluate seven different ML and DL classification models with a dataset of 383 stroke patients to determine which walking assistant devices is the most appropriate for a patient according to their conditions.

Dataset and Experimental Settings
We conducted an exploratory data analysis to extract the data characteristics. We then preprocessed the data to balance the number of class observations using the undersampling, oversampling, and combined sampling methods. The ML and DL classification models were trained with the original (unpreprocessed) or preprocessed dataset. We obtained a set of performance metrics for each method (i.e., accuracy, precision, recall, F1-score, and balanced accuracy) using five-fold cross validation (5-CV).

Data Description
We collected anonymized data on the walking rehabilitation history of 383 stroke patients (1623 observations) from the following five hospitals: Chung-Ang University Hospital (CAUH), Seoul National University Hospital (SNUH), National Traffic Injury Rehabilitation Hospital (NTIRH), The Catholic University of Korea Yeouido St. Mary's Hospital (CUYMH), and Asan Medical Center (AMC) from January 2019 to January 2021. Table 1 provides details on the number of patients and observations in the dataset. The features of the data (inputs of the algorithms) were composed of 82 values arranged in six categories: anthropometry, stroke, blood tests, functional assessment, biosignal ward, and disease. We provide the details of the data in Appendix C, including patient characteristics, category distributions, and more specific features in the seven categories. The labels (outputs of the models) were composed of eight classes to differentiate between types of walking assistant devices: tilt table (0), harness (1), walker (2), hemi-walker (3), quarter cane (4), single cane (5), walking (plane) (6), and advanced (stair) (7). Figure 1 displays the distribution of the number of observations for each class.

ML and DL Algorithm Settings
For ML, we employed four widely used classification algorithms: SVM [16], Perceptron (PT) [25], Decision Tree (DT) [26], and Light Gradient Boosting Machine (Light-GBM) [27]. We also utilized one of the most recently developed automated ML (Au-toML) [28] algorithms, the AutoGluon [29] Python library package, to find the best predictive ML classification models with our dataset. For DL, we employed two DL classification models proposed for tabular-formed dataset: SuperTML [30] and TabNet [31]. We also provide their backgrounds in Appendix A.2.
• SVM, PT, and DT settings: we utilized the scikit-learn (Ver. 0.23) [23,24] Python ML library package, and we adapted the radial basis kernel function [32] in SVM and the Gini impurity for a node split criteria in DT. We did not set the regularization term in PT. • LightGBM settings: in the LightGBM package (Ver. 2.3.1) provided as Python API via scikit-learn [23,24], we empirically decided to use a traditional gradient boosting decision tree as a boosting type without limitations for the number of leaf nodes and depth. We also found that the best performing learning rate was 0.1. • AutoGluon settings: among the various AutoML Python library packages, we employed the latest and best performing one: AutoGluon (Ver. 0.0.15) [29]. We empirically adjusted the "time_limit" parameter for the whole model from 60 to 120 s and found that the performance did not improve over 120 s. The evaluation metric for each model in the ensemble was set to "accuracy". We also set the "presets" parameter to be "best_quality" to improve the ensemble models' predictive performance based on stacking and bagging in the granted training time. • SuperTML settings: as this model transforms tabular data into images, its performance depends on convolutional structures. Therefore, we experimentally found that ResNet [2] with 152 convolutional layers performed the best. • TabNet settings: although TabNet [31] is composed of an encoder and a decoder for self-supervised learning [33], we employed only its encoder network for supervised learning. To improve its predictive performance, we modified it into a six-step operation, where we omitted "shared across decision steps" at steps 1-3 under the feature transformer process. We also changed the shared across decision steps to unshared across decision steps in steps 4-6.

Performance Measurement Settings
We measured the classification model's predictive performance in terms of accuracy, precision, recall, F1-score, and balanced accuracy. As most of these measurements are designed for binary classification problems, we transformed them for multi-class classification using the weighted average conditions in the scikit-learn Python library package [23,24]. We describe the formulations of these measurements in Appendix B. We computed the metrics by averaging the results of 5-CV for fair comparison. In each step of 5-CV, we split all of the data into an 8:2 ratio, where 80% was used for training and 20% was used for testing (validation). For experiments with balanced data, we applied the three sampling methods to the training data, after which the data were used to train the ML or DL models (the models were also trained with the unpreprocessed original data). Finally, the trained models were tested with the test data. Figure 2 summarizes each step of the 5-CV process. The collected data were split into 80% for training and 20% for testing. The sampling methods were either applied only to the training data to balance the distribution of class labels or not, after which the models were fitted to the preprocessed data. We then tested them using the test data to evaluate predictive performance.

Results and Discussion
Here, we report and discuss the classification results of the ML and DL models that we employed. We summarize the results in Table 2 via the various classification measurements: accuracy, precision, recall, F1-score, and balanced accuracy. Table 2 presents each model's classification results according to the data preprocessing methods: original (without sampling methods), SMOTE, TomekLinks, and SMOTETomek. The entries in the table are means and standard deviations, which are denoted in the form mean ± standard deviation. The best accuracy, recall, precision, F1-score, and balanced accuracy among the seven algorithms in each sampling method including the original are highlighted in bold typeface. Table 2. Performance metrics (accuracy, recall, precision, F1-score, and balanced accuracy) of the ML and DL models according to sampling method. We measured recall, precision, and F1-score as weighted averages. The bold typeface stands for the highest metrics in each measurement. In general, the three types of data preprocessing (sampling) methods did not have a positive influence on most classification results except for SVM and SuperTML. Only SVM exhibited dramatic improvements using these methods; for example, an approximately 11% increment was achieved in balanced accuracy by SMOTE and SMOTETomek, whereas only 0.6% was achieved by TomekLinks. On the other hand, SuperTML benefited from SMOTE and SMOTETomek, with only about 0.2% to 1% increments for all results. TomekLinks, however, yielded a reduction in all classification results ranging from 0.2% to 0.7%.

ML/DL Models Accuracy (%) Recall (%) Precision (%) F1-Score (%) Balanced Accuracy (%)
Although most models suffered from a small decline in classification results due to the sampling methods, AutoGluon achieved a more stable predictive performance, where the standard deviations for the averaged 5-CV metrics decreased from 0.3 to 0.2 in accuracy, recall, precision, and F1-score. It seems that, as AutoGluon is an ensemble learning method, some of the newly generated data might positively affect various algorithms within it.
Among the ML and DL classification models, LightGBM and AutoGluon demonstrated the highest classification results (over 90% accuracy, recall, precision, and F1-score). They also presented the highest balanced accuracy: 85% to 86.8%. Note that they all belong to ML classification algorithms and not to DL models. Subsequently, the DL classification models SuperTML and TabNet generated very similar results, with 88.4% to 90.6% accuracy, recall, precision, and F1-score; in contrast, they achieved 82.4% to 85.3% in balanced accuracy. Despite their similar predictive performances, SuperTML required about 70 min of training time whereas TabNet required only about 15 min, which is considered more efficient learning than SuperTML. Finally, it is also notable that the performance results of DT did not reveal much difference from the results of the two DL models, ranging from about 3.4% to 5%. These observations of the results indicate that tree-based ML algorithms are more suitable for our dataset.

Which Model Performed Best?
First, AutoGluon almost always produced the best performance regardless of class distribution (except for balanced accuracy and precision with SMOTETomek sampling). As shown in Table 2, DT, LightGBM, and AutoGluon demonstrated reasonable classification results compared to the other models. In addition, a leaderboard for AutoGluon (Table 3) indicated that the best ranked models are composed of CatBoost boosted trees (CBT) [34], Random Forests (RF) [35], LightGBM, and extremely randomized trees (ERT) [36], which are all tree-based ML algorithms. On the other hand, the DL-based models' performances were worse than that of LightGBM and AutoGluon. Additionally, they needed longer computational times for 5-CV than the ML models (LightGBM required only 0.09 min and AutoGluon required only 12 min, whereas 15 min were needed for TabNet and 70 min were needed for SuperTML).
The leaderboard of AutoGluon describes the ranking of performance by each classification model based on Score_test measured as the log-loss of each model. Notably, the tree-based algorithms in AutoGluon (CBT, LightGBM, RF, and ERT) with different node-splitting criteria (where Gini, Entr, XT, and custom denote Gini impurity, information gain, extremely randomized, and customized function, respectively) demonstrated the highest classification results, where the score_test values were −0.196, −0.2, −0.223, and −0.228 for CBT, LightGBM, RF, and ERT, respectively. Additionally, the results of DT shown in Table 2 present better classification results than those of other algorithms (SVM and PT). In addition, considering the time spent on the procedure of 5-CV (DT, LightGBM, and AutoGluon took 0.07, 0.09, and 12 min, respectively, whereas 15 min and 70 min were needed for TabNet and SuperTML, respectively), we found that the tree-based classification models are more efficient for learning from our dataset compared to the two DL models, though the performance of DT was 3.4% to 5% lower than that of the DL models.
Additionally, the leaderboard (Table 3) also contains predictive performance of nontree-based models: K-nearest neighbors (KNN) and neural network classifier (NNC). The Score_test of them exhibited significantly worse (i.e., bigger log-loss) performance relative to CBT (at least a 0.107 difference for NNC and a 0.862 difference for KNN). We further discuss why these tree-based classification models demonstrated better predictive performance than the other models. Table 3. Leaderboard for AutoGluon listing the best performing individual classification models from the ensemble model. The attributes Score_test and Score_val are log-loss used to evaluate predictive performance, and the models were sorted according to performance. Note that the closer the value is to zero, the better the model. For details on the other attributes, Stack_level and Fit_order, refer to [29].  Figure 3 describes a single sample tree from the entire set of trees generated by LightGBM. The square nodes denote features in the dataset, whereas the circular nodes are leaf nodes with raw values before the sigmoid function is applied. The output probability after the sigmoid function indicates that the input observation could belong to some class with the probability value. Generally, most tree-based algorithms define their level of nodes (features) according to various metrics to reduce uncertainty on decision boundaries. In other words, the deeper the level of nodes, the more specific the decision. Once the tree is generated by the training data, the test (unseen) data are classified according to the structures of the trees. We believe that this procedure is very similar to the practical diagnostic reasoning [37] process because the medical diagnostic process is also based on pruning (narrowing) an initial set of hypotheses by gathering more information to lower uncertainties for verification [38][39][40]. Analogous to this, the tree-based models also try to narrow the set of hypotheses by computing and comparing uncertainty-related metrics with each feature to learn the optimal decision boundary. Therefore, due to this similarity, it appears that these tree-based models have an advantage of predictive performance compared to other models.

Conclusions
In this work, we evaluated the classification performance of ML and DL models for forecasting stroke patients' walking assistance levels using a dataset gathered from different hospitals. We found that the tree-based ML algorithms yielded the most suitable classification results, and we discussed the similarities between the procedures for treebased models and actual practical diagnostics. We believe that the similarity is based on the fact that both consist of steps for reducing uncertainty. Based on this similarity, we conclude that tree-based ML classification models are appropriate and competent for medical decision making, including efficient rehabilitation. We expect that tree-based ML or DL models will be applied extensively to other medical domains for alleviating clinicians' biases during decision making [1] and for developing digital health care platforms, such as Babylon check [41].  Data Availability Statement: Data sharing is not applicable.

Conflicts of Interest:
The authors declare no conflicts of interest.

Appendix A. Background of the Sampling Methods and Classification Models
We provide a brief summary of the conceptional background of the sampling methods and classification algorithms that we evaluated.
Appendix A.1. Background of the Sampling Methods

•
Oversampling (SMOTE): proposed by Chawla et al. [19], the synthetic minority oversampling technique (SMOTE) first chooses a single instance a from a minor class at random and arbitrarily selects a single instance b that is k-nearest to a. Then, it draws lines between them, on which a new synthetic instance is generated iteratively via a convex combination of a and b. • Undersampling (TomekLinks): the concept of "TomekLinks" is defined via satisfaction of the following conditions, for instance, for a and b [20]: (1) The two observations are the closest neighbors to each other measured by Euclidean distance.
(2) They belong to different class labels (e.g., a is in the minor class while b is in the major class, and vice versa). Then, the observations in the major class, considered as ambiguous examples, are removed to balance the class distribution. • Combined sampling (SMOTETomek): Batista et al. [21] empirically demonstrated the effectiveness of the combination of SMOTE [19] and TomekLinks [20]. At first, SMOTE is applied for oversampling. After that, TomekLinks is conducted to remove ambiguous major class observations.

Appendix A.2. Background of the Classification Methods
• Support vector machines (SVM): SVM for classification [42] aims to find a proper hyperplane that best separates the instances into different classes. In other words, it tries to find a support vector that is orthogonal and maximizes the margin to the hyperplane. SVM uses some kernel tricks to replace the dot product of two vectors with the kernel function. • Decision Tree (DT): although there are many other tree-based ML algorithms, such as ID3 [43] and C4.5 [44], scikit-learn [23,24] uses the classification and regression trees (CART) [45] algorithm. CART is a binary tree classifier where nodes are split into two child nodes repeatedly with Gini's impurity index as a splitting criterion. With training data, the decision tree is structured in the direction that reduces Gini's index. • Perceptron (PT): PT [46,47] is one of the linear discriminant models for binary classification. The input vector x is transformed by a nonlinear transformation to output a feature vector φ(x). Then, it is used to construct the following linear model: where f (a) is a nonlinear activation function and where target values 1 and −1 correspond to classes 0 and 1, respectively. Then, the stochastic gradient descent algorithm is applied to the perceptron criterion error function to learn the optimal parameter w. • Light Gradient Boosting Machine (LightGBM): LightGBM [27] is a tree-based ML algorithm that utilizes a gradient boosting framework. It is a gradient-based decision tree (GBDT) with two newly proposed techniques to advance the accuracy and efficiency of GBDT (gradient-based one-sided sampling and exclusive feature bundling). With these components, it successfully deals with a large amount of data instances and features efficiently. It grows its nodes in a leaf-wise manner by selecting nodes that decrease loss. This procedure is different from other tree-based ML algorithms, such as GBT [48], GBDT [49], GBM [50], MART [51], and RF [35]. • Automated machine learning (AutoML): AutoML is proposed to automate ML processes such as data preprocessing, algorithm learning, hyperparameter tuning, and evaluation to apply ML to real-world problems. There are two issues regarding Au-toML: combined algorithm selection and hyperparameter optimization (CASH) [52], and neural architecture search (NAS) [53]. Between them, we focused on the CASH problem to find the optimal (best-fitted) algorithms for the data collected and drew similarities between the chosen models and the diagnostician's prescription process in the real world. Although numerous developed AutoML packages exist, we utilized the latest and best performing AutoGluon [29] library package. • SuperTML: proposed by Sun et al. [30], SuperTML suggested a new way to deal with classification problems using tabular data with deep neural networks by embedding each instance's features into a two-dimensional image. It then uses a pretrained convolutional neural network (CNN) [54], consisting of residual networks (ResNet) [2], to extract a representation of the images, after which fully connected layers (with two hidden layers) classify the input. It also automatically handles the categorical and missing values without any preprocessing. • TabNet: similar to tree-based ML algorithms, Arik and Pfister [31] designed a new deep neural network model that performs similarly to the way the tree-based models perform for tabular data (named as TabNet). While the tree-based algorithms efficiently select global features with information gain [26], TabNet also calculates the weights of each instance's features via step operation. In the step operation, an attentive transformer outputs a mask that is used to take an element-wise product with each batch-sized instance to calculate a sequence of the feature importance. This process belongs to TabNet's encoder. Although TabNet also has a decoder, it is for unsupervised learning only. That is why we used only the encoder part for supervised learning with six-step operations.

Appendix B. Formulations of Measurements: Accuracy, Precision, Recall, F1-Score, and Balanced Accuracy
The measurements for evaluating the performance of the classification models are computed as follows: • Accuracy:

Appendix C. Details of the Data
We present the collected dataset in a numeric and categorical manner. The numeric variables (in Table A1) are composed of anthropometry, stroke, blood test, functional assessment, and biosignal ward, which are summarized by mean, standard deviation (SD), and range. The categorical ones (in Tables A2-A6) consist of disease, stroke, and functional assessment, summarized by the number of observations (denoted as '#') and percentages (%).