Next Article in Journal
UWMambaNet: Dual-Branch Underwater Image Reconstruction Based on W-Shaped Mamba
Previous Article in Journal
Predefined Time Control of State-Constrained Multi-Agent Systems Based on Command Filtering
Previous Article in Special Issue
Link Prediction and Graph Structure Estimation for Community Detection
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Scalable SHAP-Informed Neural Network

Department of Computer Science, College of Engineering, Texas Tech University, Lubbock, TX 79409, USA
*
Author to whom correspondence should be addressed.
Mathematics 2025, 13(13), 2152; https://doi.org/10.3390/math13132152
Submission received: 11 May 2025 / Revised: 23 June 2025 / Accepted: 27 June 2025 / Published: 30 June 2025

Abstract

In the pursuit of scalable optimization strategies for neural networks, this study addresses the computational challenges posed by SHAP-informed learning methods introduced in prior work. Specifically, we extend the SHAP-based optimization family by incorporating two existing approximation methods, C-SHAP and FastSHAP, to reduce training time while preserving the accuracy and generalization benefits of SHAP-based adjustments. C-SHAP leverages clustered SHAP values for efficient learning rate modulation, while FastSHAP provides rapid approximations of feature importance for gradient adjustment. Together, these methods significantly improve the practical usability of SHAP-informed neural network training by lowering computational overhead without major sacrifices in predictive performance. The experiments conducted across four datasets—Breast Cancer, Ames Housing, Adult Census, and California Housing—demonstrate that both C-SHAP and FastSHAP achieve substantial reductions in training time compared to original SHAP-based methods while maintaining competitive test losses, RMSE, and accuracy relative to baseline Adam optimization. Additionally, a hybrid approach combining C-SHAP and FastSHAP is explored as an avenue for further balancing performance and efficiency. These results highlight the feasibility of using feature-importance-based guidance to enhance optimization in neural networks at a reduced computational cost, paving the way for broader applicability of explainability-informed training strategies.

1. Introduction

1.1. Background

Recent advances in neural network training have increasingly focused on tailoring the learning process to the structure of the data itself, rather than relying solely on general-purpose optimization routines. As model complexity continues to grow across tasks and domains, there is rising interest in integrating data-aware strategies that improve convergence efficiency and predictive performance. This shift reflects a broader trend toward data-centric AI, in which learning algorithms are explicitly designed to align more closely with data characteristics and underlying semantics [1]. One key area of investigation involves adapting the learning rate not just globally, but based on the relative importance of individual input features.
Traditional optimizers like Adam have been highly effective in handling diverse training conditions due to their adaptive mechanisms. By scaling updates based on gradient history, these optimizers offer stability and speed across many tasks. However, their updates remain blind to which features actually contribute most to prediction performance. Studies have shown that such methods can overemphasize features that are statistically significant in the training data but unimportant for generalization, especially when used without regularization or interpretability constraints [2,3]. These limitations have motivated more targeted optimization strategies—those that prioritize learning from relevant features—which may offer better generalization, particularly in high-dimensional or noisy environments. The development of AdaGrad and its successors exemplifies the long-standing interest in gradient- and parameter-wise adaptivity, although they still treat input data structure as incidental rather than central to learning [4].
One promising direction in this context is the use of feature attribution methods to guide the optimization process. Rather than relying on static learning schedules or uniform heuristics, these techniques introduce a layer of responsiveness by leveraging real-time feedback on which input features the model currently deems important. Ross et al. [5], for example, demonstrated that constraining a model to focus on “right for the right reasons” explanations—by aligning gradients with known important features—can lead to both improved accuracy and interpretability. Such integration of attribution-based feedback into the optimization loop creates new opportunities for combining data-centric learning with principled, theory-informed training dynamics. However, most XAI techniques, including SHAP, are traditionally designed for post hoc interpretation rather than real-time optimization. This distinction presents a challenge: while feature attributions offer insight into model behavior, directly embedding them into gradient-based training dynamics requires careful design to avoid instability or computational overhead.
Recent advances in Explainable AI (XAI) have introduced the possibility of incorporating feature importance information into the optimization process itself. SHAP (SHapley Additive exPlanations) values [6], grounded in cooperative game theory, provide a consistent and theoretically justified way to measure feature contributions to model predictions. Originally developed for post hoc explanation in XAI frameworks [7], SHAP has become a cornerstone in efforts to understand black-box model behavior, enabling novel applications beyond interpretability—such as the optimization techniques explored in this study. Building on these foundations, previous work has demonstrated that SHAP values can inform optimization by either:
  • Scaling the global learning rate based on aggregated feature importance;
  • Adjusting individual parameter gradients in the first layer according to the importance of corresponding input features.
This approach, termed SHAP-Informed Neural Network Optimization [8], offers a novel integration of interpretability and optimization, leveraging feature-level insights to guide the learning dynamics of neural networks.
However, the original SHAP-informed methods faced a significant challenge: computational inefficiency. Calculating exact SHAP values during training introduced substantial overhead, making the methods impractical for large datasets or models. Addressing this limitation forms the motivation for the present study. While SHAP libraries offer approximations that improve feasibility for post hoc explanations [9], these methods remain computationally expensive when applied repeatedly during training, as in SHAP-informed optimization. This limitation motivates the development of scalable alternatives.
In this work, we introduce the integration of C-SHAP (Cluster-Based SHAP) and FastSHAP (Fast Surrogate-based SHAP) within the attribution-aware optimization variants to enhance scalability. C-SHAP, a centroid-based approximation of SHAP values using clustering, dramatically reduces the computational burden of feature importance calculation by summarizing input data into representative centroids [10]. FastSHAP, a neural-network-based SHAP approximation method [11], offers an even faster way to estimate feature contributions by learning to predict SHAP values directly from data.
By embedding C-SHAP for learning rate adjustment and FastSHAP for gradient adjustment, we aim to retain the benefits of SHAP-informed optimization—feature-aware training dynamics—while significantly reducing computational costs. The goal is to make SHAP-informed neural network training more viable for larger-scale applications without sacrificing performance.
This research contributes to the ongoing exploration of how interpretability methods can be embedded into the optimization process itself, offering practical pathways for building machine learning models that are both accurate and efficient.

1.2. Related Work

Research on neural network optimization has advanced considerably over the past decade, with a central focus on improving convergence speed, accuracy, and computational efficiency. Among the most influential methods is the Adam optimizer [12], which adaptively adjusts learning rates using first and second moments of gradients. Its foundational mechanisms stem from AdaGrad and RMSProp [4], two earlier optimizers that also adapt step sizes based on gradient history. Complementing these, SGD remains a fundamental baseline due to its simplicity and strong generalization performance when paired with techniques like momentum and proper learning rate scheduling [13].
Efforts to enhance these baseline optimizers have produced various extensions. Kabiri et al. [14] proposed AMAdam, an adaptive modifier to Adam that improves convergence on noisy gradient landscapes. Huang et al. [15] introduced the Nostalgic Adam variant, which emphasizes older gradients to enhance long-term learning dynamics. Other studies, such as Luo et al. [16], proposed dynamically bounded adaptive gradient methods, which improve optimizer stability across tasks. Collectively, these works underscore a trend toward developing optimizers that not only adapt to local gradient information but also maintain global stability across training epochs.
Parallel to optimizer enhancements, learning rate scheduling has been a focal point for controlling the optimization trajectory. Classical techniques like step decay and annealing are widely adopted due to their simplicity and effectiveness. Bottou [13] explored these strategies in large-scale applications, while Ruder [17] provided an extensive taxonomy of scheduling methods. Wang et al. [18] delivered rigorous convergence guarantees for SGD trained with step-decay schedules, clarifying how carefully chosen learning-rate drops accelerate optimization while preserving generalization. These approaches aim to balance convergence speed with final model generalization, making them essential components in practical training pipelines.
Recent developments at the intersection of interpretability and optimization have introduced SHAP (SHapley Additive exPlanations) values into the training loop. Lundberg and Lee [6] proposed SHAP as a post hoc interpretability method based on cooperative game theory, allowing feature-level attribution of model predictions. Building on this, Graham and Sheng [8] pioneered the SHAP-informed optimization framework, where SHAP values were used to inform the learning rate and gradient scaling during training. Their follow-up work (this study) builds on that idea by investigating more scalable alternatives.
To address the computational inefficiency of calculating exact SHAP values during training, this study integrates C-SHAP [10] and FastSHAP [11]—two lightweight SHAP approximations—into the optimization process. C-SHAP leverages clustering (e.g., MiniBatchKMeans) to approximate SHAP values on representative data subsets, while FastSHAP uses a meta-model trained to mimic SHAP attributions in real time. These methods are particularly appealing for scaling XAI-informed optimization to larger datasets and deeper models.
Moreover, Droguett et al. [19] proposed integrating feature selection into deep learning, aligning with the philosophy behind SHAP-informed updates by prioritizing informative features during training. While many previous works use SHAP for model interpretation after training [6,20], this study exemplifies the trend of using feature importance metrics actively within the optimization loop—a notable shift in how interpretability tools are utilized in practice.
Overall, this work contributes to a growing body of research that blends optimization and explainability. While traditional learning rate schedules and adaptive optimizers have long been studied, incorporating feature importance metrics like SHAP values into the optimization process offers a promising new direction. This study extends earlier efforts [8] by improving the scalability of SHAP-informed training through novel integrations of C-SHAP and FastSHAP, providing a practical path toward interpretable, efficient neural network optimization.

2. Materials and Methods

2.1. Theory and Explanation

Portions of this section are adapted from our prior work on SHAP-guided approaches [8], extended here to include scalable methods such as C-SHAP and FastSHAP. In this section, we use bold symbols (e.g., g, m, θ) to denote vector quantities, while non-bold symbols (e.g., ϕᵢ, α, β1, ϵ) represent scalars. All operations involving vectors (e.g., addition, multiplication) are applied element-wise unless otherwise specified.
The core of this study lies in modifying the learning dynamics of neural networks by integrating interpretable, feature-level importance metrics into the optimization process. Specifically, we extend our previously proposed SHAP-informed optimization framework by evaluating two scalable variants C-SHAP and FastSHAP that significantly reduce computational cost while preserving or improving model performance.
The base optimization method in this study is the Adam optimizer, which has become a foundational component in modern deep learning due to its adaptive learning rates and efficient handling of sparse gradients. The update rules for Adam are as follows:
m t = β 1 m t 1 + 1 β 1 g t
v t = β 2 v t 1 + 1 β 2 g t 2
m ^ t = m t 1 β 1 t
v ^ t = v t 1 β 2 t
θ t + 1 = θ t α m ^ t v ^ t + ϵ
where α is the base learning rate, β1 and β2 are exponential decay rates for the moment estimates, ϵ is a small constant to avoid division by zero, and θ denotes the parameters of the network at time t.

2.1.1. SHAP Values

SHAP (SHapley Additive exPlanations) values [6] originate from cooperative game theory and quantify the contribution of each input feature to a model’s output. For a model f and an instance x, the SHAP value ϕᵢ for feature i is given by:
ϕ i = S N { i } S ! N S 1 ! N ! f S { i } f S
This equation calculates the average marginal contribution of a feature across all possible feature subsets S of the input space N.
In this study, SHAP values are leveraged not as a post hoc interpretability tool, but as a real-time guide for model optimization. Two primary mechanisms are explored: global learning rate scaling and feature-wise gradient adjustment.

2.1.2. SHAP-Informed Learning Rate Adjustment

The first SHAP-informed mechanism modifies the global learning rate based on the aggregated SHAP values computed from the training data (later labeled as SHAP). The process involves:
  • Computing SHAP values at fixed intervals.
  • Aggregating and normalizing feature importance:
    m e a n _ i m p o r t a n c e i = 1 m j = 1 m | ϕ i f |
    n o r m a l i z e d _ i m p o r t a n c e i = m e a n _ i m p o r t a n c e i m a x ( m e a n _ i m p o r t a n c e )
  • Adjusting the learning rate by applying:
    α = α   ·   i = 1 n n o r m a l i z e d _ i m p o r t a n c e i
This integration ensures that the learning rate is generally impacted by features by a single figure derived from SHAP values which encode general meaning from the features [8]. This approach generates a coarse but informative signal by aggregating feature importance scores over a sampled subset of training data. The resulting scalar value scales the global learning rate, enabling the model to prioritize updates in proportion to overall feature relevance. In doing so, it suppresses the influence of noisy or less relevant features, promoting more efficient training and improved generalization.

2.1.3. SHAP-Informed Gradient Adjustment

The second SHAP-based mechanism SHAPG (gradient-based SHAP method) applies feature-wise adjustments directly to the gradients of the first layer weights (later labeled as SHAPG):
  • Compute normalized SHAP importance as above.
    n o r m a l i z e d _ i m p o r t a n c e i = m e a n _ i m p o r t a n c e i m a x ( m e a n _ i m p o r t a n c e )
  • For gradient g t i of feature i at time t, modify as follows:
    g t i = g t i · n o r m a l i z e d _ i m p o r t a n c e i
This fine-grained approach allows the optimizer to scale its updates based on individual feature relevance, promoting stronger updates for more informative inputs and reducing noise from less important ones. By limiting the modification to the first layer, this method avoids computational overhead and instability in deeper layers.

2.1.4. Scalable SHAP-Informed Adjustments

While feature-attribution-guided optimization methods introduced in prior work have demonstrated effectiveness in guiding neural network training via feature importance, they often impose a substantial computational burden due to the cost of computing exact SHAP values during training. This limitation, particularly acute in large datasets or high-frequency SHAP update schedules, motivates the introduction of scalable alternatives. In this study, we explore two such alternatives—C-SHAP (Cluster-SHAP) and FastSHAP—that approximate SHAP-based importance while significantly reducing computational overhead.
  • C-SHAP: Learning Rate Scaling via Cluster-Based SHAP Approximation
C-SHAP (cluster-based SHAP method) leverages clustering to reduce the number of SHAP computations required at each update interval. Instead of computing SHAP values for a large batch of samples, we cluster the training data and compute SHAP values only at the centroids. Let C = { c 1 , c 2 , , c k } be the set of k centroids obtained via MiniBatch K-Means over the training data. At each SHAP update interval, we compute SHAP values ϕ i c j for each feature i at each centroid c j C . These values are then used to compute the average absolute importance per feature:
m e a n _ i m p o r t a n c e i = 1 k j = 1 k | ϕ i ( c f ) |
We normalize the mean importances as follows:
n o r m a l i z e d _ i m p o r t a n c e i = m e a n _ i m p o r t a n c e i m a x ( m e a n _ i m p o r t a n c e )
These normalized values are used to scale the global learning rate following the same approach as in Section 2.1.2:
α = α   ·   i = 1 n n o r m a l i z e d _ i m p o r t a n c e i
This method significantly reduces SHAP evaluation overhead by replacing sample-wide calculations with a much smaller number of centroid-based evaluations. Since centroids represent regions of the input space, the resulting importance values still capture general trends in feature influence.
While exact SHAP computations are costly, C-SHAP reduces the cost of SHAP updates by computing attributions only for cluster centroids, under the assumption that in smooth feature spaces, nearby samples tend to exhibit similar SHAP patterns. By summarizing regions of the input space through clustering, the method avoids computing SHAP values for every training instance, making centroid-based approximations a practical and efficient alternative for periodically informing the training process.
Rather than relying on precise SHAP values for each individual data point, the method averages SHAP values across these centroids to produce a stable global feature importance signal. This average signal is not intended to capture exact importance rankings, but rather to serve as a coarse guide for adjusting the global learning rate. Empirically, this approximate signal has proven sufficient to improve training efficiency without significantly degrading performance, even though it does not perfectly preserve fine-grained attribution details.
2.
FastSHAP: Surrogate-Based SHAP Approximation for Gradient Scaling
FastSHAP (fast surrogate-based SHAP) offers another avenue for efficient SHAP estimation by training a surrogate explainer model to approximate SHAP values in real time, rather than computing SHAP values using kernel-based or gradient-based exact methods.
  • x R is an input sample with n features.
  • ϕ i x j is the exact SHAP value for feature i on instance x j .
  • ϕ ^ i x j is the surrogate model’s prediction for that SHAP value.
FastSHAP trains a neural explainer f e x p l to learn the mapping between inputs and their SHAP attributions to minimize the loss:
L F a s t S H A P = 1 m j = 1 m i = 1 n ( ϕ ^ i x j ϕ i x j ) 2
where m denotes the number of samples in the batch or SHAP explainer training set and n is the number of features.
Once trained, the surrogate model f e x p l estimates SHAP values for a batch B during training. The estimated importances are averaged:
m e a n _ i m p o r t a n c e i = 1 B j = 1 B | ϕ ^ i ( x j ) |
Normalization and scaling are conducted identically to earlier methods. For gradient adjustment, the SHAP-informed gradient for feature i at time t is modified as follows:
g t i = g t i · n o r m a l i z e d _ i m p o r t a n c e i
This lightweight approximation allows for feature-aware gradient scaling without invoking exact SHAP evaluations during training. By amortizing SHAP computation through a learned explainer, FastSHAP achieves near-real-time attribution while preserving fine-grained influence estimation.
FastSHAP avoids recomputing SHAP values during training by learning a surrogate model that predicts SHAP-like attributions in a single forward pass. This explainer is trained to mimic SHAP value patterns across samples and can be queried in real time, allowing feature importance updates to remain integrated with the training loop. While this trades off some exactness, the use of a learned attribution function enables frequent and low-latency updates to guide optimization. In gradient adjustment tasks, this allows the network to emphasize more influential features without computational bottlenecks. The surrogate model effectively serves as a lightweight estimator of feature relevance, and if trained well, it can capture enough of the true SHAP structure to support optimization strategies like gradient scaling or attention focusing.

2.2. Methodology

Portions of this methodology section are adapted from our prior work on SHAP-informed optimization [8], with modifications to reflect the integration of C-SHAP, FastSHAP, and differing datasets.
This study investigates scalable SHAP-informed techniques integrated with the Adam optimizer to improve the training and generalization performance of neural networks. We extend prior work by implementing and evaluating four SHAP-informed learning strategies: SHAP-based learning rate scaling, SHAP-informed gradient adjustment (SHAPG), Clustered SHAP (C-SHAP), and FastSHAP. Additionally, we include a hybrid method that combines the scaling mechanism of C-SHAP with the directional gradient modifications of FastSHAP. These methods are compared against a baseline model trained using the Adam optimizer alone. Our experiments span four benchmark datasets: Adult Census, Ames Housing, Breast Cancer, and California Housing. To ensure rigorous and consistent evaluation, we use controlled preprocessing, a unified network architecture, and grid search for hyperparameter tuning.

2.2.1. Workflow Overview

To provide a comprehensive view of the experimental pipeline, we present a structured diagram outlining the full training and evaluation process used in this study. This workflow encompasses dataset preparation, model initialization, SHAP-informed training routines, and final evaluation. While each SHAP variant—SHAP, SHAPG, CSHAP, and FastSHAP—is implemented in separate scripts, the diagram generalizes the shared structure and logic used across methods. The training loop includes SHAP-based updates that adjust learning rates or gradients based on computed feature importance values. To ensure robustness, each configuration undergoes repeated training runs using independently shuffled data splits and a grid search over hyperparameter settings. Figure 1 summarizes this end-to-end workflow.

2.2.2. Datasets and Preprocessing

This study makes use of four well-established benchmark datasets:
  • The Adult Census Dataset contains 48,842 records with 14 features capturing demographic and employment-related information such as age, education level, occupation, and hours worked per week. The binary target variable indicates whether an individual earns more than USD 50 K annually. This is a classification task.
  • The Ames Housing Dataset comprises 2930 instances and 80 engineered features representing residential property characteristics in Ames, Iowa. These features include physical measurements, neighborhood information, and quality ratings. The target variable is the sale price of each property, making this a regression task.
  • The Breast Cancer Dataset consists of 569 samples, each with 30 numerical features describing properties of cell nuclei obtained from breast mass imagery. The task is binary classification, predicting whether the tumor is malignant or benign.
  • The California Housing Dataset includes 20,640 instances from the 1990 California census. It contains 8 numerical features related to socioeconomic indicators such as median income, average occupancy, and housing age. The target variable is the median house value (in USD 100,000s) per census block group, treated as a regression task.
All datasets undergo consistent preprocessing. Numerical features are normalized using standard scaling to ensure that all features operate on comparable scales. To maintain uniformity and avoid encoding complexities, only numerical features are retained in each dataset. After preprocessing, the data is randomly shuffled and partitioned into 70% for training, 15% for validation, and 15% for testing. This stratification supports consistent training and fair model evaluation across experimental trials.
While the scope of this study is limited to structured, tabular datasets, we intentionally selected datasets that vary in size, dimensionality, and task type to evaluate robustness across common tabular learning scenarios. Breast Cancer represents a low-dimensional, small-scale classification task; Adult Census covers a larger, high-dimensional binary classification problem; Ames Housing and California Housing provide regression challenges with continuous targets. These datasets offer a meaningful diversity of tabular conditions for benchmarking optimization strategies. Generalization to unstructured domains such as images, time series, or natural language is left as future work.

2.2.3. Model Architecture

To ensure consistency across experiments, we use a unified fully connected feedforward neural network architecture for all datasets, with minor adjustments to accommodate input dimensionality. Each model consists of three hidden layers with ReLU activation functions and dropout applied for regularization. The architecture was selected for its simplicity and effectiveness across both regression and classification tasks, and it enables a controlled evaluation of SHAP-informed optimization strategies.
  • Input layer with dimensionality equal to the number of numerical features in the dataset.
  • First hidden layer with 128 units and ReLU activation.
  • Second hidden layer with 64 units and ReLU activation.
  • Final output layer with a single neuron:
    • For regression tasks (Ames Housing, California Housing), this layer outputs a continuous value without activation.
    • For classification tasks (Adult Census, Breast Cancer), a sigmoid activation is applied post hoc to convert logits to probabilities for binary prediction.
  • Dropout is applied after each hidden layer, with the rate selected as a hyperparameter in the grid search.
The architecture is lightweight and designed to isolate the effects of learning procedures augmented by SHAP attributions rather than maximizing model complexity. The same architecture is reused in each experimental trial, with only the input size and output activation adapted to the dataset type.

2.2.4. Optimization Methods

We organize our SHAP-informed optimization methods into two conceptual groups: (1) methods that adjust the global learning rate and (2) methods that scale gradients based on feature importance. In prior work, we introduced SHAP (global learning rate scaling) and SHAPG (gradient-based scaling), both of which demonstrated strong performance improvements but incurred considerable computational cost due to frequent SHAP value calculations.
To address this, the current study introduces two complementary, scalable variants:
  • C-SHAP, which mirrors the logic of SHAP but reduces overhead by computing SHAP values at cluster centroids only.
  • FastSHAP, which mirrors SHAPG by learning a surrogate model to approximate SHAP values and applying them during gradient updates.
These new methods were designed to retain the original optimization intent—guiding training with feature importance—while making SHAP-informed training feasible for larger datasets and more frequent updates.
All models in this study are trained using the Adam optimizer, which serves as the baseline for all experiments. The standard Adam configuration applies a global learning rate selected through grid search. Building upon this baseline, we evaluate several SHAP-guided strategies designed to incorporate feature importance into the training process.
  • Adam only: The baseline condition uses Adam with a fixed learning rate (tuned per dataset) and no SHAP integration. This configuration serves as a control to assess the incremental benefits of SHAP-informed guidance.
  • SHAP-based learning rate scaling: This approach modifies the global learning rate based on the average SHAP values computed from the training samples. SHAP values are normalized and scaled by a tunable factor to dynamically adjust learning rate magnitude, promoting learning that aligns with feature importance. These SHAP values are recomputed at fixed intervals (e.g., every 10 or 20 epochs) to reflect updated feature attributions as training progresses, as proposed in our previous work and supported by broader optimization principles where timely adaptation of learning rate has been shown to improve convergence [8].
  • SHAPG: gradient modification using SHAP: SHAPG extends the SHAP-based learning framework by directly modifying gradients during backpropagation. Specifically, the gradient of the input layer weights is scaled feature-wise in proportion to normalized SHAP values. This enforces a directional training signal that amplifies updates for more important features and suppresses less relevant ones, as described in our prior work [8]. While effective at directing gradient updates based on feature importance, SHAPG introduces higher computational cost compared to other methods. This is due to the repeated computation of exact SHAP values across a subset of training samples at fixed intervals and their use in per-feature gradient scaling during backpropagation. The added overhead stems both from the complexity of SHAP computation and the need to apply feature-wise operations during training, which cumulatively increase runtime.
  • C-SHAP: cluster-based SHAP-informed learning rate scaling: C-SHAP reduces the computational burden of SHAP-based learning rate scaling by leveraging clustering to approximate feature importance. Instead of computing SHAP values across a large batch of training samples, C-SHAP applies MiniBatch K-Means clustering to the training data and computes SHAP values only at the resulting centroids. These centroid-level attributions serve as representative approximations for the full dataset. The mean absolute SHAP values across centroids are normalized and used to adjust the global learning rate. This strategy retains the benefits of SHAP-based optimization while significantly improving efficiency, making it suitable for larger datasets. A similar clustering-based SHAP approximation approach was explored by Ranjbaran et al. [10], who demonstrated that SHAP values computed at cluster centroids preserved feature importance rankings while greatly reducing computation time.
  • FastSHAP: surrogate-guided gradient scaling: FastSHAP employs a learned surrogate model to approximate SHAP values efficiently. A lightweight neural explainer is trained on the main model’s outputs using sampled data. The surrogate produces feature attribution vectors that are used to scale the gradients of the input layer in a manner similar to SHAPG. This method is computationally more efficient and enables more frequent attribution updates without high overhead, adapting the surrogate-based estimation strategy from the work of Jethani et al. [11].
  • Hybrid: combined C-SHAP and FastSHAP: The hybrid method combines the global learning rate scaling of C-SHAP with the gradient modulation of FastSHAP. Cumulative SHAP values are used to adjust the learning rate schedule, while the surrogate-based FastSHAP mechanism simultaneously informs gradient scaling. This fusion seeks to capture both coarse-grained (global) and fine-grained (directional) SHAP signals within the same training loop.

2.2.5. Experimental Setup

We follow a standardized protocol for training and evaluating each neural network model across various SHAP-informed optimization methods. The experimental process is divided into two phases: a grid search over key hyperparameters using the validation set, followed by a final evaluation on the test set using the best-performing configuration.
All neural network weights are initialized using PyTorch’s (version 2.1.1, Meta, Menlo Park, CA, USA) default initialization for fully connected layers (torch.nn.Linear), which corresponds to Xavier uniform initialization. This approach maintains stable variance across layers and helps improve training convergence. Biases are initialized to zero.
Each dataset is shuffled and split into training (70%), validation (15%), and test (15%) subsets. The training set is used to fit the model and apply SHAP-informed logic. The validation set guides hyperparameter selection, and the test set is reserved for final performance evaluation. This train/val/test structure is repeated across 10 independent trials to ensure statistical robustness.
For methods that require SHAP value computation (SHAP, SHAPG, C-SHAP, FastSHAP, and Hybrid), SHAP values are updated at fixed intervals every 20 epochs. These updates guide either learning rate scaling (SHAP, C-SHAP) or gradient scaling (SHAPG, FastSHAP, Hybrid).
  • For C-SHAP, we apply MiniBatch K-Means clustering on the training set to generate a small number of representative centroids. The SHAP values are then computed only at these centroids to reduce computational cost while maintaining attribution fidelity.
  • For FastSHAP, a lightweight surrogate model is trained during each update interval to approximate SHAP values efficiently. This surrogate produces importance vectors used for scaling input-layer gradients.
All SHAP values used in this study were computed using the GradientExplainer from the SHAP library, which provides gradient-based SHAP approximations suitable for deep neural networks. This approach offers exact attributions with respect to model gradients and was consistently applied across all methods requiring SHAP input (SHAP, SHAPG, C-SHAP, and FastSHAP). No sampling-based approximations (e.g., KernelSHAP or TreeSHAP) were used. For FastSHAP, GradientExplainer was used to train a surrogate explainer, which then produced approximate SHAP values during training.
Each method is evaluated across a consistent grid of hyperparameters:
  • Learning rate: 0.01, 0.001, 0.0001.
  • Dropout rate: 0.1, 0.3, 0.7.
  • Epochs: 25, 50, 100.
  • LR scaling factor: 0.5, 1.0, 2.0 (for SHAP-informed methods only).
During training, early stopping is applied with a patience of 10 epochs based on validation loss. The best model is chosen based on the lowest average validation loss across all trials. The final evaluation metrics are computed on the test set using the selected model from each trial.
We report on the following metrics to evaluate model performance:
  • Test loss: mean prediction error on the test set (MSE for regression, BCE for classification).
  • Root Mean Square Error (RMSE): used for regression datasets.
  • Accuracy: used for binary classification datasets (Adult Census and Breast Cancer).
  • R2 score: indicates variance explained in regression tasks.
  • Training time: total runtime for each trial, including all epochs and SHAP-related computations.
Final performance scores are averaged over 10 trials and reported as mean ± standard deviation.

2.2.6. Computational Complexity

To quantify the computational cost of the proposed methods, we analyze training and testing complexity, along with SHAP update overhead. Let:
  • E: number of epochs.
  • N: number of training samples.
  • d: number of input features.
  • k: number of SHAP update intervals.
  • l: number of samples used in each SHAP update.
  • B: number of samples used per FastSHAP surrogate inference.
  • E′: number of epochs dedicated to training the FastSHAP surrogate model.
The Adam-only baseline incurs a standard training cost of O(E × N × d) and testing cost of O(d), with no SHAP-related overhead.
C-SHAP introduces periodic SHAP updates every k intervals, with each update requiring O(k × l × d + k × d2), yielding the total training cost O(E × N × d) plus update overhead. Inference remains O(d).
FastSHAP adds the cost of training a surrogate model, O(E′ × B × d), and amortizes SHAP estimation using learned mappings during training with cost O(B × d) per update. Final testing remains O(d), since surrogate models do not affect inference cost of the original network.
SHAP and SHAPG also introduce an update overhead of O(k × l × d + k × d2) per SHAP computation. In SHAP, this is used to scale the global learning rate, while in SHAPG, it is applied to modulate the first-layer gradients during backpropagation. Both methods maintain an O(E × N × d) training cost and O(d) testing cost.
Table 1 below summarizes the computational complexities of each of the optimization methods mentioned.
While SHAP, SHAPG, and C-SHAP share the same theoretical update complexity class, their runtime profiles differ significantly. SHAP and SHAPG rely on exact SHAP computations, which are computationally expensive, particularly for large l and d. In contrast, C-SHAP employs centroid-based approximations, reducing the number of SHAP evaluations needed per update. Additionally, SHAPG modifies individual weight gradients, incurring further runtime overhead during backpropagation. These factors explain why C-SHAP executes much faster despite matching theoretical bounds.

3. Results

3.1. Breast Cancer Dataset

The results from our experiments on the Breast Cancer dataset provide insights into the comparative performance of attribution-informed training strategies when integrated with the Adam optimizer. Table 2 summarizes the average test loss, accuracy, and training time across 10 trials for each method. These results reflect both predictive performance and computational efficiency.
Key Observations:
  • Test loss: The hybrid method achieved the lowest test loss (0.0028 ± 0.0015), outperforming all other methods including the baseline Adam optimizer (0.0033 ± 0.0018). The original SHAP method also showed competitive performance (0.0052 ± 0.0045), while the CSHAP and hybrid methods recorded higher variability in loss.
  • Accuracy: Multiple SHAP-informed methods surpassed the baseline in accuracy. SHAPG and the hybrid methods both achieved 97.5% accuracy, notably higher than Adam’s (96.86%), with SHAP and FastSHAP also exceeding 97%. CSHAP performed slightly lower at 97.09%, suggesting more stable accuracy gains from gradient-informed strategies.
  • Training time: The FastSHAP and hybrid methods showed a favorable trade-off between accuracy and training time. FastSHAP completed training in 1.31 ± 0.893 s on average, offering a significant runtime reduction compared to SHAPG (35.65 ± 13.37 s), which had the highest training time due to full SHAP recomputation per update. The hybrid method also achieved a competitive runtime of 3.04 ± 1.13 s while maintaining strong accuracy.
These results highlight the benefits of SHAP-guided approaches in improving predictive performance on small-scale classification tasks. In particular, FastSHAP offered the best balance between test loss, accuracy, and computational cost, demonstrating its potential as an efficient surrogate-guided optimization strategy. SHAPG delivered the highest accuracy overall but incurred substantial training time. These findings suggest that SHAP-based attribution can meaningfully enhance neural network training even in low-dimensional biomedical datasets, especially when integrated through scalable surrogate or hybrid strategies.

3.2. Adult Census Dataset

The Adult Census dataset presents a higher-dimensional, mixed-type classification task that tests the scalability and robustness of feature-attribution-guided methods. Table 3 reports the test loss, accuracy, and training time across 10 trials for each method, all using the Adam optimizer as a consistent base. These results help evaluate the performance trade-offs between learning improvements and computational efficiency.
Key Observations:
  • Test loss: The SHAPG method achieved the lowest test loss (0.01159 ± 0.00022), slightly outperforming all other SHAP-based methods as well as the Adam-only baseline (0.01160 ± 0.00008). The hybrid method also performed competitively (0.01159 ± 0.00022), while CSHAP and SHAP produced similar loss levels to the baseline.
  • Accuracy: All SHAP-informed methods achieved comparable or slightly improved accuracy over the Adam baseline (83.09%). SHAPG reached the highest accuracy (83.10%), closely followed by the hybrid (82.92%) and FastSHAP methods (82.91%), indicating marginal improvements across methods. The differences were small, suggesting that test loss was a more sensitive differentiator on this task.
  • Training time: FastSHAP demonstrated the best runtime performance among SHAP-informed methods, with 129.81 ± 35.30 s—faster than even the Adam baseline (147.44 ± 22.35 s). In contrast, SHAPG required the most time (287.91 ± 151.11 s), due to the cost of computing exact SHAP values and applying gradient adjustments at each update interval. Although FastSHAP also performs updates each interval, it leverages a learned surrogate model to approximate SHAP values efficiently, resulting in significantly faster training. The hybrid and SHAP methods required moderate training time, with CSHAP achieving relatively efficient performance at 135.25 ± 27.95 s.
Overall, the Adult Census results suggest that SHAP-informed optimization methods can scale to larger datasets without significant drops in performance or excessive training costs. While the accuracy gains were relatively modest compared to the Breast Cancer dataset, the hybrid and FastSHAP methods demonstrated competitive predictive accuracy with favorable training times. These results reinforce the potential of scalable SHAP-guided methods to enhance training on tabular classification tasks with complex feature interactions.

3.3. Ames Housing Dataset

The Ames Housing dataset offers a moderately sized regression task with a large number of numerical features, providing a valuable test case for evaluating SHAP-driven parameter tuning in continuous prediction settings. Table 4 presents the results across 10 trials, including average test loss, RMSE, R2 score, and training time for each method, all trained using the Adam optimizer.
Key Observations:
  • Test loss: SHAPG achieved the lowest test loss (0.00399 ± 0.00143), outperforming all other methods, including FastSHAP (0.00477 ± 0.00204) and the SHAP method (0.00482 ± 0.00408). Both the CSHAP and hybrid methods also notably improved upon the Adam-only baseline (0.00761 ± 0.00450), highlighting the value of SHAP-based guidance in regression tasks.
  • RMSE: SHAPG and FastSHAP delivered the best RMSE scores at 0.0621 and 0.0675, respectively—lower than Adam’s 0.0830—indicating better predictive calibration across samples. The CSHAP and hybrid methods produced slightly higher RMSE values, suggesting more variability in model predictions.
  • R2 Score: SHAPG achieved the highest R2 score (0.7799), reflecting the model’s improved ability to explain variance in the target variable. FastSHAP and CSHAP also yielded strong scores (0.7644 and 0.7470), while the hybrid method showed the weakest performance (0.6210), suggesting that combining global and local SHAP signals may have introduced instability in this regression setting.
  • Training Time: FastSHAP again delivered excellent efficiency, completing training in just 2.61 ± 0.89 s. SHAPG (57.82 ± 21.67 s) and SHAP (43.17 ± 9.32 s) required significantly more time due to SHAP value computations at regular intervals. CSHAP achieved slightly improved efficiency without the consistent other metrics, while the hybrid method was slower than FastSHAP but faster than SHAP and SHAPG.
Overall, the Ames Housing results emphasize that training with SHAP guidance strategies can substantially enhance regression performance over standard Adam training. While FastSHAP stood out for its balance of low test loss, high R2, and minimal training time, SHAPG provided the strongest accuracy but at a greater computational cost. These findings underscore that surrogate-based methods like FastSHAP are particularly well suited for moderate-sized tabular regression problems where both speed and interpretability matter.

3.4. California Housing Dataset

The California Housing dataset presents a large-scale regression task with socioeconomic variables, offering a more challenging setting for evaluating the generalization and scalability of SHAP-enhanced learning methods. Table 5 summarizes the results across 10 trials for each method, reporting test loss, RMSE, R2 score, and training time using the Adam optimizer.
Key Observations:
  • Test loss: The hybrid method recorded the lowest test loss (0.00802 ± 0.00029), slightly outperforming all other methods, including CSHAP (0.00803 ± 0.00035) and SHAPG (0.00808 ± 0.00044). The FastSHAP method also performed competitively, while the Adam-only baseline had the highest test loss overall.
  • RMSE: All SHAP-based methods improved upon the Adam baseline (0.0905), with the hybrid and CSHAP methods achieving the lowest RMSEs at 0.0895 and 0.0896, respectively. These results suggest that incorporating gradient scaling and cluster-based SHAP approximations can enhance stability and consistency in model predictions for high-dimensional tabular regression.
  • R2 score: SHAPG yielded the highest R2 score (0.8089), indicating the best variance explanation among all models. The hybrid (0.8086) and CSHAP methods (0.8085) followed closely behind. These results highlight the consistent ability of SHAP-informed methods to improve model expressiveness on large tabular datasets.
  • Training time: CSHAP was the most efficient SHAP-based method, completing training in 54.79 ± 16.67 s. FastSHAP followed with a slightly longer runtime (63.93 ± 11.14 s), while SHAPG and SHAP both required over two minutes. The hybrid method offered a compromise between speed and performance, completing training in 68.67 ± 15.82 s.
These findings reinforce the value of optimization with SHAP for large-scale regression tasks. While CSHAP delivered the lowest test loss with minimal computational overhead, SHAPG achieved the highest R2 score, demonstrating strong representational capacity. FastSHAP again proved to be an efficient alternative, offering solid accuracy and fast training. Taken together, the California Housing results highlight the scalability of SHAP-based optimization strategies, especially when efficient attribution approximations are employed.

4. Discussion

This study examined the effectiveness of integrating SHAP-based feature attribution into neural network optimization using the Adam optimizer, with a focus on improving learning through an adaptive learning rate or gradient modulation. We evaluated five SHAP-informed strategies—SHAP, SHAPG, CSHAP, FastSHAP, and a hybrid method—across four diverse datasets: Breast Cancer, Adult Census, Ames Housing, and California Housing. By isolating the impact of these methods without confounding variables like additional optimizers or learning rate decay schedules, we aimed to assess their practical value and scalability in both classification and regression settings.
The SHAP-informed methods consistently outperformed the Adam-only baseline across all datasets, though the degree of improvement and cost varied. SHAPG delivered some of the strongest overall results, particularly in accuracy and R2 score, likely due to its reliance on exact SHAP values calculated at each update interval. These values preserve high-fidelity estimates of localized feature importance, which can guide gradient updates with greater precision—albeit at a substantial computational cost. Recent work has emphasized that high-fidelity SHAP-based feature guidance can improve learning outcomes, whether by informing gradient updates during optimization or through effective feature selection [19,21]. FastSHAP, in contrast, demonstrated highly competitive performance while maintaining low training times. Its surrogate model appears to learn stable patterns of feature influence, enabling it to provide sufficiently accurate importance estimates without incurring the overhead of full SHAP computation. This balance of efficiency and effectiveness makes FastSHAP especially compelling in both small-scale and high-dimensional settings, aligning with similar findings in real-time attribution models [11].
SHAP-based methods improved performance across all tasks, but the extent of improvement varied by dataset. In classification, gradient methods (SHAPG, FastSHAP) achieved the highest accuracy on the Breast Cancer dataset, while only true SHAP altered gradients surpassed the Adam-only method in terms of accuracy on the Adult Census dataset. In regression, again, the SHAP gradient methods led in RMSE and R2 for the Ames Housing dataset, while methods utilizing CSHAP (pure CSHAP and the hybrid method) delivered the lowest test loss and RMSE on the California Housing dataset. These results suggest that methods leveraging either surrogate-based approximations or centroid-based aggregation can match or outperform exact SHAP-based approaches at a fraction of the cost, particularly when task complexity increases. This is highly in agreeance with the work of Figueroa Baraza et al. and our previous work indicating that utilization of feature importance in high complexity datasets can be advantageous [8,19].
A broader pattern emerged: methods like CSHAP and FastSHAP that approximate SHAP values through clustering or learned surrogates provided the best trade-offs between performance and scalability. Their consistent success across tasks—especially on larger datasets like California Housing—demonstrates that effective approximation techniques can retain the benefits of SHAP-guided optimization while vastly improving feasibility. The fixed SHAP update interval of 20 epochs used throughout this study offered a reasonable balance between responsiveness and training overhead. Nevertheless, future work may explore tuning this interval to adapt to task-specific learning dynamics.
While SHAPG remains the most accurate in certain tasks, its high computational demand limits its use in time-sensitive or resource-constrained applications. The observed trade-offs highlight a core principle: there is measurable value in computing exact SHAP attributions, particularly for capturing sharp gradients or highly localized influence—but when approximation methods are carefully designed, they can achieve near-equivalent generalization at a far lower cost.
Importantly, the effectiveness of attribution-informed training appears to be dataset-dependent. These methods produced the most pronounced benefits in high-dimensional datasets with complex feature relationships—such as Adult Census, where gradient-informed methods improved accuracy over Adam, and California Housing, where both CSHAP and gradient altering methods outperformed the baseline in R2. In contrast, simpler datasets with less feature variability, like Breast Cancer, showed narrower margins of improvement. This suggests that SHAP-informed training is particularly useful in tasks where capturing fine-grained feature relevance is critical to generalization. Prior work in feature selection for neural networks supports this finding: interpretability-based methods tend to offer more predictive benefit when data complexity necessitates more selective learning [22].
Future work should explore strategies for further reducing computational overhead while maintaining performance. Adaptive SHAP update intervals may reduce unnecessary computation in early or late training phases. Although we evaluate SHAP-informed methods across four datasets, future work should provide finer-grained comparisons of computational cost against traditional optimizers on a per-sample or per-epoch basis, particularly as dataset size and dimensionality vary. Exploring integrations with alternative optimizers like RMSprop or AdaGrad may reveal new synergies. Although our prior work investigated SHAP-guided training with SGD, those methods yielded a lower performance than Adam-based counterparts in most settings, and were therefore excluded from this study to maintain clarity of scope. Nevertheless, evaluating the proposed C-SHAP and FastSHAP methods with optimizers such as SGD or L-BFGS could offer additional insight into their generality and convergence behavior under alternative optimization dynamics—particularly since quasi-Newton methods like L-BFGS leverage curvature information, which may interact favorably with SHAP-informed gradients due to their emphasis on feature-wise influence. Finally, extending these methods to more complex neural architectures or non-tabular domains will help test their generalizability beyond the scope of this work. Indeed, recent work on Shapley explanations (e.g., CF-SHAP, FF-SHAP) strives to increase efficiency while preserving explanatory power, which could serve as an avenue for expansion [23,24].
While this study focuses primarily on improving training performance through SHAP-informed optimization, we acknowledge that real-world deployment presents additional challenges and opportunities. In production ML pipelines, computational overhead, inference time latency, and model transparency are critical concerns. The SHAPG and FastSHAP methods, which apply gradient-level adjustments, may be suitable for online learning scenarios or retraining workflows, while C-SHAP’s efficiency via centroid approximations makes it more attractive for batch learning in resource-constrained settings. Future work should explore how these methods integrate with industry-standard pipelines, including AutoML frameworks, MLOps tools, and edge deployment contexts where explainability and efficiency must be jointly optimized.
In conclusion, this study demonstrates that SHAP-informed optimization can notably improve neural network performance across a range of tasks. Scalable approaches like CSHAP and FastSHAP offer the promise of matching or exceeding the performance of exact attribution-based methods while being far more computationally efficient. These findings establish a strong foundation for future research into principled, attribution-guided training strategies—particularly in settings where both generalization and scalability are required.

Author Contributions

Conceptualization, J.G.; Formal Analysis, J.G.; Investigation, J.G.; Methodology, V.S.S.; Project Administration, V.S.S.; Software, J.G.; Writing—Original Draft, J.G.; Writing—Review and Editing, J.G. and V.S.S. All authors have read and agreed to the published version of the manuscript.

Funding

This research received no external funding.

Data Availability Statement

The datasets used in this study are publicly accessible through the following sources: Breast Cancer Dataset: Available via the Scikit-learn library (load_breast_cancer function in Python). Documentation and access details: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_breast_cancer.html (accessed on 13 April 2025). Adult Census Income Dataset: Available from OpenML (dataset ID: 1590) and accessible via the Scikit-learn library (fetch_openml(name = ‘adult’, version = 2) in Python). Documentation: https://www.openml.org/d/1590 (accessed on 17 April 2025). Ames Housing Dataset: Available from OpenML and accessible via the Scikit-learn library (fetch_openml(name = ‘house_prices’, as_frame = True) in Python). Documentation: https://www.openml.org/d/42165 (accessed on 16 April 2025). California Housing Dataset: Available via the Scikit-learn library (fetch_california_housing function in Python). Documentation: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html (accessed on 21 April 2025). All datasets are publicly available and widely used for benchmarking in machine learning research.

Conflicts of Interest

The authors declare no conflicts of interest.

References

  1. Jarrahi, M.H.; Memariani, A.; Guha, S. The Principles of Data-Centric AI (DCAI). Commun. ACM 2023, 66, 84–92. [Google Scholar] [CrossRef]
  2. Wilson, A.C.; Roelofs, R.; Stern, M.; Srebro, N.; Recht, B. The Marginal Value of Adaptive Gradient Methods in Machine Learning. arXiv 2017, arXiv:1705.08292. [Google Scholar]
  3. Zou, D.; Cao, Y.; Li, Y.; Gu, Q. Understanding the Generalization of Adam in Learning Neural Networks with Proper Regularization. In Proceedings of the International Conference on Learning Representations (ICLR), Kigali, Rwanda, 1–5 May 2023. [Google Scholar]
  4. Duchi, J.; Hazan, E.; Singer, Y. Adaptive subgradient methods for online learning and stochastic optimization. J. Mach. Learn. Res. 2011, 12, 2121–2159. [Google Scholar]
  5. Ross, A.S.; Hughes, M.C.; Doshi-Velez, F. Right for the Right Reasons: Training Differentiable Models by Constraining Their Explanations. In Proceedings of the 26th International Joint Conference on Artificial Intelligence, Melbourne, Australia, 19–25 August 2017; pp. 2662–2668. [Google Scholar]
  6. Lundberg, S.M.; Lee, S.-I. A Unified Approach to Interpreting Model Predictions. Adv. Neural Inf. Process. Syst. (NeurIPS) 2017, 30, 4765–4774. [Google Scholar]
  7. Adadi, A.; Berrada, M. Peeking Inside the Black-Box: A Survey on Explainable Artificial Intelligence (XAI). IEEE Access 2018, 6, 52138–52160. [Google Scholar] [CrossRef]
  8. Graham, J.; Sheng, V. SHAP-Informed Neural Network Optimization. Mathematics 2024, 12, 456. [Google Scholar]
  9. Hamilton, R.I.; Papadopoulos, P.N. Using SHAP Values and Machine Learning to Understand Trends in the Transient Stability Limit. IEEE Trans. Power Syst. 2024, 39, 1384–1397. [Google Scholar] [CrossRef]
  10. Ranjbaran, G.; Recupero, D.R.; Roy, C.K.; Schneider, K.A. C-SHAP: A Hybrid Method for Fast and Efficient Interpretability. Appl. Sci. 2025, 15, 672. [Google Scholar] [CrossRef]
  11. Jethani, N.; Covert, I.; Lee, S.-I.; Lundberg, S.M. FastSHAP: Real-Time Shapley Value Estimation. In Proceedings of the 10th International Conference on Learning Representations (ICLR 2022), Virtual, 25–29 April 2022. [Google Scholar]
  12. Kingma, D.P.; Ba, J. Adam: A Method for Stochastic Optimization. In Proceedings of the International Conference on Learning Representations (ICLR), San Diego, CA, USA, 7–9 May 2015. [Google Scholar]
  13. Bottou, L. Large-Scale Machine Learning with Stochastic Gradient Descent. In Proceedings of the COMPSTAT’2010: 19th International Conference on Computational Statistics, Paris, France, 22–27 August 2010; pp. 177–186. [Google Scholar]
  14. Kabiri, H.; Ghanou, Y.; Khalifi, H.; Casalino, G. AMAdam: Adaptive modifier of Adam method. Knowl. Inf. Syst. 2024, 66, 3427–3458. [Google Scholar] [CrossRef]
  15. Huang, H.; Wang, C.; Dong, B. Nostalgic Adam: Weighting more of the past gradients when designing the adaptive learning rate. In Proceedings of the 28th International Joint Conference on Artificial Intelligence (IJCAI) 2019, Macao, China, 10–16 August 2019; pp. 2556–2562. [Google Scholar]
  16. Luo, L.; Xiong, Y.; Liu, Y.; Sun, X. Adaptive Gradient Methods with Dynamic Bound of Learning Rate. In Proceedings of the International Conference on Learning Representations (ICLR 2019), New Orleans, LA, USA, 6–9 May 2019. [Google Scholar]
  17. Ruder, S. An Overview of Gradient Descent Optimization Algorithms. arXiv 2016, arXiv:1609.04747. [Google Scholar]
  18. Wang, X.; Magnússon, S.; Johansson, M. On the Convergence of Step Decay Step-Size for Stochastic Optimization. Adv. Neural Inf. Process. Syst. 2021, 34, 14226–14238. [Google Scholar]
  19. Figueroa Barraza, J.; López Droguett, E.; Martins, M.R. Towards Interpretable Deep Learning: A Feature Selection Framework for Prognostics and Health Management Using Deep Neural Networks. Sensors 2021, 21, 5888. [Google Scholar] [CrossRef] [PubMed]
  20. Le, T.T.H.; Kim, H.; Kang, H.; Kim, H. Classification and explanation for intrusion detection system based on ensemble trees and SHAP method. Sensors 2022, 22, 1154. [Google Scholar] [CrossRef] [PubMed]
  21. Marcílio, W.E.; Eler, D.M. From explanations to feature selection: Assessing SHAP values as feature selection mechanism. In Proceedings of the 2020 33rd SIBGRAPI Conference on Graphics, Patterns and Images (SIBGRAPI), Porto de Galinhas, Brazil, 7–10 November 2020; pp. 340–347. [Google Scholar] [CrossRef]
  22. Potharlanka, J.L.; Bhat, N.M. Feature importance feedback with Deep Q process in ensemble-based metaheuristic feature selection algorithms. Sci. Rep. 2024, 14, 2923. [Google Scholar] [CrossRef] [PubMed]
  23. Albini, E.; Long, J.; Dervovic, D.; Magazzeni, D. Counterfactual Shapley Additive Explanations. In Proceedings of the 2022 ACM Conference on Fairness, Accountability, and Transparency (FAccT ’22), Seoul, Republic of Korea, 21–24 June 2022; ACM: New York, NY, USA, 2022; pp. 1054–1070. [Google Scholar]
  24. Alkhatib, A.; Boström, H. Fast Approximation of Shapley Values with Limited Data. In Proceedings of the 14th Scandinavian Conference on Artificial Intelligence SCAI 2024, Jönköping, Sweden, 10–11 June 2024; pp. 95–100. [Google Scholar]
Figure 1. Overview of the full training and evaluation workflow. The diagram illustrates key stages of the SHAP-informed training process, including data preparation, model initialization, SHAP value computation, learning rate or gradient adjustment, and evaluation. The outer blocks represent grid search over hyperparameters and repeated trials across randomized data splits. SHAP, SHAPG, CSHAP, FastSHAP, and hybrid logic are implemented across separate scripts.
Figure 1. Overview of the full training and evaluation workflow. The diagram illustrates key stages of the SHAP-informed training process, including data preparation, model initialization, SHAP value computation, learning rate or gradient adjustment, and evaluation. The outer blocks represent grid search over hyperparameters and repeated trials across randomized data splits. SHAP, SHAPG, CSHAP, FastSHAP, and hybrid logic are implemented across separate scripts.
Mathematics 13 02152 g001
Table 1. Summary of training, SHAP update, and testing complexity for each optimization method.
Table 1. Summary of training, SHAP update, and testing complexity for each optimization method.
MethodTraining ComplexitySHAP Update OverheadTest Complexity
Adam onlyO(E × N × d)NoneO(d)
SHAPO(E × N × d)O(k × l × d + k × d2)O(d)
SHAPGO(E × N × d)O(k × l × d + k × d2)O(d)
CSHAPO(E × N × d)O(k × l × d + k × d2)O(d)
FastSHAPO(E × N × d) + O(E′ × B × d)O(B × d)O(d)
Table 2. Summary of average performance metrics across SHAP-informed optimization methods on the Breast Cancer dataset. Each method’s results include the mean and standard deviation of test loss, accuracy, and training time across 10 trials. The comparison includes SHAP-based learning rate scaling (SHAP, CSHAP), SHAP-informed gradient adjustments (SHAPG, FastSHAP), and the hybrid method combining both approaches. Adam-only training is used as the baseline.
Table 2. Summary of average performance metrics across SHAP-informed optimization methods on the Breast Cancer dataset. Each method’s results include the mean and standard deviation of test loss, accuracy, and training time across 10 trials. The comparison includes SHAP-based learning rate scaling (SHAP, CSHAP), SHAP-informed gradient adjustments (SHAPG, FastSHAP), and the hybrid method combining both approaches. Adam-only training is used as the baseline.
MethodTest LossAccuracyTraining Time
Adam only0.0033 ± 0.00180.9686 ± 0.01651.9019 ± 0.4432
SHAP0.0052 ± 0.00450.9721 ± 0.010716.9329 ± 9.6288
SHAPG0.0041 ± 0.00280.9756 ± 0.016835.6517 ± 13.3737
CSHAP0.0118 ± 0.02400.9709 ± 0.00942.1851 ± 0.3981
FastSHAP0.0066 ± 0.00730.9721 ± 0.01821.3071 ± 0.8928
Hybrid0.0028 ± 0.00150.9756 ± 0.01213.0419 ± 1.1327
Bolded values represent the top-performing method for each metric; italicized values denote the second-best.
Table 3. Summary of average performance metrics across SHAP-informed optimization methods on the Adult Census dataset. Each method’s results include the mean and standard deviation of test loss, accuracy, and training time across 10 trials. The methods compared include SHAP-based learning rate scaling (SHAP, CSHAP), SHAP-informed gradient adjustment (SHAPG, FastSHAP), and a hybrid strategy combining both gradient and learning rate scaling. Adam-only training serves as the baseline.
Table 3. Summary of average performance metrics across SHAP-informed optimization methods on the Adult Census dataset. Each method’s results include the mean and standard deviation of test loss, accuracy, and training time across 10 trials. The methods compared include SHAP-based learning rate scaling (SHAP, CSHAP), SHAP-informed gradient adjustment (SHAPG, FastSHAP), and a hybrid strategy combining both gradient and learning rate scaling. Adam-only training serves as the baseline.
MethodTest LossAccuracyTraining Time
Adam only0.01160 ± 0.000080.83019 ± 0.00270147.4386 ± 22.3537
SHAP0.01162 ± 0.000110.82970 ± 0.00171224.6501 ± 59.8349
SHAPG0.01153 ± 0.000140.83109 ± 0.00385287.9111 ± 51.1131
CSHAP0.01161 ± 0.000190.82851 ± 0.00476135.2455 ± 27.9485
FastSHAP0.01167 ± 0.000180.82919 ± 0.00378129.8111 ± 35.3003
Hybrid0.01159 ± 0.000220.82921 ± 0.00534194.4226 ± 49.8023
Bolded values represent the top-performing method for each metric; italicized values denote the second-best.
Table 4. Summary of average performance metrics across SHAP-informed optimization methods on the Ames Housing dataset. Each method’s results include the mean and standard deviation of test loss, RMSE, R2, and training time. Methods compared include SHAP, SHAPG, CSHAP, FastSHAP, and a hybrid strategy, with Adam-only training serving as the baseline.
Table 4. Summary of average performance metrics across SHAP-informed optimization methods on the Ames Housing dataset. Each method’s results include the mean and standard deviation of test loss, RMSE, R2, and training time. Methods compared include SHAP, SHAPG, CSHAP, FastSHAP, and a hybrid strategy, with Adam-only training serving as the baseline.
MethodTest LossRMSER2Training Time
Adam Only0.00761 ± 0.004500.0830 ± 0.02670.7255 ± 0.17681.5512 ± 0.6533
SHAP0.00482 ± 0.004080.0651 ± 0.02420.7373 ± 0.247743.1751 ± 9.3204
SHAPG0.00399 ± 0.001430.0621 ± 0.01150.7799 ± 0.094657.8169 ± 21.6683
CSHAP0.00631 ± 0.003240.0769 ± 0.02000.7470 ± 0.10092.0344 ± 0.9757
FastSHAP0.00477 ± 0.002040.0675 ± 0.01470.7644 ± 0.09352.6061 ± 0.8900
Hybrid0.00719 ± 0.004870.0804 ± 0.02690.6210 ± 0.23544.2156 ± 1.5527
Bolded values represent the top-performing method for each metric; italicized values denote the second-best.
Table 5. This table presents the average performance metrics from the grid search for each method on the California Housing dataset. It includes values for average test loss, RMSE, R2, and training time along with their respective standard deviations. Each row corresponds to a SHAP-informed optimization method or the Adam-only baseline.
Table 5. This table presents the average performance metrics from the grid search for each method on the California Housing dataset. It includes values for average test loss, RMSE, R2, and training time along with their respective standard deviations. Each row corresponds to a SHAP-informed optimization method or the Adam-only baseline.
MethodTest LossRMSER2Training Time
Adam Only0.00818 ± 0.000250.09045 ± 0.001380.8032 ± 0.006943.3565 ± 5.2618
SHAP0.00813 ± 0.000370.09012 ± 0.002060.8016 ± 0.0097129.8534 ± 25.2650
SHAPG0.00808 ± 0.000440.08984 ± 0.002440.8089 ± 0.0101129.2503 ± 30.2712
CSHAP0.00803 ± 0.000350.08960 ± 0.001980.8085 ± 0.007054.7930 ± 16.6717
FastSHAP0.00812 ± 0.000300.09007 ± 0.001650.8066 ± 0.004663.9276 ± 11.1408
Hybrid0.00802 ± 0.000290.08954 ± 0.001640.8086 ± 0.008268.6670 ± 15.8209
Bolded values represent the top-performing method for each metric; italicized values denote the second-best.
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Graham, J.; Sheng, V.S. Scalable SHAP-Informed Neural Network. Mathematics 2025, 13, 2152. https://doi.org/10.3390/math13132152

AMA Style

Graham J, Sheng VS. Scalable SHAP-Informed Neural Network. Mathematics. 2025; 13(13):2152. https://doi.org/10.3390/math13132152

Chicago/Turabian Style

Graham, Jarrod, and Victor S. Sheng. 2025. "Scalable SHAP-Informed Neural Network" Mathematics 13, no. 13: 2152. https://doi.org/10.3390/math13132152

APA Style

Graham, J., & Sheng, V. S. (2025). Scalable SHAP-Informed Neural Network. Mathematics, 13(13), 2152. https://doi.org/10.3390/math13132152

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop