1. Introduction
Trust is an essential component to the adoption and utilization of new artificial intelligence (AI) tools [
1]. Trust is especially important when AI attempts to replace a non-black box technology but no mechanism is offered for determining accountability or mitigating errors and bias. Therefore, developing explainable AI techniques and methods, especially those that are highly interpretable, is beneficial to user adoption [
2,
3].
The most advanced explainable AI techniques are typically designed for image classification tasks because the continuous values and ready interpretability of image subsets lend themselves to the two major types of explainable AI techniques. The first of these uses the model’s activation, or backpropagated signals, to determine the most important features. Techniques that use the activation signals include the class activation map (CAM) [
4] which uses global average pooling; techniques like Grad-CAM [
5], Zeigler and Fergus [
6], and Wagner et al. [
7], that use gradients; or combinations of activations and gradients, like Grad-CAM++ [
8], among others. These methods have been most successfully applied to convolutional neural networks and fully connected networks. While convolutional neural networks can be useful for text classification, recurrent neural networks and transformer architectures are more frequently used for text classification because they have substantially more complicated interactions and can produce maps of disconnected, important tokens.
The second major type of explainable AI technique is perturbation-based, where small changes are made to the inputs to identify which feature has the highest impact on the final output. Shapley additive explanations (SHAP) [
9] values and local interpretable model-agnostic explanations (LIME) [
10] fall into this category. Both of these methods are computationally intensive and incur most of the computational cost at runtime, rather than upfront.
Unlike either of these two major types of explainable AI techniques, the “what-you-see” algorithm conceived by Stalder et al. [
11] works by creating a new neural network—the Explainer—that generates a mask for an input image. This mask blocks out all the irrelevant parts of the image and passes that masked image through the original AI model that needs explaining—the Explanandum. This explainable AI method incurs all the computational costs when training the Explainer. Then, an image need only be passed through the trained Explainer to generate an explanation (i.e., the unmasked parts of the image). Additionally, the explainable AI method uses a loss function that values continuous regions, so the masks themselves are easily interpreted.
Neither this technique nor many of the existing explainable AI techniques can be applied outright to tokenized data, such as text or nucleotide/amino acid sequences; for example, perturbing a token is not meaningful because tokens are categorical. An embedding can be perturbed, but that will reveal more about the components of the embedding vector, not the tokens that preceded the embeddings. The most common explainable AI algorithm for tokenized data is SHAP values [
9] and LIME [
10], but the results are often difficult to interpret because tokens are evaluated one-at-a-time. Text-based transformer and recurrent neural network architectures, which outperform CNNs-based architectures, do not consider the local structure of the inputs, so important tokens may be separated by a large distance. Additionally, because the importance values range from −1 to 1, a large number of small positive values can overrule a single large negative value, or visa-versa, obscuring the true importance landscape. Rationalization approaches [
12,
13,
14], which use models to either select or generate text to explain the classification, are somewhat less common and tend to require complicated model training processes.
We propose a novel approach to explainable AI for text that combines the extractive capabilities of rationalization approaches with the simplicity of training in Stadler et al. [
11] to provide meaningful explanations for classification decisions on tokenized sequences. We demonstrate our method on a genomic taxonomic classification task at three nested levels of classification—superkingdom, phylum and genus—using the Nucleotide Transformer 50 million parameter model (NT50m) [
15] and the BERTax dataset [
16].
2. Related Studies
Explainable AI techniques for tokenized sequences are relatively limited, compared to those for other input types. The majority rely on methods like LIME and SHAP, or derive from those. These include the application of SHAP or SHAP-based methods [
17] for sentiment analysis for language models and large language models (LLMs) [
18]; LIME for hate speech detection using a LSTMs [
19] and XLM-RobERTa model [
20]; and LIME for topic classification using transformer architectures [
21]. In some instances, both LIME and SHAP are used together to explain the classification of AI-generated versus human-generated text, where the classification decision is issued by models with transformer architectures [
22,
23].
Gradient-based approaches such as AGrad [
24] and RePAGrad [
24], have been used to explain review-based datasets, including Yelp, Amazon, and IMDB reviews. However, AGrad and RePAGrad are reliant on attention architectures that are specific to transformers, and thus not generally applicable.
Additionally, rationalization approaches are those where a model produces an explanation for the classification decision. This can be accomplished with a generative model producing explanations [
25]. However, like all generative approaches, there is a risk of hallucinations and explanations that are not tied to the input. Among these, supervised approaches extract the relevant portions of text using a generative model, which is trained using reinforcement learning to select sequences and a classifier to evaluate those selected sequences [
12]. These are difficult to train and unstable. The training can be simplified by using re-parameterized gradients [
13] or by using importance scores from an approach like LIME to create an extraction model [
14].
3. Methods
In the Stalder et al. [
11] approach, as in our approach (
Figure 1), the input is fed to the Explainer, which provides a set of masks that correspond to each class,
. Each value of a mask ranges from [0, 1], having passed through a sigmoid function. The masks that correspond to a ground truth target of 1 are aggregated using a max function to form the target mask,
. The masks that correspond to the ground truth target of 0 are aggregated using a max function to form the non-target mask (
). The target mask also has an inverse mask
. The mask is then used to cover up parts of the input as it is fed into the Explanandum.
The approach proposed in this paper is very similar to that of the Stalder et al. [
11] approach, with two slight modifications, the first of which is depicted in
Figure 1, as it pertains to how the tokens are masked. Because the pixel values in images are continuous, it is trivial to take the Hadamard product of the mask of real values against the image, and pass the result to the Explanandum. The product of the mask and the image is a masked image, where masked values nearer to zero are ‘darker’. Because the mask values are continuous, the entire process of passing an input image through the Explainer, then its output through the Explanandum, is differentiable. As a result, the weights of the Explainer can be changed, while the Explanandum weights are frozen.
However, directly masking a token can only be done with a zero or a one: the product of a mask value of 0.25 and a discrete token is not a valid operation. One way around this is to use a specific [mask] token represented by the zero vector. There are two problems with this. The first is that not all models are implemented with this already built into the model architecture, making it difficult to implement. The second problem is that, if using the Explainer–Explanandum paradigm, you would need some type of step function to force a zero or one, which can make learning difficult. We overcome this challenge by changing where the Hadamard product is taken. In our approach, we separate the Explanandum into pre- and post-embedding layers and take the Hadamard product with the embedded tokens after they pass through the pre-embedding layers, rather than applying them to the tokens themselves. The masks
for images and tokenized text are the same shape as the input, so in order to apply this mask to embedded tokens, the mask vector is repeated, expanding it from a matrix of dimensions
, where
is the length of the sequence, to
, where
is the dimension of the embedding vector. The Hadamard product of this repeated mask is then taken against the embedding vectors. This operation changes the magnitude of the embedding vector without changing the orientation of the embedding vector in the embedding space (
Figure 2). Then, the masking operation is similar to that for an image, preserving differentiability.
This approach works with any architecture, with one caveat: the softmax outputs for a blank sequence and for a sequence masked in its entirety should both be approximately uniform. Conceptually, this makes sense as inputting nothing should not result in any classification. We found that this was not always the case with all architectures, specifically in transformer architectures when the classifier uses a mean pooling across all the tokens, as opposed to just the [CLS] token output. With NT50m, mean pooling is taken for each token prior to the fully connected layers for classification, meaning an empty sequence or a completely masked sequence did not result in a uniform output of the Explanandum classifier. To fix this, we found it necessary to take the Hadamard product of the mask and the embedding and the Hadamard product of the mask and the output of the transformer prior to mean pooling and classification. The difference in architecture is shown in
Figure 3. This is not necessary when using the CLS output token for classification.
The second modification to adapt the what-you-see model to what-you-read functionality involves the loss function. This modification is the simple flattening of the formulation from two dimensions (for images) to one dimension (for text). More detail can be found in Stadler et al. [
11]. The new loss function (Equation (1)) has the same four components with the same hyperparameters:
where
is the input sequence,
is the class(es) the input
belongs to,
is the mask for a given class
,
is the set of all masks,
and
are the mask and mask complement, and
is the masks for the non-correct classes. For our experiment, the hyperparameters
,
, and
are all 1. All hyperparameter values were selected based on the values from Stadler et al. [
11].
The loss term
is the standard classification loss using cross-entropy (Equation (2)), to accommodate multi-class classification:
The bracket is the Iverson bracket that returns 1 input belongs to class and 0 otherwise.
The loss term
(Equation (3)) is the negative entropy of the complement of the mask
. When the complement of the mask is applied, only the least important tokens are exposed; then, maximizing the entropy of the downstream classification probabilities is equivalent to maximizing the uncertainty, pushing the probabilities as near to uniform as possible when important tokens are masked. Text that has been stripped of relevant information should not produce a meaningful classification result.
where
is the Explanandum and outputs a set of probabilities.
As explained in Stadler et al. [
11], the first two loss terms of Equation (1),
and
, do not directly impact the mask itself. The first loss term
incentivizes the model to create a mask that does not hide anything, while the second loss term
incentivizes the model to create a mask that hides everything. The loss term
(Equation (4)) is the area loss term, which introduces some balance in the mask generation such that the mask neither obscures nor reveals too much
The first two terms in are the mean values of the mask, , where is the length of the sequence, and the mean values of the non-target mask, .
The third term in
(Equation (5)) regulates how much of the image is covered, penalizing covering too much or too little. It does this by first sorting the values of the class segmentation mask
. Using the defined fractional minimum area
and the fractional maximum area
that determine at what point a penalty is incurred, where
. Any mask with more area visible than
(a value of 1) covers too much area, and any mask with less area visible than
(a value of 0) covers too little area. Then, we define two vectors
and
. The bounding measure
is defined:
There is only a penalty if the mask area is smaller than or larger than .
The fourth loss term
(Equation (6)) encourages local smoothness amongst neighbors, represented through total variation loss on the target and non-target masks:
3.1. Explainer
The architecture for the Explainer is relatively straightforward recurrent neural network architecture (
Figure 4). As sequences are variable in length, the core of the Explainer model is a bidirectional Long Short-Term Memory (LSTM) layer with two layers and 40 nodes followed by an output rectified linear unit (ReLU) activation function. The input to the LSTM is an embedding layer with a vocabulary of 4107, equal to the vocabulary of the NT50m model when dimensionality is one hundred. The bidirectional LSTM layers produce a vector for every input token, each with length of 160 after the memory and cell states of the forward and backward passes are concatenated. Following batch normalization, a dense layer is applied with a sigmoid activation function and length equal to that of the number of classes,
. The final output shape for a sequence of length
is a
dimensional matrix. No hyperparameter search was performed for the Explainer design, as a LSTM node size of 40 was chosen to get a size of 160 when concatenating all memory and cell states, as it was approximately 10% of the largest number of classes while balanced resource constraints.
3.2. Experiment
3.2.1. Taxonomic Data
The Explanandum in this experiment is a finetuned genomic language model (gLM) trained to perform taxonomic classification of DNA sequences, which we introduced in an earlier paper [
26]. The data was obtained from the open-access dataset published as part of the Supplemental Material of Mock et al. [
16]:
https://osf.io/dwkte (accessed on 1 October 2024) [
27]. All sequences are 1500 base-pairs in length, standardized to all uppercase letters, and belong to four superkingdom categories: archaea, viruses, eukaryote, and bacteria. As the NT50m models use a 6-mer tokenizer (six amino acids grouped together, ex ACTGCC), the input sequences are expected to be 250 tokens, less than the maximum token length of NT50m of 1000 tokens. Phylum and genus labels are retrieved from the Environment for Tree Exploration (ETE) toolkit using the ncbi_taxonomy module. We used the version of the publicly available taxdump archive from December 2024 (
https://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdump_archive/ (accessed on 12 December 2024) [
28]). Only sequences with both a phylum and a genus label are included in our analysis.
The complete dataset after processing includes 5,181,880 sequences spanning 4 superkingdoms: 2,601,890 from Eukaryota, 1,828,018 from Bacteria, 524,276 from Archaea and 227,696 from Viruses. They are then annotated with 1573 unique taxonomy IDs, including 55 phyla, and 1878 genera. The test dataset is split using the scikit learn train_test_split function [
29] with a seed set to 42, such that 2% (103,648 sequences) is reserved for the hold out test dataset. This test set is used to evaluate all experimental runs. An additional 10% of the full dataset is set aside to monitor model performance by evaluating validation loss over the course of training. The Explanandum was trained on the remainder of the dataset. The Explainer was trained on a subset of 239,391 examples randomly sampled from the training set, amounting to 4.7% of the training set, to reduce training time (20 min/epoch). The results on this subset were sufficient to render training on the full dataset unnecessary.
3.2.2. Explanandum
The genomic language model NT50m is the backbone of the classifier. We utilize pretrained weights alongside a custom hierarchical classification layer. The classification layer consists of three classification heads, each a linear layer that maps the transformer’s hidden size to the number of possible values for a given taxonomic rank: genus (1878 possibilities), phylum (55) and superkingdom (4). No hyperparameter selection was done for creating the NT50m Explanandum model as we were implementing the pretrained model with a different classification head. The forward pass sends token IDs and attention masks through the model backbone to produce token-level hidden states, which are then averaged using mean pooling across the sequences to get a fixed-length vector of length 512. Weighting this vector by the total number of valid (non-padded) tokens ensures that only meaningful tokens contribute to the mean embedding vector. Finally, the classification heads output a logit for the three taxonomic levels (the loss function is the sum of the three taxonomic level cross-entropies). An example of a masked sequence is shown in
Figure 5.
The Explanandum achieves a balanced classification accuracy of 99.05%, 97.39%, and 71.65% for superkingdom, phylum and genus, respectively, suggesting that there is a signal-rich feature set for the Explainer to find. We use balanced accuracy to ensure the weight of each taxonomic category is equally important.
3.2.3. Training and Resources
Training of the Explainer and the Explanandum was done on a single NVIDIA H100 GPU with 80 GB of VRAM running Python 3.11.6 [
30] with the PyTorch 2.5.1 [
31] library. Both models were trained with a learning rate of 0.0002 using a cosine learning rate scheduler to adjust the learning rate throughout training. The Explanandum was trained for five epochs with a batch size of 128. The Explainer was trained using a batch size of 48 and trained for 50 epochs; however, an early stopping mechanism ended training at 38 epochs. Training was conducted using FP32 precision, zero weight decay and no warm-up phase. Competing interpretability method LIME was implemented using the captum 0.7.0 [
32] Python package. Random seeds were set to 42 for training. Scikit-learn [
29] version 1.6.1 was used for assessment of performance and training and test splitting.
4. Results
The Explainer successfully learned masks that cover the input sequences without compromising the classification performance of the Explanandum (
Table 1). The results include performance on the unmasked sequences, the sequences with the masks, the rounded masks, the inverted masks, the inverted rounded masks, and finally, separated chunks. For the separated chunks, we took the rounded masks and separated the sequence into subsequences based continuous ones and zeros. Each chunk was classified individually and then the class was aggregated based on an average for the important and not important subsequences. This was done to demonstrate that the residual effects of masked spacing were not impacting the classification significantly.
The unmasked classifier achieves balanced accuracies of 99.05%, 97.39%, and 71.65% for superkingdom, phylum and genus labels, respectively, while the masked classifier has comparable accuracies of 98.39%, 95.51%, and 69.88%. We confirm that the masks do not achieve this accuracy by simply exposing the entirety of all sequences: 67.28% of mask values over all tokens were greater than 0.5.
While there is a variation in the values of the mask at any given sequence, there is a consistent trend in the mask values, with many being close to 1 and a distribution spread out among the lower values, not quite getting to zero (
Figure 6A). The majority of mask values are well above 0.5, with a mean mask value of 0.73 and 67.28% greater than 0.5 across the test dataset. However,
Figure 6B shows that there are parts of sequences that are more consistently relevant than others, as there is a consistent peak at the beginning of the sequence and two peaks in the middle. However, looking at the maximum and minimum mask values at each location across test sequences indicates a variation in the mask values.
The purpose of this approach was to create a premium on continuity of the explainable features. As a result, it is essential to check that it is actually doing that.
Figure 7A shows that the vast majority of sequences have fewer than five chunks of unmasked sequences, with the average being 3.35 and about 90% having four or fewer, and only 0.4% of sequences have zero masked tokens. The mean unmasked length was 51.39 tokens, with some exceeding 200 tokens, as shown in
Figure 7B. This indicates that the chunks that are unmasked are not trivial in length, either. This corresponds to about 20% of the sequence, showing that the chunks do not cover the entirety of the sequence and do not show small disparate parts.
Checking both the inverted masks and the rounded masks confirm that the masks are removing relevant information. When the masks are inverted, accuracy of superkingdom, phylum and genus classification plummets from 98.39% to 46.70%, 95.51% to 28.21%, and 69.88% to 5.71%, respectively. The rounded masks show a drop, but nowhere near as steep, with superkingdom dropping to 76.42%, phylum to 52.6% and genus to 40.86%. Inverting the rounded masks gave similar results to the inverted masks, with balanced accuracies of 44.85%, 26.20% and 4.92%. All told, the masks effectively expose relevant information for classification while obscuring irrelevant information.
To confirm that blanks and masked out sequence segments do not play an influential role in classification, the sequences were broken into relevant and irrelevant chunks, which were then passed through the Explanandum. The relevant chunks had balanced classification accuracies of 69.98%, 49.92% and 36.32%, versus the irrelevant chunks’ 47.13%, 31.68% and 9.58%. These performance metrics are similar to that of the masked and inverted-masked sequences.
Due to the emphasis on globality that underpins transformer-based architectures, the most popular explainable AI frameworks for text—LIME and Shap—are inherently ill-equipped to clarify taxonomic sequence classification decisions. This, combined with the non-serializable architecture of the Evolutionary Scale Modeling (ESM)-based backbone of NT50m, made the Python packages for LIME, transformer-interpret and Shap prohibitively difficult to use. Runtime is an additional limiting factor. LIME took an average of 26 s per-sequence, per-taxonomic rank, which would amount to months if run without parallelization on the 103,638 sequence testing set. Compared to our variant on Stadler et al. [
11], the Explainer was trained in 15 h and 40 min, and took twenty minutes to generate masks for the 103,638 testing set sequences. Even if it was feasible to run LIME at scale, the outputs would not be reliably helpful.
Figure 8 demonstrates the output of LIME on an example at the phylum level that was classified correctly. From the image, it is unclear which of the green chunks (positively contributed to classification) were most significant and the presence of red (negatively contributing to classification) only create additional confusion. As you filter down to show only the ones that contributed positively and even only substantially (>0.5) the tokens are very much spread out, providing little interpretability. Meanwhile, the masks demonstrated in
Figure 8 are substantially more interpretable.