1. Introduction
Survival analysis is an important and fundamental tool for modeling applications when using time-to-event data [
1], which can be encountered in medicine, reliability, safety, finance, etc. This is a reason why many machine learning models have been developed to deal with time-to-event data and to solve the corresponding problems in the framework of survival analysis [
2]. The crucial peculiarity of time-to-event data is that a training set consists of censored and uncensored observations. When time-to-event exceeds the duration of an observation, we have a censored observation. When an event is observed, i.e., time-to-event coincides with the duration of the observation, we deal with an uncensored observation.
Many survival models are able to cover various cases of time-to-event probability distributions and their parameters [
2]. One of the important models is the Cox proportional hazards model [
3], which can be regarded as a semi-parametric regression model. There are also many parametric and nonparametric models. When considering machine learning survival models, it is important to point out that, in contrast to other machine learning models, their outcomes are functions, for instance, survival functions, hazard functions or cumulative hazard functions. For instance, the well-known effective model called the random survival forest (RSF) [
4] predicts survival functions (SFs) or cumulative hazard functions.
An important area of survival model application is the problem of treatment effect estimation, which is often solved in the framework of machine learning problems [
5]. The treatment effect shows how a treatment may be efficient depending on characteristics of a patient. The problem is solved by dividing patients into two groups called treatment and control, such that patients from the different groups can be compared. One of the popular measures of efficient treatment that is used in machine learning models is the average treatment effect (ATE) [
6], which is estimated on the basis of observed data about patients, such as the mean difference between outcomes of patients from the treatment and control groups.
Due to the difference between characteristics of patients and their responses to a particular treatment, the treatment effect is measured using the conditional average treatment effect (CATE), which is defined as the mean difference between outcomes of patients from the treatment and control groups, conditional on a patient feature vector [
7]. In fact, most methods of CATE estimation are based on constructing two regression models for controls and treatments. However, two difficulties in CATE estimation can be met. The first one is that the treatment group is usually very small. Therefore, many machine learning models cannot be accurately trained on the small datasets. The second difficulty is fundamental. Each patient cannot be simultaneously in the treatment and control groups, i.e., we either observe the patient outcome under the treatment or control, but never both [
8]. Nevertheless, to overcome these difficulties, many methods for estimating CATE have been proposed and developed due to the importance of the problem in many areas [
9,
10,
11,
12,
13].
One of the approaches for constructing regression models for controls and treatments is the application of the Nadaraya–Watson kernel regression [
14,
15], which uses standard kernel functions, for instance, the Gaussian, uniform or Epanechnikov kernels. In order to avoid selecting a standard kernel, Konstantinov et al. [
16] proposed to implement kernels and the whole Nadaraya–Watson kernel regression by using a set of identical neural subnetworks with shared parameters, with a specific way of the network training. The corresponding method called TNW–CATE (Trainable Nadaraya–Watson regression for CATE) is based on an important assumption that domains of the feature vectors from the treatment and control groups are similar. Indeed, we often treat patients after being in the control group, i.e., it is assumed that treated patients came to the treatment group from the control group. For example, it is difficult to expect that patients with pneumonia will be treated with new drugs for stomach disease. The neural kernels (kernels implemented as the neural network) are more flexible, and they can accurately model a complex location structure of feature vectors, for instance, when the feature vectors from the control and treatment group are located on the spiral, as shown in
Figure 1, where small triangular and circle markers correspond to the treatment and control groups, respectively. This is another important peculiarity of the TNW–CATE. Results provided in [
16] illustrated outperformance of the TNW–CATE in comparison with other methods when the treatment group was very small and the feature vectors had complex structure.
Following the ideas behind the TNW–CATE, we propose the CATE estimation method, called BENK (the Beran Estimator with Neural Kernels), dealing with censored time-to-event data in the framework of survival analysis. The main idea behind the proposed method is to apply the Beran estimator [
17] for estimating SFs of treatments and controls and to compare them for estimating the CATE. One of the important peculiarities of the Beran estimator is that it takes into account distances between feature vectors by using kernels which measure the similarity between any two feature vectors. On the one hand, the Beran estimator can be regarded as an extension of the Kaplan–Meier estimator. It allows us to obtain SFs that are conditional on the feature vectors, which can be viewed as outcomes of regression survival models for the treatment and control groups. On the other hand, the Beran estimator can also be viewed as an analogue of the Nadaraya–Watson kernel regression for survival analysis. However, typical kernels, for example, the Gaussian one, cannot cope with the possible complex structure of data. Therefore, similarly to the TNW–CATE model, we propose to implement kernels in the Beran estimator by means of neural subnetworks and to estimate CATE by using the obtained SFs. The whole neural network model is trained in an end-to-end manner.
Various numerical experiments illustrate BENK and its peculiarities. They also show that BENK outperforms many well-known meta-models: the T-learner and the S-learner, the X-learner for several control and treatment output functions based on the Cox models, the RSF and the Beran estimator with Gaussian kernels.
BENK is implemented using the framework PyTorch with open code. The code of the proposed algorithms can be found at
https://github.com/Stasychbr/BENK (accessed on 27 October 2023).
The paper is organized as follows.
Section 2 is a review of the existing CATE estimation models, including CATE estimation survival models, the Nadaraya–Watson regression models and general survival models. A formal statement of the CATE estimation problem is provided in
Section 3. The CATE estimation problem in the case of censored data is stated in
Section 4. The Beran estimator is considered in
Section 5. A description of BENK is provided in
Section 6. Numerical experiments illustrating BENK and comparing it with other models can be found in
Section 7. Concluding remarks are provided in
Section 8.
2. Related Work
Estimating CATE. One of the important approaches to implement personalized medicine is the treatment effect estimation. As a result, many interesting machine learning models have been developed and implemented to estimate CATE. First, we have to point out an approach which uses the Lasso model for estimating CATE [
18]. The SVM was also applied to solve the problem [
19]. A unified framework for constructing fast tree-growing procedures for solving the CATE problem was provided in [
20]. McFowland et al. [
21] estimated CATE by using the anomaly detection model. A set of meta-algorithms or meta-learners, including the T-learner, the S-learner and the X-learner, were studied in [
12]. Many other models related to the CATE estimation problem are studied in [
22,
23].
The aforementioned models are constructed by using machine learning methods, which are different from neural networks. However, neural networks became a basis for developing many interesting and efficient models [
24,
25,
26,
27].
Due to the importance of the CATE problem, there are many other publications devoted to this problem [
28,
29,
30,
31].
The next generation of models that solve the CATE estimation problem is based on architectures of transformers with the attention operations [
32,
33,
34]. The transfer learning technique was successfully applied to the CATE estimation in [
35,
36]. Ideas of using the Nadaraya–Watson kernel regression in the CATE estimation were studied in [
37]. These ideas can lead to the best results under the condition of large numbers of examples in the treatment and control groups. At the same time, a small amount of training data may lead to overfitting and unsatisfactory results. Therefore, the problem of overcoming this possible limitation motivated researchers to introduce a neural network of a special architecture, which implements the trainable kernels in the Nadaraya–Watson regression [
16].
Machine learning models in survival analysis. The importance of survival analysis applications can be regarded as one of the reasons for developing many machine learning methods that deal with censored and time-to-event data. A comprehensive review of machine learning survival models is presented in [
2]. A large portion of models use the Cox model, which can be viewed as a simple and applicable survival model that establishes a relationship between covariates and outcomes. Various extensions of the Cox model have been proposed. They can be conditionally divided into two groups. The first group remains the linear relationship of covariates and includes various modifications of the Lasso models [
38]. The second group of models relaxes the linear relationship assumption accepted in the Cox model [
39].
Many survival models are based on using the RSFs, which can be regarded as powerful tools, especially when models learn on tabular data [
40,
41]. At the same time, there are many survival models based on neural networks [
42,
43].
Estimating CATE with censored data. Censored data can be regarded as an important type, especially for estimating the treatment effect because many applications are characterized by time-to-event data as outcomes. This peculiarity is a reason for developing many CATE models that deal with censored data in the framework of survival analysis [
44,
45,
46]. Modifications of the survival causal trees and forests for estimating the CATE based on censored observational data were proposed in [
44]. An approach combining a treatment-specific semi-parametric Cox loss with a treatment-balanced deep neural network was studied in [
47]. Nagpal et al. [
48] presented a latent variable approach to model the CATE under assumption that an individual can belong to one of the latent clusters with distinct response characteristics. The problem of CATE estimation by focusing on learning (discrete-time) treatment-specific conditional hazard functions was studied in [
49]. A three-stage modular design for estimating CATE in the framework of survival analysis was proposed in [
50]. A comprehensive simulation study presenting a wide range of settings, describing CATE by taking into account the covariate overlap, was carried out in [
51]. Rytgaard et al. [
52] presented a data-adaptive estimation procedure for estimation of the CATE in a time-to-event setting based on generalized random forests. The authors proposed a two-step procedure for estimation, applying inverse probability weighting to construct time-point-specific weighted outcomes as input for the forest. A unified framework for counterfactual inference, applicable to survival outcomes and formulation of a nonparametric hazard ratio metric for evaluating the CATE, were proposed in [
53].
In spite of many works and results devoted to estimating the CATE with censored data, these methods are mainly based on assumptions of a large number of examples in the treatment group. Moreover, there are no results implementing the Nadaraya–Watson regression by means of neural networks.
3. CATE Estimation Problem Statement
According to the CATE estimation problem, all patients are divided into two groups: control and treatment. Let the control group be the set of c patients, such that the i-th patient is characterized by the feature vector and the i-th observed outcome (time to event, temperature, the blood pressure, etc.). It is also supposed that the treatment group is the set of t patients, such that the i-th patient is characterized by the feature vector and the i-th observed outcome . The indicator of a group for the i-th patient is denoted as , where () corresponds to the control (treatment) group.
We use different notations and for controls and treatments in order to avoid additional indices. However, we use the vector instead of and when estimating the CATE.
Suppose that the potential outcomes of patients from the control and treatment groups are
F and
H, respectively. The treatment effect for a new patient with the feature vector
is estimated by the individual treatment effect, defined as
. The fundamental problem of computing the CATE is that only one of the outcomes
f or
h for each patient can be observed. An important assumption of unconfoundedness [
54] is used to allow the untreated patients to be used to construct an unbiased counterfactual for the treatment group [
55]. According to the assumption, potential outcomes are characteristics of a patient before the patient is assigned to a treatment condition, or, formally, the treatment assignment
T is independent of the potential outcomes for
F and
H that conditional on the feature vector
, which can be written as
The second assumption, called the overlap assumption, regards the joint distribution of treatments and covariates. This assumption claims that a positive probability of being both treated and untreated for each value of
exists. This implies that the following holds with probability 1:
Let
be the random feature vector from
. The treatment effect is estimated by means of CATE, which is defined as the expected difference between two potential outcomes, as follows [
56]:
By using the above assumptions, CATE can be rewritten as
The motivation behind unconfoundedness is that nearby observations in the feature space can be treated as having come from a randomized experiment [
7].
Suppose that functions
and
express outcomes of the control and treatment patients, respectively. Then, they can be written as follows:
where
is noise governed by the normal distribution with the zero expectation.
The above imply that the CATE can be estimated as
An example illustrating the controls (circle markers), treatments (triangle markers) and corresponding unknown function
and
are shown in
Figure 1.
4. CATE with Censored Data
Before considering the CATE estimation problem with the censored data, we introduce basic statements of survival analysis. Let us define the training set
, which consists of
c triplets
,
, where
is the feature vector characterizing the
i-th patient from the control group,
is the time to the event concerning the
i-th control patient and
is the indicator of censored or uncensored observations. If
, then the event of interest is observed (the uncensored observation). If
, then we have the censored observation. Only the right-censoring is considered when the observed survival time is less than or equal to the true survival time. Many applications of survival analysis deal with the right-censored observations [
2]. The main goal of survival machine learning modeling is to use set
to estimate probabilistic characteristics of time
F to the event of interest for a new patient with the feature vector
.
In the same way, we define the training set , which consists of d triplets , , where is the feature vector characterizing the i-th patient from the treatment group, is the time to the event concerning the i-th treatment patient and is the indicator of censoring.
The survival function (SF), denoted
, can be regarded as an important concept in survival analysis. It represents the probability of survival of a patient with the feature vector
up to time
t, that is,
. The hazard function, denoted
, can be viewed as another concept in survival analysis. It is defined as the rate of an event at time
t given that no event occurred before time
t. It is expressed through the SF as follows:
The integral of the hazard function, denoted
, is called the cumulative hazard function and can be interpreted as the probability of an event at time
t given survival until time
t, i.e.,
It is expressed through the SF as follows:
The above functions for controls and treatments are written with indices 0 and 1, respectively, for instance, and .
In order to compare survival models, Harrell’s concordance index, or the C-index [
57], is usually used. The C-index measures the probability that, in a randomly selected pair of examples, the example that failed first had a worst predicted outcome. It is calculated as the ratio of the number of pairs correctly ordered by the model to the total number of admissible pairs. A pair is not admissible if the events are both right-censored or if the earliest time in the pair is censored. The corresponding survival model is supposed to be perfect when the C-index is 1. The case when the C-index is 0.5 says that the survival model is the same as random guessing. The case when the C-index is less than 0.5 says that the corresponding model is worse than random guessing.
In contrast to the standard CATE estimation problem statement given in the previous section, the CATE estimation problem with censored data has another statement, which is due to the fact that outcomes in survival analysis are random times to an event of interest having some conditional probability distribution. In other words, predictions corresponding to a patient characterized by vector in survival analysis provided by a survival machine learning model are represented in the form of functions of time, for instance, in the form of SF . This implies that the CATE should be reformulated by taking into account the above peculiarity. It is assumed that SFs as well as hazard functions for control and treatment patients, estimated by using datasets and , will have indices 0 and 1, respectively.
The following definitions of the CATE in the case of censored data can be found in [
58]:
Difference in expected lifetimes:
We will the first integral definition of the CATE. Let
be the distinct times to an event of interest, which are obtained from the set
. The SF provided by a survival machine learning model is a step function, i.e., it can be represented as
, where
is the indicator function, taking a value of 1 if
;
is the value of the SF in interval
. Hence, the following holds:
6. Neural Network for Estimating CATE
Let us consider how the Beran estimator with neural kernels can be implemented by means of a neural network of a special type. Our first aim is to implement kernels
by means of a neural subnetwork, which is called the neural kernel and is a part of the whole network for implementing the Beran estimator. The second aim is for this network to learn on the control data. Having the trained kernel, we can apply it to compute the conditional survival function for controls, as well as for treatments, because the kernels in (
14) do not directly depend on times to events
or
. However, in order to train the kernel, we have to train the whole network because the loss function is defined through SF
, which represents the probability of survival of a control patient up to time
t, which is estimated by means of the Beran estimator. This implies that the whole network contains blocks of the neural kernels for computing kernels
, normalization for computing the kernel weights
and the Beran estimator in accordance with (
14). In order to realize a training procedure for the network, we randomly select a portion (
n examples) from all control training examples and form a single specific example from
n selected ones. This random selection is repeated
N times to have
N examples for training. Thus, for every
,
, from the control group, we add another vector
from the same set of controls. By composing
n pairs of vectors
,
, and including other elements of training examples (
), we obtain one composite vector of data, representing one new training example for the entire neural network. Such new training examples can be constructed for each
. The formal construction of the training set is considered below.
Having the trained neural kernel, it can be successfully used for computing SF of controls and SF of treatments for arbitrary vectors of features , again applying the Beran estimator.
Let us consider the training algorithm in detail. First, we return to the set of
c controls
. For every
i from set
, we construct
N subsets
,
, having
n examples randomly selected from
, which have indices from the index set
, i.e., the subsets
are of the form
Here,
N and
n can be regarded as tuning hyperparameters. Upper index
r indicates that the
r-th example
is randomly taken from
, i.e., there is an example
from
such that
,
,
. Each subset
, jointly with
, forms a training example
for the control network as follows:
The number of possible examples is , and these examples are used for training the neural network, whose output is the estimate of SF .
The architecture of the neural network, consisting of
n subnetworks that implement the neural kernels, is shown in
Figure 2. Examples
produced from the dataset of controls are fed to the whole neural network, such that each pair
,
, is fed to each subnetwork, which implements the kernel function. The output of each subnetwork is kernel
. All subnetworks are identical and have shared weights. After normalizing the kernels, we obtain
n weights
, which are used to estimate SFs by means of the Beran estimator in (
14). The block of the whole neural network that implements the Beran estimator uses all weights
,
, and the corresponding values
and
,
. As a result, we obtain SF
. In the same way, we compute SFs
for all
. These functions are the basis for training. In fact, the normalization block and the block that implements the Beran estimator can be regarded as part of the neural network, and they are trained in an end-to-end manner.
According to (
13), expected lifetimes are used to compute the CATE
. Therefore, the whole network is trained by means of the following loss function:
Here,
is a subset of
, which contains only uncensored examples from
,
is the number of elements in
;
is the time to an event of the
k-th example from the set
and
is the expected lifetime computed through SF
, obtained by integrating the SF:
The sum in (
18) is taken over uncensored examples from
. However, the Beran estimator uses all the examples.
One of the loss functions, which takes into account all data (censored and uncensored), is the C-index. However, our aim is not to estimate the SF or the CHF. We aim to estimate the difference between the predicted time to event and the expected time to event. Therefore, we use the standard mean squared error (MSE) loss function. But the censored times introduce bias into MSE and, therefore, they are not used.
It is important to point out that our aim is to train subnetworks with shared training parameters, which are the neural kernels. By having the trained neural kernels, we can use them to compute kernels
and
and then to compute estimates of SFs
and
for controls and treatments, respectively, i.e., we realize the idea of transferring tasks from the control group to the treatment group. Let
and
be the ordered time moments corresponding to times
and
, respectively. Then, the CATE
can be computed through SFs
and
, again by using the Beran estimators with the trained neural kernels, i.e., in accordance with (
13), it holds that
where
is the estimation of the SF of treatments on the interval
,
is the estimation of SF of controls in interval
and it is assumed that
.
The illustration of the neural networks that predict
and
for a new vector
of features is shown in
Figure 3. It can be seen from
Figure 3 that the first neural network consists of
c subnetworks, such that pairs of vectors
,
, are fed to the subnetworks, where
is taken from the dataset of controls. Predictions of the first neural network are
c kernels
, which are used to compute
by means of the Beran estimator (
14). The same architecture has the neural network for predicting kernels
, used for estimating the treatment SF
. This network consists of
s subnetworks and uses vectors
from the dataset of treatments. After computing estimates
and
, we can find the CATE
.
Phases of training and computing CATE
by means of neural kernels are schematically shown as Algorithms 1 and 2, respectively.
Algorithm 1 The algorithm for training neural kernels |
- Require:
Datasets of c controls and of s treatments, number N of generated subsets of , number of examples in generated subsets n - Ensure:
Neural kernels for their use in the Beran estimator for control and treatment data - 1:
for , do - 2:
for , do - 3:
Generate subset - 4:
Form example - 5:
end for - 6:
end for - 7:
Train the weight sharing neural network with the loss function given in ( 18) on the set of examples
|
Algorithm 2 The algorithm for computing CATE for a new feature vector |
- Require:
Trained neural kernels, datasets and , testing example - Ensure:
CATE - 1:
for , do - 2:
Form pair of vectors by using the dataset of controls - 3:
Feed pair to the trained neural kernel and predict - 4:
end for - 5:
for , do - 6:
Form pair of vectors by using the dataset of treatments - 7:
Feed pair to the trained neural kernel and predict - 8:
end for - 9:
Compute , , , - 10:
Estimate and using ( 14) - 11:
Compute using ( 20)
|
7. Numerical Experiments
Numerical experiments for studying BENK and its comparison with available models are performed by using simulated datasets because the true CATEs are unknown due to the fundamental problem of causal inference for real data [
8]. This implies that control and treatment datasets are randomly generated in accordance with predefined outcome functions.
7.1. CATE Estimators for Comparison and Their Parameters
For investigating BENK and its comparison, we use nine models, which can be united in three groups (the T-learner, the S-learner, the X-learner), such that each group is based on three base models for estimating SFs (the RSF, the Cox model, the Beran estimator with Gaussian kernels). The models are given below in terms of survival models:
The T-learner [
12] is a model which estimates the control SF
and the treatment SF
for every
. The CATE in this case is defined in accordance with (
13);
The S-learner [
12] is a model which estimates SF
instead of
and
, where the treatment assignment indicator
is included as an additional feature to the feature vector
. As a result, we have a modified dataset
where
if
,
, and
if
,
. The CATE is determined as
The X-learner [
12] is based on computing the so-called imputed treatment effects and is represented in the following three steps. First, the outcome functions
and
are estimated using a regression algorithm. Second, the imputed treatment effects are computed as follows:
Third, two regression functions and are estimated for imputed treatment effects and , respectively. The CATE for a point is defined as a weighted linear combination of the functions and as , where is a weight that is equal to the ratio of treated patients. The original X-learner does not deal with censored data. Therefore, we propose a simple survival modification of the X-learner. It is assumed that and are expectations and of the times to an event corresponding to control and treatment data, respectively. Expectations and are computed by means of one of the algorithms for determining estimates of SFs and . The functions and are implemented using the random forest regression algorithm for all the basic models.
Estimations of SFs and as well as are carried out by means o the following survival regression algorithms:
The RSF parameters of random forests used in experiments are the following:
The numbers of trees are 10, 50, 100, 200;
The depths are 3, 4, 5, 6;
The smallest values of examples which fall in a leaf are 1 example, 1%, 5%, 10% of the training set.
The above values for the hyperparameters are tested, choosing those leading to the best results;
The Cox proportional hazards model [
3], which is used with the elastic net regularization with the 3 to 1 ratio coefficient
/
;
In contrast to the proposed BENK model, we use the Beran estimator with the standard Gaussian kernels. Values , , and also values , 5, 50, 200, 500, 700 of the bandwidth parameter of the Gaussian kernel are tested, choosing those leading to the best results.
In sum, we have nine models for comparison, whose notations are given in
Table 1.
7.2. Generating Synthetic Datasets
As has been described above, we consider generating the artificial complex feature spaces and outcomes in the numerical experiments. All the vectors of features, including controls and treatments , are generated by means of three functions: the spiral function, the bell-shaped function and the circular function. The idea to use these functions stems from the goal to obtain complex structures of data, which are poorly processed by many standard methods. The above functions are defined through a parameter as follows:
Spiral functions: The feature vectors, having dimensionality
d and being located on the Archimedean spirals, are defined for even
d as
and for odd
d as
Values of are uniformly generated from the interval for all numerical experiments;
Bell-shaped functions: Features are represented as a set of almost non-overlapping Gaussians. As
is uniformly generated in the numerical experiments, we can define
and
as corresponding bounds of the uniform distribution. Therefore, the feature vector of dimensionality
d is represented as
Therefore, each feature corresponds to its own region in the distribution;
Circular functions: The corresponding feature space is generated by using only the even numbers of features. The feature vectors are located on the two-dimensional circles as follows:
where
is an indicator function.
Each pair of features corresponds to their own two-dimensional circle and to their own region in the distribution.
In all experiments, feature vectors
are generated in the same way as vectors
. However, for feature vectors
and
, from the control and treatment groups, the corresponding times to events
f and
h are different and are generated by using the Weibull distribution, as follows:
where
u is the random variable, uniformly distributed on the interval
; values
f and
h larger than 2000 are clipped to this value.
This way for generating f and h is in agreement with the Cox model. Hence, we can use the Cox model as a base model among RSFs and the Beran estimator with Gaussian kernels in the numerical experiments.
The proportion of censored data, denoted as p, is taken as 33% of all observations in the experiments. Hence, parameters of censoring and are generated from the binomial distribution with probabilities = 0.67, = 0.33.
The Precision in Estimation of Heterogeneous Effects metric (PEHE), proposed in [
61], is used to reduce the variance in the numerical experiments. According to [
61], this metric evaluates the ability of each method to capture treatment effect heterogeneity.
If we label the test dataset as
, then the PEHE can be defined as follows:
where
is the size of the set
, taken for all numerical experiments as
.
The proportion of treatments and controls in most experiments is , except for experiments studying how the proportion of treatments impacts the CATE, where the proportion of treatments and controls is denoted as q. For example, if 100 controls are generated for an experiment with , then 20 treatments are generated in addition to controls, such that the total number of examples is 120. The generated feature vectors in all experiments consist of 10 features; the volume of the set is 300 unless otherwise stated. To select optimal hyperparameters of BENK, additional validation examples are generated, such that they belong to only the control group, and the size of this additional validation set is of the set size. After the BENK neural network training, this validation set is concatenated with for other models, which are trained using cross-validation with three splits. For studying the dependencies, we repeat the numerical experiments 100 times and provide the mean values across these 100 iterations.
Each subnetwork is a fully connected neural network consisting of five layers, with corresponding activation functions ReLU6, ReLU6, ReLU6, Tanh, Softplus. Inputs for each subnetwork are represented in the form to ensure the symmetry property of kernels. The non-negativity property of neural kernels is achieved by using the activation function Softplus in the last layer of the subnetworks, which ensures that the output is always positive.
7.3. Study of the BENK Properties
In all pictures illustrating results of numerical experiments, dotted curves correspond to the T-learner (triangle markers), the S-learner (triangle markers) or the X-learner (the circle marker) under the condition of using the Beran estimator with the Gaussian kernels. Dash-and-dot curves correspond to the Cox models. Dashed curves with the same markers correspond to the same models implemented using RSFs. The solid curve with cross markers corresponds to BENK. The PEHE metric is used to represent results of experiments. The smaller the values of the PEHE, the better the obtained results. To avoid clutter of curves on the figures, we pick the best model for each T-,S- or X-learner obtained in each experiment.
First, we study different CATE estimators using different numbers
c of controls, taking the values 100, 200, 300, 500, 1000. The number of treatments
t is determined as
of the number of controls. Values of
n are equal to
.
Figure 4,
Figure 5 and
Figure 6 illustrate how values of the PEHE metric depend on the number
c of controls for different estimators when different functions are used for generating examples.
Figure 4 shows the difference between the PEHE metric of BENK and other models in the experiment, with the feature vectors located around the spiral. The T-SF, S-Beran and X-SF models are provided in
Figure 4 because they show the best competitive metric values. In order to illustrate how the variance in results depends on the amount of input data, the error bars are also depicted in
Figure 4. It can be seen from
Figure 4 that the variance in results is reduced with the number of controls. This property of results indicates that the neural network is properly trained. We do not add the error bars to other graphs so as to not mask the relative positions of the corresponding curves.
Figure 5 illustrates similar dependencies when the bell-shaped function is used for generating the feature vectors. The selected models in this case are T-Cox, S-SF and X-Cox.
Figure 6 illustrates the relationship between different models obtained on the circular feature space. The competitive algorithms given in the picture are T-Beran, S-Beran and X-Beran. It can be seen from
Figure 4,
Figure 5 and
Figure 6 that the proposed model BENK provides better results in comparison with other models. The largest relative difference between BENK and other models can be observed when the feature vectors are generated in accordance with the spiral function. This function produces the most complex data structure, such that other studied models cannot cope with it.
Another interesting question is how the CATE estimators depend on the proportion
q of treatments and controls in the training set. Particularly, for the proposed BENK model, we try to study whether an increasing number of treatments (the set
) provides better CATE results with an unchanged number of controls (the set
). The corresponding numerical results are shown in
Figure 7,
Figure 8 and
Figure 9. One can see from
Figure 7,
Figure 8 and
Figure 9 that the enhancement in the PEHE is sufficient in comparison with other CATE estimators when
q is changed from 10% to 20% in the experiments with the spiral and bell-shaped functions. Moreover, we again observe the outperformance of BENK in comparison with other estimators.
In the previous experiments, the amount of the censored data was taken
of all observations. However, it is interesting to study how this amount impacts the PEHE of the CATE estimators.
Figure 10,
Figure 11 and
Figure 12 illustrate the corresponding dependencies when different generating functions are used. It can be seen from
Figure 10,
Figure 11 and
Figure 12 that the PEHE metrics for all estimators, including BENK, increase with the amount of censored data.
Table 2 aims to quantitatively compare results under the following conditions:
,
,
,
,
. One can see from
Table 2 that BENK provides outperforming results. Let us compare results obtained for BENK with the results provided by other models in
Table 2. For comparison, we can apply the standard t-test. The obtained
p-values for all pairs of models are shown in the last column. We can see from
Table 2 that all
p-values are smaller than
. Hence, we can conclude that the outperformance of BENK is statistically significant. It is interesting to note from
Table 2 that methods based on the Cox model (T-Cox, S-Cox, X-Cox) show worse results. This can be explained by the weak assumption of the linear relationship of features, which takes place in the Cox model. This assumption contradicts the complex spiral, bell-shaped and circular functions and does not allow us to obtain better results. It should be pointed out that T-NW provides the best result for the bell-shaped generating function among results given by methods other than BENK. This is explained by the fact that the bell-shaped function is close to the Gaussian function; therefore, the method based on using Nadaraya–Watson kernel regression does not crucially differ from BENK. It is also interesting to note that the efficient methods such as the S-learner and the X-learner often provide worse results in comparison with the T-learner, which is rather weak in standard CATE tasks. This is due to peculiarities of survival data, which differ from the standard regression and classification data.
It should be noted that we did not provide results of various deep neural network extensions of the CATE estimators because they have not been successful. The problem is that neural networks require a large amount of data for training and the considered small datasets have led to overfitting the networks. This is why we studied models which provide satisfactory predictions under condition of small amounts of data.
8. Conclusions
A new method called BENK for solving the CATE problem under the condition of censored data has been presented. It extends the idea behind TNW–CATE proposed in [
16] to the case of censored data. In spite of many similar parts of TNW-CATE and BENK, they are different because BENK is based on using the Beran estimator for training and can be successfully applied to survival analysis of controls and treatments. However, TNW–CATE and BENK use the same idea to train neural kernels: implementation as neural networks instead of using standard kernels.
It is also interesting to point out that BENK does not require oneto have a large dataset for training, even though the neural network is used for implementing the kernels. This is due to a special way that is proposed to train the network, which considers pairs of examples from the control group for training, as in Siamese neural networks. Our experiments have illustrated the outperforming characteristics of BENK. At the same time, we have to point out some disadvantages of BENK. First, it has many tuning parameters, including parameters of the neural network and parameters of training n and N, such that the training time may be significantly increased in comparison with other methods of solving the CATE problem. Second, BENK assumes that the feature vector domains are similar for controls and treatments. This does not mean that they have to totally coincide, but the corresponding difference in domains should not be very large. A method which could take into account a possible difference between the feature vector domains for controls and treatments can be regarded as a direction for further research. An idea behind the method is to combine the domain adaptation models and BENK.
Another direction for further research is to study robust versions of BENK when there are anomalous observations that may impact training the neural network. An idea behind the robust version is to use attention weights for feature vectors and also to introduce additional attention weights for predictions.
It should be noted that the Beran estimator is one of several estimators that are used in survival analysis. Moreover, we have studied only the difference in expected lifetimes as a definition of the CATE in the case of censored data. There are other definitions, for instance, the difference in SFs and the hazard ratio, which may lead to more interesting models. Therefore, BENK implementations and studies using other estimators and definitions of the CATE can be also considered as directions for further research.
The proposed method can be used in applications that are different from medicine. For example, it can be applied to selection and control of the most efficient regimes in the Internet of Things. This is also an interesting direction for further research.