Scalable SHAP-Informed Neural Network
Abstract
1. Introduction
1.1. Background
- Scaling the global learning rate based on aggregated feature importance;
- Adjusting individual parameter gradients in the first layer according to the importance of corresponding input features.
1.2. Related Work
2. Materials and Methods
2.1. Theory and Explanation
2.1.1. SHAP Values
2.1.2. SHAP-Informed Learning Rate Adjustment
- Computing SHAP values at fixed intervals.
- Aggregating and normalizing feature importance:
- Adjusting the learning rate by applying:
2.1.3. SHAP-Informed Gradient Adjustment
- Compute normalized SHAP importance as above.
- For gradient of feature i at time t, modify as follows:
2.1.4. Scalable SHAP-Informed Adjustments
- C-SHAP: Learning Rate Scaling via Cluster-Based SHAP Approximation
- 2.
- FastSHAP: Surrogate-Based SHAP Approximation for Gradient Scaling
- is an input sample with n features.
- is the exact SHAP value for feature i on instance .
- is the surrogate model’s prediction for that SHAP value.
2.2. Methodology
2.2.1. Workflow Overview
2.2.2. Datasets and Preprocessing
- The Adult Census Dataset contains 48,842 records with 14 features capturing demographic and employment-related information such as age, education level, occupation, and hours worked per week. The binary target variable indicates whether an individual earns more than USD 50 K annually. This is a classification task.
- The Ames Housing Dataset comprises 2930 instances and 80 engineered features representing residential property characteristics in Ames, Iowa. These features include physical measurements, neighborhood information, and quality ratings. The target variable is the sale price of each property, making this a regression task.
- The Breast Cancer Dataset consists of 569 samples, each with 30 numerical features describing properties of cell nuclei obtained from breast mass imagery. The task is binary classification, predicting whether the tumor is malignant or benign.
- The California Housing Dataset includes 20,640 instances from the 1990 California census. It contains 8 numerical features related to socioeconomic indicators such as median income, average occupancy, and housing age. The target variable is the median house value (in USD 100,000s) per census block group, treated as a regression task.
2.2.3. Model Architecture
- Input layer with dimensionality equal to the number of numerical features in the dataset.
- First hidden layer with 128 units and ReLU activation.
- Second hidden layer with 64 units and ReLU activation.
- Final output layer with a single neuron:
- For regression tasks (Ames Housing, California Housing), this layer outputs a continuous value without activation.
- For classification tasks (Adult Census, Breast Cancer), a sigmoid activation is applied post hoc to convert logits to probabilities for binary prediction.
- Dropout is applied after each hidden layer, with the rate selected as a hyperparameter in the grid search.
2.2.4. Optimization Methods
- C-SHAP, which mirrors the logic of SHAP but reduces overhead by computing SHAP values at cluster centroids only.
- FastSHAP, which mirrors SHAPG by learning a surrogate model to approximate SHAP values and applying them during gradient updates.
- Adam only: The baseline condition uses Adam with a fixed learning rate (tuned per dataset) and no SHAP integration. This configuration serves as a control to assess the incremental benefits of SHAP-informed guidance.
- SHAP-based learning rate scaling: This approach modifies the global learning rate based on the average SHAP values computed from the training samples. SHAP values are normalized and scaled by a tunable factor to dynamically adjust learning rate magnitude, promoting learning that aligns with feature importance. These SHAP values are recomputed at fixed intervals (e.g., every 10 or 20 epochs) to reflect updated feature attributions as training progresses, as proposed in our previous work and supported by broader optimization principles where timely adaptation of learning rate has been shown to improve convergence [8].
- SHAPG: gradient modification using SHAP: SHAPG extends the SHAP-based learning framework by directly modifying gradients during backpropagation. Specifically, the gradient of the input layer weights is scaled feature-wise in proportion to normalized SHAP values. This enforces a directional training signal that amplifies updates for more important features and suppresses less relevant ones, as described in our prior work [8]. While effective at directing gradient updates based on feature importance, SHAPG introduces higher computational cost compared to other methods. This is due to the repeated computation of exact SHAP values across a subset of training samples at fixed intervals and their use in per-feature gradient scaling during backpropagation. The added overhead stems both from the complexity of SHAP computation and the need to apply feature-wise operations during training, which cumulatively increase runtime.
- C-SHAP: cluster-based SHAP-informed learning rate scaling: C-SHAP reduces the computational burden of SHAP-based learning rate scaling by leveraging clustering to approximate feature importance. Instead of computing SHAP values across a large batch of training samples, C-SHAP applies MiniBatch K-Means clustering to the training data and computes SHAP values only at the resulting centroids. These centroid-level attributions serve as representative approximations for the full dataset. The mean absolute SHAP values across centroids are normalized and used to adjust the global learning rate. This strategy retains the benefits of SHAP-based optimization while significantly improving efficiency, making it suitable for larger datasets. A similar clustering-based SHAP approximation approach was explored by Ranjbaran et al. [10], who demonstrated that SHAP values computed at cluster centroids preserved feature importance rankings while greatly reducing computation time.
- FastSHAP: surrogate-guided gradient scaling: FastSHAP employs a learned surrogate model to approximate SHAP values efficiently. A lightweight neural explainer is trained on the main model’s outputs using sampled data. The surrogate produces feature attribution vectors that are used to scale the gradients of the input layer in a manner similar to SHAPG. This method is computationally more efficient and enables more frequent attribution updates without high overhead, adapting the surrogate-based estimation strategy from the work of Jethani et al. [11].
- Hybrid: combined C-SHAP and FastSHAP: The hybrid method combines the global learning rate scaling of C-SHAP with the gradient modulation of FastSHAP. Cumulative SHAP values are used to adjust the learning rate schedule, while the surrogate-based FastSHAP mechanism simultaneously informs gradient scaling. This fusion seeks to capture both coarse-grained (global) and fine-grained (directional) SHAP signals within the same training loop.
2.2.5. Experimental Setup
- For C-SHAP, we apply MiniBatch K-Means clustering on the training set to generate a small number of representative centroids. The SHAP values are then computed only at these centroids to reduce computational cost while maintaining attribution fidelity.
- For FastSHAP, a lightweight surrogate model is trained during each update interval to approximate SHAP values efficiently. This surrogate produces importance vectors used for scaling input-layer gradients.
- Learning rate: 0.01, 0.001, 0.0001.
- Dropout rate: 0.1, 0.3, 0.7.
- Epochs: 25, 50, 100.
- LR scaling factor: 0.5, 1.0, 2.0 (for SHAP-informed methods only).
- Test loss: mean prediction error on the test set (MSE for regression, BCE for classification).
- Root Mean Square Error (RMSE): used for regression datasets.
- Accuracy: used for binary classification datasets (Adult Census and Breast Cancer).
- R2 score: indicates variance explained in regression tasks.
- Training time: total runtime for each trial, including all epochs and SHAP-related computations.
2.2.6. Computational Complexity
- E: number of epochs.
- N: number of training samples.
- d: number of input features.
- k: number of SHAP update intervals.
- l: number of samples used in each SHAP update.
- B: number of samples used per FastSHAP surrogate inference.
- E′: number of epochs dedicated to training the FastSHAP surrogate model.
3. Results
3.1. Breast Cancer Dataset
- Test loss: The hybrid method achieved the lowest test loss (0.0028 ± 0.0015), outperforming all other methods including the baseline Adam optimizer (0.0033 ± 0.0018). The original SHAP method also showed competitive performance (0.0052 ± 0.0045), while the CSHAP and hybrid methods recorded higher variability in loss.
- Accuracy: Multiple SHAP-informed methods surpassed the baseline in accuracy. SHAPG and the hybrid methods both achieved 97.5% accuracy, notably higher than Adam’s (96.86%), with SHAP and FastSHAP also exceeding 97%. CSHAP performed slightly lower at 97.09%, suggesting more stable accuracy gains from gradient-informed strategies.
- Training time: The FastSHAP and hybrid methods showed a favorable trade-off between accuracy and training time. FastSHAP completed training in 1.31 ± 0.893 s on average, offering a significant runtime reduction compared to SHAPG (35.65 ± 13.37 s), which had the highest training time due to full SHAP recomputation per update. The hybrid method also achieved a competitive runtime of 3.04 ± 1.13 s while maintaining strong accuracy.
3.2. Adult Census Dataset
- Test loss: The SHAPG method achieved the lowest test loss (0.01159 ± 0.00022), slightly outperforming all other SHAP-based methods as well as the Adam-only baseline (0.01160 ± 0.00008). The hybrid method also performed competitively (0.01159 ± 0.00022), while CSHAP and SHAP produced similar loss levels to the baseline.
- Accuracy: All SHAP-informed methods achieved comparable or slightly improved accuracy over the Adam baseline (83.09%). SHAPG reached the highest accuracy (83.10%), closely followed by the hybrid (82.92%) and FastSHAP methods (82.91%), indicating marginal improvements across methods. The differences were small, suggesting that test loss was a more sensitive differentiator on this task.
- Training time: FastSHAP demonstrated the best runtime performance among SHAP-informed methods, with 129.81 ± 35.30 s—faster than even the Adam baseline (147.44 ± 22.35 s). In contrast, SHAPG required the most time (287.91 ± 151.11 s), due to the cost of computing exact SHAP values and applying gradient adjustments at each update interval. Although FastSHAP also performs updates each interval, it leverages a learned surrogate model to approximate SHAP values efficiently, resulting in significantly faster training. The hybrid and SHAP methods required moderate training time, with CSHAP achieving relatively efficient performance at 135.25 ± 27.95 s.
3.3. Ames Housing Dataset
- Test loss: SHAPG achieved the lowest test loss (0.00399 ± 0.00143), outperforming all other methods, including FastSHAP (0.00477 ± 0.00204) and the SHAP method (0.00482 ± 0.00408). Both the CSHAP and hybrid methods also notably improved upon the Adam-only baseline (0.00761 ± 0.00450), highlighting the value of SHAP-based guidance in regression tasks.
- RMSE: SHAPG and FastSHAP delivered the best RMSE scores at 0.0621 and 0.0675, respectively—lower than Adam’s 0.0830—indicating better predictive calibration across samples. The CSHAP and hybrid methods produced slightly higher RMSE values, suggesting more variability in model predictions.
- R2 Score: SHAPG achieved the highest R2 score (0.7799), reflecting the model’s improved ability to explain variance in the target variable. FastSHAP and CSHAP also yielded strong scores (0.7644 and 0.7470), while the hybrid method showed the weakest performance (0.6210), suggesting that combining global and local SHAP signals may have introduced instability in this regression setting.
- Training Time: FastSHAP again delivered excellent efficiency, completing training in just 2.61 ± 0.89 s. SHAPG (57.82 ± 21.67 s) and SHAP (43.17 ± 9.32 s) required significantly more time due to SHAP value computations at regular intervals. CSHAP achieved slightly improved efficiency without the consistent other metrics, while the hybrid method was slower than FastSHAP but faster than SHAP and SHAPG.
3.4. California Housing Dataset
- Test loss: The hybrid method recorded the lowest test loss (0.00802 ± 0.00029), slightly outperforming all other methods, including CSHAP (0.00803 ± 0.00035) and SHAPG (0.00808 ± 0.00044). The FastSHAP method also performed competitively, while the Adam-only baseline had the highest test loss overall.
- RMSE: All SHAP-based methods improved upon the Adam baseline (0.0905), with the hybrid and CSHAP methods achieving the lowest RMSEs at 0.0895 and 0.0896, respectively. These results suggest that incorporating gradient scaling and cluster-based SHAP approximations can enhance stability and consistency in model predictions for high-dimensional tabular regression.
- R2 score: SHAPG yielded the highest R2 score (0.8089), indicating the best variance explanation among all models. The hybrid (0.8086) and CSHAP methods (0.8085) followed closely behind. These results highlight the consistent ability of SHAP-informed methods to improve model expressiveness on large tabular datasets.
- Training time: CSHAP was the most efficient SHAP-based method, completing training in 54.79 ± 16.67 s. FastSHAP followed with a slightly longer runtime (63.93 ± 11.14 s), while SHAPG and SHAP both required over two minutes. The hybrid method offered a compromise between speed and performance, completing training in 68.67 ± 15.82 s.
4. Discussion
Author Contributions
Funding
Data Availability Statement
Conflicts of Interest
References
- Jarrahi, M.H.; Memariani, A.; Guha, S. The Principles of Data-Centric AI (DCAI). Commun. ACM 2023, 66, 84–92. [Google Scholar] [CrossRef]
- Wilson, A.C.; Roelofs, R.; Stern, M.; Srebro, N.; Recht, B. The Marginal Value of Adaptive Gradient Methods in Machine Learning. arXiv 2017, arXiv:1705.08292. [Google Scholar]
- Zou, D.; Cao, Y.; Li, Y.; Gu, Q. Understanding the Generalization of Adam in Learning Neural Networks with Proper Regularization. In Proceedings of the International Conference on Learning Representations (ICLR), Kigali, Rwanda, 1–5 May 2023. [Google Scholar]
- Duchi, J.; Hazan, E.; Singer, Y. Adaptive subgradient methods for online learning and stochastic optimization. J. Mach. Learn. Res. 2011, 12, 2121–2159. [Google Scholar]
- Ross, A.S.; Hughes, M.C.; Doshi-Velez, F. Right for the Right Reasons: Training Differentiable Models by Constraining Their Explanations. In Proceedings of the 26th International Joint Conference on Artificial Intelligence, Melbourne, Australia, 19–25 August 2017; pp. 2662–2668. [Google Scholar]
- Lundberg, S.M.; Lee, S.-I. A Unified Approach to Interpreting Model Predictions. Adv. Neural Inf. Process. Syst. (NeurIPS) 2017, 30, 4765–4774. [Google Scholar]
- Adadi, A.; Berrada, M. Peeking Inside the Black-Box: A Survey on Explainable Artificial Intelligence (XAI). IEEE Access 2018, 6, 52138–52160. [Google Scholar] [CrossRef]
- Graham, J.; Sheng, V. SHAP-Informed Neural Network Optimization. Mathematics 2024, 12, 456. [Google Scholar]
- Hamilton, R.I.; Papadopoulos, P.N. Using SHAP Values and Machine Learning to Understand Trends in the Transient Stability Limit. IEEE Trans. Power Syst. 2024, 39, 1384–1397. [Google Scholar] [CrossRef]
- Ranjbaran, G.; Recupero, D.R.; Roy, C.K.; Schneider, K.A. C-SHAP: A Hybrid Method for Fast and Efficient Interpretability. Appl. Sci. 2025, 15, 672. [Google Scholar] [CrossRef]
- Jethani, N.; Covert, I.; Lee, S.-I.; Lundberg, S.M. FastSHAP: Real-Time Shapley Value Estimation. In Proceedings of the 10th International Conference on Learning Representations (ICLR 2022), Virtual, 25–29 April 2022. [Google Scholar]
- Kingma, D.P.; Ba, J. Adam: A Method for Stochastic Optimization. In Proceedings of the International Conference on Learning Representations (ICLR), San Diego, CA, USA, 7–9 May 2015. [Google Scholar]
- Bottou, L. Large-Scale Machine Learning with Stochastic Gradient Descent. In Proceedings of the COMPSTAT’2010: 19th International Conference on Computational Statistics, Paris, France, 22–27 August 2010; pp. 177–186. [Google Scholar]
- Kabiri, H.; Ghanou, Y.; Khalifi, H.; Casalino, G. AMAdam: Adaptive modifier of Adam method. Knowl. Inf. Syst. 2024, 66, 3427–3458. [Google Scholar] [CrossRef]
- Huang, H.; Wang, C.; Dong, B. Nostalgic Adam: Weighting more of the past gradients when designing the adaptive learning rate. In Proceedings of the 28th International Joint Conference on Artificial Intelligence (IJCAI) 2019, Macao, China, 10–16 August 2019; pp. 2556–2562. [Google Scholar]
- Luo, L.; Xiong, Y.; Liu, Y.; Sun, X. Adaptive Gradient Methods with Dynamic Bound of Learning Rate. In Proceedings of the International Conference on Learning Representations (ICLR 2019), New Orleans, LA, USA, 6–9 May 2019. [Google Scholar]
- Ruder, S. An Overview of Gradient Descent Optimization Algorithms. arXiv 2016, arXiv:1609.04747. [Google Scholar]
- Wang, X.; Magnússon, S.; Johansson, M. On the Convergence of Step Decay Step-Size for Stochastic Optimization. Adv. Neural Inf. Process. Syst. 2021, 34, 14226–14238. [Google Scholar]
- Figueroa Barraza, J.; López Droguett, E.; Martins, M.R. Towards Interpretable Deep Learning: A Feature Selection Framework for Prognostics and Health Management Using Deep Neural Networks. Sensors 2021, 21, 5888. [Google Scholar] [CrossRef] [PubMed]
- Le, T.T.H.; Kim, H.; Kang, H.; Kim, H. Classification and explanation for intrusion detection system based on ensemble trees and SHAP method. Sensors 2022, 22, 1154. [Google Scholar] [CrossRef] [PubMed]
- Marcílio, W.E.; Eler, D.M. From explanations to feature selection: Assessing SHAP values as feature selection mechanism. In Proceedings of the 2020 33rd SIBGRAPI Conference on Graphics, Patterns and Images (SIBGRAPI), Porto de Galinhas, Brazil, 7–10 November 2020; pp. 340–347. [Google Scholar] [CrossRef]
- Potharlanka, J.L.; Bhat, N.M. Feature importance feedback with Deep Q process in ensemble-based metaheuristic feature selection algorithms. Sci. Rep. 2024, 14, 2923. [Google Scholar] [CrossRef] [PubMed]
- Albini, E.; Long, J.; Dervovic, D.; Magazzeni, D. Counterfactual Shapley Additive Explanations. In Proceedings of the 2022 ACM Conference on Fairness, Accountability, and Transparency (FAccT ’22), Seoul, Republic of Korea, 21–24 June 2022; ACM: New York, NY, USA, 2022; pp. 1054–1070. [Google Scholar]
- Alkhatib, A.; Boström, H. Fast Approximation of Shapley Values with Limited Data. In Proceedings of the 14th Scandinavian Conference on Artificial Intelligence SCAI 2024, Jönköping, Sweden, 10–11 June 2024; pp. 95–100. [Google Scholar]
Method | Training Complexity | SHAP Update Overhead | Test Complexity |
---|---|---|---|
Adam only | O(E × N × d) | None | O(d) |
SHAP | O(E × N × d) | O(k × l × d + k × d2) | O(d) |
SHAPG | O(E × N × d) | O(k × l × d + k × d2) | O(d) |
CSHAP | O(E × N × d) | O(k × l × d + k × d2) | O(d) |
FastSHAP | O(E × N × d) + O(E′ × B × d) | O(B × d) | O(d) |
Method | Test Loss | Accuracy | Training Time |
---|---|---|---|
Adam only | 0.0033 ± 0.0018 | 0.9686 ± 0.0165 | 1.9019 ± 0.4432 |
SHAP | 0.0052 ± 0.0045 | 0.9721 ± 0.0107 | 16.9329 ± 9.6288 |
SHAPG | 0.0041 ± 0.0028 | 0.9756 ± 0.0168 | 35.6517 ± 13.3737 |
CSHAP | 0.0118 ± 0.0240 | 0.9709 ± 0.0094 | 2.1851 ± 0.3981 |
FastSHAP | 0.0066 ± 0.0073 | 0.9721 ± 0.0182 | 1.3071 ± 0.8928 |
Hybrid | 0.0028 ± 0.0015 | 0.9756 ± 0.0121 | 3.0419 ± 1.1327 |
Method | Test Loss | Accuracy | Training Time |
---|---|---|---|
Adam only | 0.01160 ± 0.00008 | 0.83019 ± 0.00270 | 147.4386 ± 22.3537 |
SHAP | 0.01162 ± 0.00011 | 0.82970 ± 0.00171 | 224.6501 ± 59.8349 |
SHAPG | 0.01153 ± 0.00014 | 0.83109 ± 0.00385 | 287.9111 ± 51.1131 |
CSHAP | 0.01161 ± 0.00019 | 0.82851 ± 0.00476 | 135.2455 ± 27.9485 |
FastSHAP | 0.01167 ± 0.00018 | 0.82919 ± 0.00378 | 129.8111 ± 35.3003 |
Hybrid | 0.01159 ± 0.00022 | 0.82921 ± 0.00534 | 194.4226 ± 49.8023 |
Method | Test Loss | RMSE | R2 | Training Time |
---|---|---|---|---|
Adam Only | 0.00761 ± 0.00450 | 0.0830 ± 0.0267 | 0.7255 ± 0.1768 | 1.5512 ± 0.6533 |
SHAP | 0.00482 ± 0.00408 | 0.0651 ± 0.0242 | 0.7373 ± 0.2477 | 43.1751 ± 9.3204 |
SHAPG | 0.00399 ± 0.00143 | 0.0621 ± 0.0115 | 0.7799 ± 0.0946 | 57.8169 ± 21.6683 |
CSHAP | 0.00631 ± 0.00324 | 0.0769 ± 0.0200 | 0.7470 ± 0.1009 | 2.0344 ± 0.9757 |
FastSHAP | 0.00477 ± 0.00204 | 0.0675 ± 0.0147 | 0.7644 ± 0.0935 | 2.6061 ± 0.8900 |
Hybrid | 0.00719 ± 0.00487 | 0.0804 ± 0.0269 | 0.6210 ± 0.2354 | 4.2156 ± 1.5527 |
Method | Test Loss | RMSE | R2 | Training Time |
---|---|---|---|---|
Adam Only | 0.00818 ± 0.00025 | 0.09045 ± 0.00138 | 0.8032 ± 0.0069 | 43.3565 ± 5.2618 |
SHAP | 0.00813 ± 0.00037 | 0.09012 ± 0.00206 | 0.8016 ± 0.0097 | 129.8534 ± 25.2650 |
SHAPG | 0.00808 ± 0.00044 | 0.08984 ± 0.00244 | 0.8089 ± 0.0101 | 129.2503 ± 30.2712 |
CSHAP | 0.00803 ± 0.00035 | 0.08960 ± 0.00198 | 0.8085 ± 0.0070 | 54.7930 ± 16.6717 |
FastSHAP | 0.00812 ± 0.00030 | 0.09007 ± 0.00165 | 0.8066 ± 0.0046 | 63.9276 ± 11.1408 |
Hybrid | 0.00802 ± 0.00029 | 0.08954 ± 0.00164 | 0.8086 ± 0.0082 | 68.6670 ± 15.8209 |
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content. |
© 2025 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https://creativecommons.org/licenses/by/4.0/).
Share and Cite
Graham, J.; Sheng, V.S. Scalable SHAP-Informed Neural Network. Mathematics 2025, 13, 2152. https://doi.org/10.3390/math13132152
Graham J, Sheng VS. Scalable SHAP-Informed Neural Network. Mathematics. 2025; 13(13):2152. https://doi.org/10.3390/math13132152
Chicago/Turabian StyleGraham, Jarrod, and Victor S. Sheng. 2025. "Scalable SHAP-Informed Neural Network" Mathematics 13, no. 13: 2152. https://doi.org/10.3390/math13132152
APA StyleGraham, J., & Sheng, V. S. (2025). Scalable SHAP-Informed Neural Network. Mathematics, 13(13), 2152. https://doi.org/10.3390/math13132152