A Copula-Driven CNN-LSTM Framework for Estimating Heterogeneous Treatment Effects in Multivariate Outcomes
Abstract
1. Introduction
2. Methods
- 1.
- Empirical Copula CNN-LSTM: Inputs are transformed using rank-based empirical copulas before feeding into a CNN-LSTM multitask model.
- 2.
- Plain CNN-LSTM: Identical model architecture but without copula preprocessing.
- 3.
- Causal Forest: A tree-based nonparametric estimator using generalized random forests for HTEs.
2.1. Empirical Copula CNN-LSTM
- -
- is the intermediate feature representation learned by a 1D convolutional layer. This layer applies sliding filters across covariates to extract local patterns.
- -
- denotes the Rectified Linear Unit, a commonly used activation function defined as , which introduces nonlinearity and helps avoid vanishing gradients.
- -
- is the hidden state output of a Long Short-Term Memory (LSTM) network, a type of recurrent neural network (RNN) designed to capture long-range dependencies and temporal dynamics in sequences. The LSTM includes gating mechanisms to control memory updates and retention, making it well-suited for sequential or structured input data.
- Mean squared error (MSE) for the continuous outcome ();
- Poisson negative log-likelihood for the count outcome ();
- Cox partial likelihood loss for the censored survival outcome () [21].
2.2. Plain CNN-LSTM
2.3. Causal Forest
3. Simulation Study
3.1. Data Simulation Framework
3.2. Outcome Generation
- Continuous outcome ():
- Count outcome ():
- Survival outcome (): True event times are sampled from an exponential distribution:Censoring times are independently sampled from a uniform distribution to induce approximately 20% censoring. The observed time and event indicator are:
3.3. Copula-Based Covariate Transformation
3.4. Model Architectures
3.4.1. Empirical Copula CNN-LSTM
- A 1D convolutional layer with 32 filters (kernel size = 1) and ReLU activation.
- An LSTM layer with 64 hidden units.
- Three outcome-specific dense layers:
- -
- A linear output head for , trained using the mean squared error (MSE) loss.
- -
- A softplus output head for , trained using Poisson negative log-likelihood.
- -
- A linear output head for , trained using a custom Cox partial likelihood loss.
3.4.2. Plain CNN-LSTM
3.4.3. Causal Forest
3.5. Treatment Effect Estimation
3.6. Bootstrapped Confidence Intervals
- A resample of the dataset is drawn with replacement.
- The ATE is recomputed on the resample.
3.7. Sensitivity Analysis with Treatment Perturbation
- A new treatment vector was generated by randomly flipping a proportion of the original treatment assignments.
- The outcome generation, model fitting, and treatment effect estimation procedures were repeated.
3.8. Visualization and Reporting
- Point and error bar plots (95% CI) for ATEs.
- Bar plots with SE-based error bars for CATEs.
3.9. Data Analysis Results
- Perturbation Rate = 0.00 (Baseline)
- Perturbation Rate = 0.05
- Perturbation Rate = 0.10
- Perturbation Rate = 0.15
- Model Comparisons and Robustness Insights
- The Empirical Copula CNN-LSTM consistently delivers stable ATE estimates with low standard errors, especially for survival outcomes. Its empirical copula transformation likely contributes to robustness by modeling rank-based dependencies and reducing sensitivity to input scale or distributional shifts.
- The Plain CNN-LSTM tends to overestimate ATEs for count outcomes and exhibits increased variability under perturbation, likely due to its lack of dependence modeling.
- The Causal Forest performs well on continuous outcomes but underperforms on survival risk, especially as perturbation increases, possibly due to limitations in modeling high-dimensional interactions over time.
- Baseline Performance (Perturbation Rate = 0.00)
- Impact of Low-Level Perturbation (Rate = 0.05)
- Medium Perturbation Effects (Rate = 0.10)
- High Perturbation Effects (Rate = 0.15)
- Comparative Insights and Robustness Evaluation
- The Empirical Copula CNN-LSTM demonstrates superior robustness, with consistently low SEs and interpretable CATEs across continuous, count, and survival outcomes. Its ability to preserve distributional structure under noise via empirical copula transformations enhances model generalizability.
- The Plain CNN-LSTM is relatively stable in estimating continuous outcomes (Y1), but consistently overestimates count outcomes (Y2) and underestimates survival benefit (Y3), particularly as perturbation increases.
- The Causal Forest yields competitive performance in Y1 and Y2 at moderate perturbation levels, but persistently underestimates treatment effects for survival outcomes (Y3), likely due to the limitations of tree-based partitioning in capturing censored, time-to-event dynamics.
4. Real Data Analysis: COMPAS Dataset
- 1.
- Continuous Outcome: Jail time duration.
- 2.
- Count Outcome: Number of prior offenses.
- 3.
- Survival Outcome: Time to reoffense or synthetic survival data.
- A dense layer with mean squared error (MSE) loss for the continuous outcome.
- A dense layer with Poisson loss for the count outcome.
- A dense layer with a custom Cox partial likelihood loss for the censored survival outcome.
- 1.
- A baseline CNN-LSTM model without the copula transformation, trained on the original (non-rank-transformed) covariates.
- 2.
- A Causal Forest model from the generalized random forest framework, trained separately for each outcome type. This nonparametric tree-based estimator is capable of capturing HTEs and serves as a strong benchmark.
- Tables comparing ATEs, standard errors, and 95% confidence intervals across models and outcomes.
- CATE summaries showing the means and standard deviations of individual-level treatment effects.
- A point-range plot showing ATE estimates with confidence intervals across all outcomes and models.
- A bar plot displaying CATE mean effects with error bars (±1.96 standard errors) for each model and outcome type.
- Outcome Y1: Continuous (Days in Jail)
- Outcome Y2: Count (Number of Prior Offenses)
- Outcome Y3: Survival Risk
- Model Comparison
- High precision in ATE estimation across all outcomes;
- Consistent effect directionality with tighter confidence intervals;
- Superior handling of multivariate outcome structures.
- Outcome Y1: Continuous (Days in Jail)
- Outcome Y2: Count (Number of Prior Offenses)
- Outcome Y3: Survival Risk
- Model Comparisons
5. Discussion
Funding
Institutional Review Board Statement
Data Availability Statement
Acknowledgments
Conflicts of Interest
References
- Imbens, G.W.; Rubin, D.B. Causal Inference for Statistics, Social, and Biomedical Sciences; Cambridge University Press: Cambridge, UK, 2015. [Google Scholar]
- Hernán, M.A.; Robins, J.M. Causal Inference: What If; Chapman & Hall/CRC: Boca Raton, FL, USA, 2020. [Google Scholar]
- Pearl, J. Causality: Models, Reasoning and Inference, 2nd ed.; Cambridge University Press: Cambridge, UK, 2009. [Google Scholar]
- Shalit, U.; Johansson, F.D.; Sontag, D. Estimating individual treatment effect: Generalization bounds and algorithms. In Proceedings of the 34th International Conference on Machine Learning, Sydney, Australia, 6–11 August 2017; pp. 1–10. [Google Scholar]
- Alaa, A.M.; van der Schaar, M. Deep multitask Gaussian processes for survival analysis with competing risks. In Proceedings of the Advances in Neural Information Processing Systems 30 (NIPS 2017), Long Beach, CA, USA, 4–9 December 2017; Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., Garnett, R., Eds.; Curran Associates, Inc.: Red Hook, NY, USA, 2018; pp. 1–9. [Google Scholar]
- Shi, C.; Blei, D.M.; Veitch, V. Adapting neural networks for the estimation of treatment effects. In Proceedings of the 33rd International Conference on Neural Information Processing Systems, Vancouver, Canada, 8–14 December 2019; Article No.: 225. pp. 2507–2517. [Google Scholar]
- Yoon, J.; Jordon, J.; van der Schaar, M. GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets. In Proceedings of the International Conference on Learning Representations, Vancouver, BC, Canada, 30 April–3 May 2018; Available online: https://openreview.net/forum?id=ByKWUeWA- (accessed on 19 May 2025).
- Nelsen, R.B. An Introduction to Copulas; Springer Science & Business Media: Berlin/Heidelberg, Germany, 2006. [Google Scholar]
- Liu, H.; Han, F.; Yuan, M.; Lafferty, J.; Wasserman, L. The nonparanormal: Semiparametric estimation of high dimensional undirected graphs. J. Mach. Learn. Res. 2019, 10, 2295–2328. [Google Scholar]
- Nagler, T. A generic approach to nonparametric function estimation with mixed data. Stat. Probab. Lett. 2018, 137, 326–330. [Google Scholar] [CrossRef]
- Carroll, R.J.; Ruppert, D.; Stefanski, L.A.; Crainiceanu, C.M. Measurement Error in Nonlinear Models: A Modern Perspective; Chapman & Hall/CRC: Boca Raton, FL, USA, 2006. [Google Scholar]
- Penning de Vries, B.B.L.; van Smeden, M.; Groenwold, R.H.H. A weighting method for simultaneous adjustment for confounding and joint exposure-outcome misclassifications. Stat. Methods Med. Res. 2020, 30, 473–487. [Google Scholar] [CrossRef] [PubMed]
- Anoke, S.C.; Norm, S.L.; Zigler, C.M. Approaches to treatment effect heterogeneity in the presence of confounding. Stat. Med. 2019, 38, 2797–2815. [Google Scholar] [CrossRef] [PubMed]
- Rosenbaum, P.R. Observational Studies; Springer: Berlin/Heidelberg, Germany, 2002. [Google Scholar]
- Cinelli, C.; Hazlett, C. Making Sense of Sensitivity: Extending Omitted Variable Bias. J. R. Stat. Soc. Ser. B Stat. Methodol. 2020, 82, 39–67. [Google Scholar] [CrossRef]
- Kim, J.-M. Integrating Copula-Based Random Forest and Deep Learning Approaches for Analyzing Heterogeneous Treatment Effects in Survival Analysis. Mathematics 2025, 13, 1659. [Google Scholar] [CrossRef]
- Kim, J.-M. Treatment effect estimation in survival analysis using deep learning-based causal inference. Axioms 2025, 14, 458. [Google Scholar] [CrossRef]
- Shi, X.; Chen, Z.; Wang, H.; Yeung, D.-Y.; Wong, W.-K.; Woo, W.-C. Convolutional LSTM network: A machine learning approach for precipitation nowcasting. In Proceedings of the 29th International Conference on Neural Information Processing Systems, Montreal, QC, Canada, 7–12 December 2015; MIT Press: Cambridge, MA, USA, 2015; Volume 1, pp. 802–810. [Google Scholar]
- Bai, S.; Kolter, J.Z.; Koltun, V. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling. arXiv 2018, arXiv:1803.01271. [Google Scholar] [CrossRef]
- Athey, S.; Tibshirani, J.; Wager, S. Generalized random forests. Ann. Stat. 2019, 47, 1148–1178. [Google Scholar] [CrossRef]
- Cox, D.R. Regression Models and Life-Tables. J. R. Stat. Soc. Ser. B (Methodol.) 1972, 34, 187–202. [Google Scholar] [CrossRef]
- Wager, S.; Athey, S. Estimation and Inference of Heterogeneous Treatment Effects using Random Forests. J. Am. Stat. Assoc. 2018, 113, 1228–1242. [Google Scholar] [CrossRef]
- Angwin, J.; Larson, J.; Mattu, S.; Kirchner, L. Machine Bias: There’s Software Used Across the Country to Predict Future Criminals. And it’s Biased Against Blacks. ProPublica. 2016. Available online: https://www.propublica.org/article/machine-bias-risk-assessments-in-criminal-sentencing (accessed on 19 May 2025).
Outcome | Model | ATE | SE | CI_Lower | CI_Upper | Perturbation Rate |
---|---|---|---|---|---|---|
Y1 | Empirical Copula CNN-LSTM | 1.3179 | 0.0055 | 1.3071 | 1.3288 | 0.00 |
Y2 | Empirical Copula CNN-LSTM | 0.0648 | 0.0039 | 0.0571 | 0.0725 | 0.00 |
Y3 | Empirical Copula CNN-LSTM | −0.7792 | 0.0036 | −0.7863 | −0.7722 | 0.00 |
Y1 | Plain CNN-LSTM | 1.2354 | 0.0116 | 1.2126 | 1.2582 | 0.00 |
Y2 | Plain CNN-LSTM | 0.2206 | 0.0082 | 0.2046 | 0.2366 | 0.00 |
Y3 | Plain CNN-LSTM | −0.7666 | 0.0096 | −0.7856 | −0.7477 | 0.00 |
Y1 | Causal Forest | 1.2849 | 0.0038 | 1.2775 | 1.2923 | 0.00 |
Y2 | Causal Forest | 0.2057 | 0.0068 | 0.1924 | 0.2189 | 0.00 |
Y3 | Causal Forest | −0.1255 | 0.0029 | −0.1312 | −0.1198 | 0.00 |
Y1 | Empirical Copula CNN-LSTM | 1.4535 | 0.0063 | 1.4412 | 1.4658 | 0.05 |
Y2 | Empirical Copula CNN-LSTM | 0.1152 | 0.0016 | 0.1120 | 0.1184 | 0.05 |
Y3 | Empirical Copula CNN-LSTM | −0.3648 | 0.0024 | −0.3696 | −0.3600 | 0.05 |
Y1 | Plain CNN-LSTM | 1.3183 | 0.0112 | 1.2964 | 1.3402 | 0.05 |
Y2 | Plain CNN-LSTM | 0.1348 | 0.0032 | 0.1285 | 0.1411 | 0.05 |
Y3 | Plain CNN-LSTM | −0.3521 | 0.0044 | −0.3606 | −0.3435 | 0.05 |
Y1 | Causal Forest | 1.4624 | 0.0035 | 1.4555 | 1.4693 | 0.05 |
Y2 | Causal Forest | 0.2395 | 0.0030 | 0.2336 | 0.2454 | 0.05 |
Y3 | Causal Forest | −0.0855 | 0.0021 | −0.0895 | −0.0814 | 0.05 |
Y1 | Empirical Copula CNN-LSTM | 1.4563 | 0.0052 | 1.4462 | 1.4664 | 0.10 |
Y2 | Empirical Copula CNN-LSTM | 0.3549 | 0.0058 | 0.3436 | 0.3663 | 0.10 |
Y3 | Empirical Copula CNN-LSTM | −0.4057 | 0.0031 | −0.4117 | −0.3997 | 0.10 |
Y1 | Plain CNN-LSTM | 1.3751 | 0.0111 | 1.3532 | 1.3969 | 0.10 |
Y2 | Plain CNN-LSTM | 0.3978 | 0.0081 | 0.3819 | 0.4137 | 0.10 |
Y3 | Plain CNN-LSTM | −0.3271 | 0.0045 | −0.3360 | −0.3182 | 0.10 |
Y1 | Causal Forest | 1.4170 | 0.0078 | 1.4016 | 1.4323 | 0.10 |
Y2 | Causal Forest | 0.3706 | 0.0050 | 0.3607 | 0.3805 | 0.10 |
Y3 | Causal Forest | −0.0683 | 0.0017 | −0.0717 | −0.0648 | 0.10 |
Y1 | Empirical Copula CNN-LSTM | 1.3851 | 0.0034 | 1.3784 | 1.3918 | 0.15 |
Y2 | Empirical Copula CNN-LSTM | 0.1397 | 0.0027 | 0.1344 | 0.1450 | 0.15 |
Y3 | Empirical Copula CNN-LSTM | −0.3880 | 0.0027 | −0.3932 | −0.3827 | 0.15 |
Y1 | Plain CNN-LSTM | 1.4149 | 0.0097 | 1.3958 | 1.4340 | 0.15 |
Y2 | Plain CNN-LSTM | 0.2261 | 0.0041 | 0.2181 | 0.2341 | 0.15 |
Y3 | Plain CNN-LSTM | −0.3679 | 0.0044 | −0.3765 | −0.3592 | 0.15 |
Y1 | Causal Forest | 1.4153 | 0.0054 | 1.4048 | 1.4258 | 0.15 |
Y2 | Causal Forest | 0.2319 | 0.0071 | 0.2180 | 0.2457 | 0.15 |
Y3 | Causal Forest | −0.0458 | 0.0023 | −0.0503 | −0.0414 | 0.15 |
Outcome | Model | Mean CATE | SE of CATE | Perturbation Rate |
---|---|---|---|---|
Y1 | Empirical Copula CNN-LSTM | 1.3179 | 0.0053 | 0.00 |
Y2 | Empirical Copula CNN-LSTM | 0.0648 | 0.0040 | 0.00 |
Y3 | Empirical Copula CNN-LSTM | −0.7792 | 0.0037 | 0.00 |
Y1 | Plain CNN-LSTM | 1.2354 | 0.0118 | 0.00 |
Y2 | Plain CNN-LSTM | 0.2206 | 0.0083 | 0.00 |
Y3 | Plain CNN-LSTM | −0.7666 | 0.0098 | 0.00 |
Y1 | Causal Forest | 1.2849 | 0.0038 | 0.00 |
Y2 | Causal Forest | 0.2057 | 0.0073 | 0.00 |
Y3 | Causal Forest | −0.1255 | 0.0030 | 0.00 |
Y1 | Empirical Copula CNN-LSTM | 1.4535 | 0.0055 | 0.05 |
Y2 | Empirical Copula CNN-LSTM | 0.1152 | 0.0017 | 0.05 |
Y3 | Empirical Copula CNN-LSTM | −0.3648 | 0.0024 | 0.05 |
Y1 | Plain CNN-LSTM | 1.3183 | 0.0120 | 0.05 |
Y2 | Plain CNN-LSTM | 0.1348 | 0.0034 | 0.05 |
Y3 | Plain CNN-LSTM | −0.3521 | 0.0043 | 0.05 |
Y1 | Causal Forest | 1.4624 | 0.0037 | 0.05 |
Y2 | Causal Forest | 0.2395 | 0.0031 | 0.05 |
Y3 | Causal Forest | −0.0855 | 0.0021 | 0.05 |
Y1 | Empirical Copula CNN-LSTM | 1.4563 | 0.0054 | 0.10 |
Y2 | Empirical Copula CNN-LSTM | 0.3549 | 0.0061 | 0.10 |
Y3 | Empirical Copula CNN-LSTM | −0.4057 | 0.0033 | 0.10 |
Y1 | Plain CNN-LSTM | 1.3751 | 0.0107 | 0.10 |
Y2 | Plain CNN-LSTM | 0.3978 | 0.0081 | 0.10 |
Y3 | Plain CNN-LSTM | −0.3271 | 0.0043 | 0.10 |
Y1 | Causal Forest | 1.4170 | 0.0076 | 0.10 |
Y2 | Causal Forest | 0.3706 | 0.0051 | 0.10 |
Y3 | Causal Forest | −0.0683 | 0.0020 | 0.10 |
Y1 | Empirical Copula CNN-LSTM | 1.3851 | 0.0036 | 0.15 |
Y2 | Empirical Copula CNN-LSTM | 0.1397 | 0.0027 | 0.15 |
Y3 | Empirical Copula CNN-LSTM | −0.3880 | 0.0026 | 0.15 |
Y1 | Plain CNN-LSTM | 1.4149 | 0.0097 | 0.15 |
Y2 | Plain CNN-LSTM | 0.2261 | 0.0043 | 0.15 |
Y3 | Plain CNN-LSTM | −0.3679 | 0.0050 | 0.15 |
Y1 | Causal Forest | 1.4153 | 0.0058 | 0.15 |
Y2 | Causal Forest | 0.2319 | 0.0065 | 0.15 |
Y3 | Causal Forest | −0.0458 | 0.0025 | 0.15 |
Outcome | Model | ATE | SE | CI Lower | CI Upper |
---|---|---|---|---|---|
Y1: Continuous | Empirical Copula CNN-LSTM | −0.4670 | 0.0267 | −0.5193 | −0.4147 |
Y2: Count | Empirical Copula CNN-LSTM | 0.0826 | 0.0018 | 0.0790 | 0.0862 |
Y3: Survival Risk | Empirical Copula CNN-LSTM | 0.0018 | 0.0003 | 0.0012 | 0.0023 |
Y1: Continuous | CNN-LSTM Baseline | −0.4674 | 0.0254 | −0.5171 | −0.4176 |
Y2: Count | CNN-LSTM Baseline | 0.0853 | 0.0027 | 0.0799 | 0.0907 |
Y3: Survival Risk | CNN-LSTM Baseline | 0.0050 | 0.0019 | 0.0011 | 0.0088 |
Y1: Continuous | Causal Forest | −0.2525 | 0.0305 | −0.3123 | −0.1927 |
Y2: Count | Causal Forest | −0.0026 | 0.0003 | −0.0032 | −0.0020 |
Y3: Survival Risk | Causal Forest | −0.3003 | 0.0157 | −0.3311 | −0.2694 |
Outcome | Model | Mean CATE | SE CATE |
---|---|---|---|
Y1: Continuous | Empirical Copula CNN-LSTM | −0.4670 | 1.6014 |
Y2: Count | Empirical Copula CNN-LSTM | 0.0826 | 0.1033 |
Y3: Survival Risk | Empirical Copula CNN-LSTM | 0.0018 | 0.0164 |
Y1: Continuous | CNN-LSTM Baseline | −0.4674 | 1.6039 |
Y2: Count | CNN-LSTM Baseline | 0.0853 | 0.1451 |
Y3: Survival Risk | CNN-LSTM Baseline | 0.0050 | 0.1012 |
Y1: Continuous | Causal Forest | −0.2525 | 1.7610 |
Y2: Count | Causal Forest | −0.0026 | 0.0173 |
Y3: Survival Risk | Causal Forest | −0.3003 | 0.9271 |
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content. |
© 2025 by the author. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https://creativecommons.org/licenses/by/4.0/).
Share and Cite
Kim, J.-M. A Copula-Driven CNN-LSTM Framework for Estimating Heterogeneous Treatment Effects in Multivariate Outcomes. Mathematics 2025, 13, 2384. https://doi.org/10.3390/math13152384
Kim J-M. A Copula-Driven CNN-LSTM Framework for Estimating Heterogeneous Treatment Effects in Multivariate Outcomes. Mathematics. 2025; 13(15):2384. https://doi.org/10.3390/math13152384
Chicago/Turabian StyleKim, Jong-Min. 2025. "A Copula-Driven CNN-LSTM Framework for Estimating Heterogeneous Treatment Effects in Multivariate Outcomes" Mathematics 13, no. 15: 2384. https://doi.org/10.3390/math13152384
APA StyleKim, J.-M. (2025). A Copula-Driven CNN-LSTM Framework for Estimating Heterogeneous Treatment Effects in Multivariate Outcomes. Mathematics, 13(15), 2384. https://doi.org/10.3390/math13152384