1. Introduction
Originally designed to solve machine translation problems [
1,
2], the Transformer [
3,
4] model has been widely introduced into computer vision (CV) [
5,
6,
7], natural language processing (NLP) [
8], speech processing [
9,
10,
11], audio processing [
12,
13], chemistry [
14], and life sciences [
15] due to its powerful modelling capabilities and applicability. It has contributed significantly to the development of these fields.
In computer vision, Convolutional Neural Networks (CNNs) [
16,
17,
18] are traditionally used as the primary means of processing. Convolution is well suited for processing regular, highdimensional data and allows for automatic feature extraction. However, convolution suffers from obvious localisation constraints. The conditional assumption is that points in the space are only associated with their neighbouring grids, whereas distant grids are not associated with each other. Although this limitation can be alleviated to some extent by expanding the convolution kernel, it still cannot solve the problem fundamentally. After introducing the Transformer, some researchers have tried to introduce the Transformer model architecture into the field of computer vision. Transformer has a larger field of perception than CNN, so it captures rich global information and can better understand the whole image. Ramachandran et al. [
19] constructed a vision model without using convolution, which uses a fullattention mechanism instead of convolution to improve the localisation constraint in convolution. In addition, Transformer has shown excellent performance in other CV areas such as image classification [
6,
20], object detection [
5,
21], semantic segmentation [
22], image processing [
22], and video understanding [
5].
Sequential data are more suitable for processing using Transformer than computer vision. In the traditional field of time series prediction, most of them rely on Recurrent Neural Network (RNN) [
23,
24] models, among which the more influential ones include Gated Recurrent Unit (GRU) [
25] and Long Shortterm Memory (LSTM) [
26,
27] networks. For example, Mou et al. [
28] proposed a TimeAware LSTM (TLSTM) with temporal information enhancement, whose main idea is to divide memory states into shortterm memory and longterm memory, adjust the influence of shortterm memory according to the time interval between inputs (the longer the time interval, the smaller the influence of shortterm memory), and then reorganise the adjusted shortterm memory and longterm memory into a new memory state. However, the emergence of Transformer soon shook the dominance of RNN family models in the field of time series prediction because of the following bottlenecks of RNNs in dealing with longtime prediction problems.
(1) Parallelism bottleneck: The RNN family of models requires the input data to be arranged in temporal order and computed sequentially according to the order of arrangement. This serial structure has the advantage that it inherently contains the portrayal of positional relationships, but it also constrains the model from being computed in parallel. Especially when facing long sequences, the inability to parallelise means more time and cost.
(2) Gradient bottleneck [
29]: One performance bottleneck of RNN networks is the frequent problem of gradient disappearance or gradient explosion during training. Most neural network models optimise model parameters by computing gradients. Gradient disappearance or gradient explosion can cause the model to fail to converge or converge too slowly, which means that for the RNN family of networks, it is difficult to make the model better by increasing the number of iterations or increasing the size of the network.
(3) Memory bottleneck: For each moment, the RNN network requires a positional input ${x}_{t}$ and a hidden input ${h}_{t1}$, which will be fused within the model according to the inherent rules to produce a hidden state ${h}_{t}$. Therefore, when the sequence length is too long, the ${h}_{t}$ almost no longer contains the earlier positional input; that is, the “forgetting” phenomenon occurs.
Compared with the RNN family of models, Transformer portrays the positional relationships between sequences by positional encoding without recursively feeding sequential data. This processing makes the model more flexible and provides the maximum possible parallelisation for time series data. The positional encoding also ensures that no forgetting occurs. The information at each location has an equal status for the Transformer. Additionally, using an attention mechanism to extract internal features allows the model to choose to focus on important information. The problem of gradient disappearance or gradient explosion can be avoided by ignoring irrelevant and redundant information. Therefore, based on the above advantages of Transformer models, many scholars are now trying to use Transformer models for time series tasks.
2. Research Background
Transformer is a typical encoderdecoderbased sequencetosequence [
30] model, and this structure is well suited for processing sequence data. Several researchers have tried to improve the Transformer model to meet the needs of more complex applications. For example, Kitaev et al. [
31] proposed a Reformer model that uses Locality Sensitive Hashing Attention (LSH) to reduce the complexity of the original model from
$O({L}^{2})$ to
$O(Llog(L))$. Zhou et al. [
32] proposed an Informer model for Long Sequence Time Series Forecasting (LSTF), which accurately captures the longterm dependence between output and input and exhibits high predictive power. Wu et al. [
33] proposed the Autoformer model, which uses a deep decomposition architecture and an autocorrelation mechanism to improve LSTF accuracy. The Autoformer model achieves desirable results even when the series is predicted much longer than the length of the input series, i.e., it can predict the longerterm future based on limited information. Zhou et al. [
34] proposed the FEDformer model, which provides a way to apply the attention mechanism in the frequency domain and can be used as an essential complement to the time domain analysis.
The Transformer model described above focuses on reducing its temporal and spatial complexity, but needs to enhance the diversity of the information it captures. The attention mechanism is the core part of the Transformer used for feature extraction. It is designed to allow the model to focus on more important information, which means there is a certain amount of information loss. The multihead attention mechanism can compensate for this. However, since each attention head captures similarly, there is no way to ensure that each attention head is capturing different vital features. Since the multihead attention mechanism essentially divides multiple mutually independent subspaces, this approach completely cuts off the connection between each subspace, which leads to a lack of interaction between the information captured by multiple heads. Based on these problems, this paper proposes a hierarchical attention mechanism that features each layer using a different attention mechanism to capture features. The higher layers will use the information captured by the lower layers, thus enhancing the Transformer’s ability to perceive deeper information.
3. Research Methodology
3.1. Problem Description
Initially, the Transformer model was proposed by Waswani et al. to solve the machine translation problem, so Vanilla Transformer is more suitable for processing textual data. For example, the primary processing unit of the Vanilla Transformer model is a word vector, and each word vector is called a token. In contrast, in the time series prediction problem, our basic processing unit becomes a timestamp. If we want to apply Transformer to a time series problem, the reasonable idea is to encode the multivariate sequence information of each timestamp into a token vector. This modelling approach is also the treatment of many mainstream Transformerlike models.
Here, for the convenience of the subsequent description, we define the dimension of the token as
d, the input length of the model as
I, and the output length as
O. Further, the model’s input can be defined as
$\mathcal{X}=\{{\mathbf{x}}_{1},\cdots ,{\mathbf{x}}_{I}\}\in {\mathbb{R}}^{I\times d}$, and the model’s output as
$\widehat{\mathcal{X}}=\{{\widehat{\mathbf{x}}}_{1},\cdots ,{\widehat{\mathbf{x}}}_{O}\}\in {\mathbb{R}}^{O\times d}$. Therefore, this paper aims to learn a mapping
$\mathcal{T}(\xb7)$ from the input space to the output space.
3.2. Model Architecture
Our model (
Figure 1) continues the Transformer architecture in the main body, and we also added a decomposer to the model by referring to Autoformer’s sequence decomposition model. The function of the decomposer is to filter trendcyclical and seasonal parts. The advantage is that removing trend parts from the series allows the model to focus better on the hidden periodic information of the series, and Wu et al. [
33] have shown that this decomposition is effective. In addition, the model uses a coder–decoder structure, where the encoder is responsible for mapping the information from the input space to the feature space, and the decoder is responsible for mapping the information from the feature space to the target space. The model is a typical sequencetosequence model, since both the input and output of the model are sequencetype data. In addition, we try to use a hierarchical attention mechanism instead of the original multihead attention mechanism and a graph network instead of the original feedforward neural network inside the codec, which can improve the diversity of captured information and the mitigate tokenuniformity inductive bias [
35,
36] of the model, respectively.
3.2.1. Decomposer
The main difficulty of time series forecasting lies in discovering the hidden trendcyclical and seasonal parts information from the historical series. The trendcyclical records the overall trend of the series, which has an essential influence on the longterm trendcyclical of the series. The seasonal parts record the hidden cyclical pattern of the series, which mainly shows the regular fluctuation of the series in the short term. It is generally difficult to predict these two pieces of information simultaneously. The basic idea is to decompose the two, extracting the trendcyclical from the sequence using average pooling and filtering the seasonal period using the trendcyclical, which is how Decomposer implements the decomposed information, as shown in Algorithm 1.
Algorithm 1 Decomposer 
Require: $\mathcal{X}$ 
Ensure: $\mathcal{S},\mathcal{T}$ 
 1:
$\mathcal{T}\leftarrow \mathrm{avgpool}(\mathrm{padding}(\mathcal{X}))$  2:
$\mathcal{S}\leftarrow \mathcal{X}\mathcal{T}$

Here, $\mathcal{X}\in {\mathbb{R}}^{L\times d}$ is the input sequence of length L. $\mathcal{T},\mathcal{S}\in {\mathbb{R}}^{L\times d}$ is the decomposed trendcyclical and seasonal parts where the role of padding is to ensure that the decomposed series remains equal in dimension to the input sequence.
The decomposer module has a relatively simple structure. However, it can decompose the forecasting task into two subtasks, i.e., mining hidden periodic patterns and forecasting overall trends. This decomposition can reduce the difficulty of prediction to a certain extent and, thus, improve the final prediction results.
3.2.2. Encoder
The encoder is mainly responsible for encoding the input data and realizing the transformation from the input space to the feature space. The decomposer in the encoder is more like a filter because, in the encoder, we focus more on the seasonal parts of the sequence and ignore the trendcyclical. The input data are passed through a hierarchical attention layer for initial key feature extraction. After which, the decomposer extracts the seasonal part’s features in the sequence and they are further fed into the graph network to mitigate inductive bias. After stacking
N layers, The seasonal parts features thus obtained will be auxiliary inputs to the decoder. Algorithm 2 describes the computation procedure.
Algorithm 2 Encoder 
Require: ${\mathcal{X}}_{\mathrm{en}}$ 
Ensure: ${\mathcal{X}}_{\mathrm{en}}^{N}$ 
 1:
for
$l=1,\cdots ,N$
do  2:
if $l=0$ then  3:
${\mathcal{X}}_{\mathrm{en}}^{l1}\leftarrow {\mathcal{X}}_{\mathrm{en}}$  4:
end if  5:
${\mathcal{S}}_{\mathrm{en}}^{l,1},\_\leftarrow \mathcal{D}\left(\mathcal{H}({\mathcal{X}}_{\mathrm{en}}^{l1})+{\mathcal{X}}_{\mathrm{en}}^{l1}\right)$  6:
${\mathcal{S}}_{\mathrm{en}}^{l,2},\_\leftarrow \mathcal{D}\left(\mathcal{G}({\mathcal{S}}_{\mathrm{en}}^{l,1})+{\mathcal{S}}_{\mathrm{en}}^{l,1}\right)$  7:
${\mathcal{X}}_{\mathrm{en}}^{l}\leftarrow {\mathcal{S}}_{\mathrm{en}}^{l,2}$  8:
end for

Here, ${\mathcal{X}}_{\mathrm{en}}\in {\mathbb{R}}^{I\times d}$ denotes the historical observation sequence. N denotes the number of stacked layers of the encoder. ${\mathcal{X}}_{\mathrm{en}}^{N}$ denotes the output of the Nth layer encoder. $\mathcal{D}$ denotes the decomposer operator. $\mathcal{G}$ denotes the graph network operator and $\mathcal{H}$ denotes the hierarchical attention mechanism, the concrete implementation of which will be described later.
3.2.3. Decoder
The structure of the decoder is more complex than that of the encoder. However, its internal modules are identical to the encoder’s, but use a multiinput structure. It goes through two hierarchical attention calculations and three sequence decompositions in turn. Assuming that the model’s encoder is a feature catcher, the decoder is a feature fuser that fuses and corrects the inputs from different sources to obtain the correct prediction sequence. The decoder has three primary input sources: the seasonal parts
${\mathcal{X}}_{\mathrm{des}}$ and the trendcyclical
${\mathcal{X}}_{\mathrm{det}}$ extracted from the original series, and the seasonal parts
${\mathcal{X}}_{\mathrm{en}}^{N}$ captured by the decoder. The computation of the trendcyclical and seasonal parts is kept relatively independent throughout the computation process. Only at the final output is a linear layer used to fuse the two to obtain the final prediction
${\mathcal{X}}_{\mathrm{pred}}$. The computation process is described in Algorithm 3.
Algorithm 3 Decoder 
Require: ${\mathcal{X}}_{\mathrm{en}},{\mathcal{X}}_{\mathrm{en}}^{N}$ 
Ensure: ${\mathcal{X}}_{\mathrm{pred}}$ 
 1:
${\mathcal{X}}_{\mathrm{ens}},{\mathcal{X}}_{\mathrm{ent}}\leftarrow \mathcal{D}({\mathcal{X}}_{\mathrm{en}\frac{I}{2}:I})$  2:
${\mathcal{X}}_{\mathrm{des}}\leftarrow {\mathcal{X}}_{\mathrm{ens}}\Vert {\mathbf{0}}_{0:\frac{I}{2}}$  3:
${\mathcal{X}}_{\mathrm{det}}\leftarrow {\mathcal{X}}_{\mathrm{ent}}\Vert {\overline{\mathcal{X}}}_{0:\frac{I}{2}}$  4:
for
$l=1,\cdots ,M$
do  5:
if $l=1$ then  6:
${\mathcal{X}}_{\mathrm{de}}^{l1}\leftarrow {\mathcal{X}}_{\mathrm{des}}$  7:
${\mathcal{T}}_{\mathrm{de}}^{l1}\leftarrow {\mathcal{X}}_{\mathrm{det}}$  8:
end if  9:
${\mathcal{S}}_{\mathrm{en}}^{l,1},{\mathcal{T}}_{\mathrm{de}}^{l,1}\leftarrow \mathcal{D}\left(\mathcal{H}({\mathcal{X}}_{\mathrm{de}}^{l1})+{\mathcal{X}}_{\mathrm{de}}^{l1}\right)$  10:
${\mathcal{S}}_{\mathrm{de}}^{l,2},{\mathcal{T}}_{\mathrm{de}}^{l,2}\leftarrow \mathcal{D}\left(\mathcal{H}({\mathcal{S}}_{\mathrm{de}}^{l,1},{\mathcal{X}}_{\mathrm{en}}^{N})+{\mathcal{S}}_{\mathrm{de}}^{l,1}\right)$  11:
${\mathcal{S}}_{\mathrm{de}}^{l,3},{\mathcal{T}}_{\mathrm{de}}^{l,3}\leftarrow \mathcal{D}\left(\mathcal{G}({\mathcal{S}}_{\mathrm{de}}^{l,2})+{\mathcal{S}}_{\mathrm{de}}^{l,2}\right)$  12:
${\mathcal{X}}_{\mathrm{de}}^{l}\leftarrow {\mathcal{S}}_{\mathrm{de}}^{l,3}$  13:
${\mathcal{T}}_{\mathrm{de}}^{l}\leftarrow {\mathcal{T}}_{\mathrm{de}}^{l1}+{\mathcal{W}}_{l,1}\ast {\mathcal{T}}_{\mathrm{de}}^{l,1}+{\mathcal{W}}_{l,2}\ast {\mathcal{T}}_{\mathrm{de}}^{l,2}+{\mathcal{W}}_{l,3}\ast {\mathcal{T}}_{\mathrm{de}}^{l,3}$  14:
end for  15:
${\mathcal{X}}_{\mathrm{pred}}\leftarrow {\mathcal{W}}_{\mathcal{S}}\ast {\mathcal{X}}_{\mathrm{de}}^{M}+{\mathcal{T}}_{\mathrm{de}}^{M}$

Here, ${\mathcal{X}}_{\mathrm{en}}$ denotes the original sequence, which is also the input to the encoder. It is decomposed into trendcyclical and season parts ${\mathcal{X}}_{\mathrm{ens}},{\mathcal{X}}_{\mathrm{ent}}$ before feeding into the decoder as the initial input.
3.3. Hierarchical Attention Mechanism
The hierarchical attention mechanism, as the first feature capture unit of Metaformer, is at the model’s core and, therefore, has a significant impact on the subsequent work. Most Transformerlike models use the multihead attention mechanism to complete the first step of feature extraction. However, the multihead attention mechanism itself has significant drawbacks: (1) each head uses the exact attention mechanism, which cannot guarantee the diversity of captured information and may even miss some critical information. (2) Each head belongs to a separate subspace, and the lack of information interaction between heads is not conducive to the deep understanding of information by the model. Therefore, we propose a hierarchical attention mechanism for the first time. First, a hierarchical structure is used, where each layer uses a different attention mechanism to capture features separately, which ensures the diversity of information circulating in the network; second, a cascading interaction is used, where the information captured by the lower layer will be reused by the upper layer, which will deepen the depth of information understanding by the model. We know that when we humans understand language, we not only focus on the surface meaning of words, but can also understand the metaphors behind the words. Inspired by this, we use a hierarchical structure to model this phenomenon and, thus, improve the network’s ability to perceive information in three dimensions.
3.3.1. Traditional MultiHead Attention Mechanism
In the multihead attention mechanism, only one type of attention computation scaled dotproduct attention is used. The multihead attention mechanism first takes as input three vectors of queries, keys, and values with
${d}_{m}$ dimension, and each head is projected to
${d}_{k},{d}_{k}$ and
${d}_{v}$ dimensions using a linear layer. The attention function is then computed to produce a
${d}_{v}$ dimensional output value. Finally, the output of each attention head is stitched together and passed through a linear layer to obtain the final output.
Equation (
2) calculates the multiheaded attention mechanism, where
${\mathcal{L}}_{{\theta}_{q}},{\mathcal{L}}_{{\theta}_{k}},{\mathcal{L}}_{{\theta}_{v}},{\mathcal{L}}_{{\theta}_{o}}$ denotes the linear layer with projection parameter matrix
${W}^{Q}\in {\mathbb{R}}^{{d}_{m}\times {d}_{k}},{W}^{K}\in {\mathbb{R}}^{{d}_{m}\times {d}_{k}},{W}^{V}\in {\mathbb{R}}^{{d}_{m}\times {d}_{v}},{W}^{O}\in {\mathbb{R}}^{h{d}_{v}\times {d}_{m}}$, respectively.
h denotes the number of heads of attention.
$\mathcal{A}$ denotes scaled dotproduct attention. ∐ denotes sequential cascade.
3.3.2. Hierarchical Attention Mechanism
We propose a hierarchical attention mechanism to address the shortcomings in the multihead attention mechanism, aiming to enhance the model’s deep understanding of the information.
Figure 2 depicts the central architecture of the hierarchical attention mechanism, and Algorithm 4 describes its implementation.
Algorithm 4 Hierachical Attention 
Require: $\mathbf{Q},\mathbf{K},\mathbf{V}$ 
Ensure: $\mathbf{Y}$ 
 1:
for
$i=1,\cdots ,N$
do  2:
if $i=1$ then  3:
$\mathbf{H}\leftarrow \mathrm{Random}\phantom{\rule{4.pt}{0ex}}\mathrm{Initialisation}$  4:
end if  5:
$\mathbf{H}\leftarrow \mathcal{R}\left({\mathcal{A}}_{i}\left({\mathcal{L}}_{{\theta}_{q}}^{i}(\mathbf{Q}),{\mathcal{L}}_{{\theta}_{k}}^{i}(\mathbf{K}),{\mathcal{L}}_{{\theta}_{v}}^{i}(\mathbf{V})\right),\mathbf{H}\right)$  6:
$\mathcal{Y}\leftarrow \mathcal{Y}\Vert \mathbf{H}$  7:
end for  8:
$\mathbf{Y}\leftarrow {\mathcal{L}}_{{\theta}_{o}}(\mathcal{Y})$

Here,
${\mathcal{L}}_{{\theta}_{q}},{\mathcal{L}}_{{\theta}_{k}},{\mathcal{L}}_{{\theta}_{v}},{\mathcal{L}}_{{\theta}_{o}}$ has the same meaning as in Equation (
2).
$\mathcal{R}$ denotes the GRU unit.
$\mathcal{Y}$ records the information of each layer and finally maps it to the specified dimension as the model’s output by a linear layer.
${\mathcal{A}}_{i}$ denotes different attention calculation methods. This paper mainly uses four common attention mechanisms: Vanilla Attention, ProbSparse Attention, LSH Attention, and AutoCorrelation. AutoCorrelation is not, strictly speaking, part of the attention mechanism family. However, its effect is similar to or even better than attention mechanisms, so it is introduced into our model and involved in feature extraction.
Attention is the core building block of Transformer and is considered an essential tool for information capture in both CV and NLP domains. Many researchers have worked on designing more efficient attention, so many variants based on Vanilla Attention have been proposed in succession. The following briefly describes the four attention mechanisms used in our model.
3.3.3. Vanilla Attention
Vanilla Attention was first proposed in the Transformer [
3], and its input consists of three vectors: queries, keys, and values (
$\mathbf{Q},\mathbf{K},\mathbf{V}$), whose dimensions are
${d}_{k},{d}_{k},{d}_{v}$, respectively. Vanilla Attention is also known as Scaled Dot Product Attention because it is computed by dot product using
$\mathbf{Q}$ and
$\mathbf{K}$ and then scaled by
$\sqrt{{d}_{k}}$. The specific calculation process is shown in Equation (
3).
Here, $\mathcal{A}$ denotes the attention or autocorrelation mechanism. ${\sigma}^{\u2020}$ denotes the softmax activation function.
3.3.4. ProbSparse Attention
This attention mechanism, first proposed in Informer, considers the attention coefficients’ sparsity and specifies the query matrix
$\mathbf{Q}$ using the exact query sparsity measurement method (Algorithm 5). Equation (
4) gives the ProbSparse Attention calculation method.
Here,
$\tilde{\mathbf{Q}}$ is the sparse matrix obtained by the sparsity measure. The prototype of
$\tilde{M}({\mathbf{q}}_{i},\mathbf{K})$ is Kullback–Leibler (KL) divergence, see Equation (
5).
Algorithm 5 Explicit Query Sparisity Measurement 
Require: $\mathbf{Q},\mathbf{K}$ 
Ensure: $\tilde{\mathbf{Q}}$ 
 1:
Define
$\tilde{M}({\mathbf{q}}_{i},\mathbf{K})={max}_{j}\{{\displaystyle \frac{{\mathbf{q}}_{i}{\mathbf{k}}_{j}^{\top}}{\sqrt{{d}_{K}}}}\}{\displaystyle \frac{1}{{L}_{K}}}{\sum}_{j=1}^{{L}_{K}}{\displaystyle \frac{{\mathbf{q}}_{i}{\mathbf{k}}_{j}^{\top}}{\sqrt{{d}_{K}}}}$  2:
Define
$\mathcal{U}={\mathrm{argTopu}}_{{q}_{i}\in [1,\cdots ,{L}_{Q}]}(\tilde{M}({\mathbf{q}}_{i},\mathbf{K}))$  3:
for
$u\in [1,\cdots ,{L}_{Q}]$
do  4:
if $u\in \mathcal{U}$ then  5:
${\tilde{\mathbf{Q}}}_{u,:}\leftarrow {\mathbf{Q}}_{u,:}$  6:
else  7:
${\tilde{\mathbf{Q}}}_{u,:}\leftarrow \mathbf{0}$  8:
end if  9:
end for

3.3.5. LSH Attention
Like ProbSparse Attention, LSH Attention also uses a sparsification method to reduce the complexity of Vanilla Attention. The main idea is that for each query, only the nearest keys are focused on, where the nearest neighbour selection is achieved by locally sensitive hashing. The specific attentional process of LSH Attention is given in Equation (
6), where the hash function used is Equation (
7):
where
${\mathcal{P}}_{i}=\{j:h({q}_{i})=h({k}_{j})\}$ denotes the set of key vectors that the
ith query focuses on.
$a({\mathbf{q}}_{i},{\mathbf{k}}_{j})=\mathrm{exp}({\displaystyle \frac{{\mathbf{q}}_{i}{\mathbf{k}}_{j}^{\top}}{\sqrt{d}}})$ is used to measure the association of nodes
i and
j.
3.3.6. AutoCorrelation
AutoCorrelation mechanisms are different from the types of attention mechanisms above. Whereas the selfattentive family focuses on the correlation between points, the AutoCorrelation mechanism focuses on the correlation between segments. Therefore, AutoCorrelation mechanisms are an excellent complement to selfattentive mechanisms.
Equation (
8) gives the procedure of calculating the AutoCorrelation mechanism, where Equation (
9) is used to measure the correlation between two sequences, and
$\tau $ denotes the order of the lag term.
$\mathrm{roll}(\mathbf{V},\tau )$ denotes the vector of
$\tau $order lagged terms of vector
$\mathbf{V}$ obtained in a selflooping manner. Equation (
10) is the Topk algorithm used to filter the set
$\mathcal{T}$ of
k lagged terms with the highest correlation.
3.4. GAT Network
The Vanilla Transformer model embeds a Feedforward Network (FFN) [
37] layer at the end of each encoder–decoder layer. The FFN plays a crucial role in mitigating tokenuniformity inductive bias. Inductive bias can be considered a learning algorithm as a heuristic or “value” for selecting hypotheses in ample hypothesis space. For example, convolutional networks assume that information is spatially local, spatially invariant, and translational equivalent, so that the parameter space can be reduced by sliding convolutional weight sharing; recurrent neural networks assume that information is sequential and invariant to temporal transformations, so that weight sharing is also possible. Similarly, the attention mechanism also has some assumptions, such as the uselessness of some information. If the attention mechanism is stacked, some critical information will be lost, so adding a layer of FNN can somehow alleviate the accumulation of inductive bias and avoid network collapse. Of course, not only does the FFN layer have a mitigating effect, but we find that a similar effect can be achieved using a Graph Neural Network (GNN) [
38,
39,
40]. Here, we use a twolayer GAT [
41,
42] network instead of the original FFN layer. The graph network has the property of aggregating the information of neighbouring nodes, i.e., through the aggregation of the graph network, each node will fuse some features of its neighbouring nodes. Additionally, we use random sampling to reduce the complexity. The reason is that our goal is not feature aggregation, but to mitigate the loss of crucial information. In particular, when the number of samples per node is 0, the graph network can be considered to ultimately degenerate into an FFN layer with a similar role to the original FFN.
Here, we model each token as a node in the graph and mine the dependencies between nodes using the graph attention algorithm. The input to GAT is defined as $\mathcal{H}=\{{\overrightarrow{h}}_{1},{\overrightarrow{h}}_{2},\cdots ,{\overrightarrow{h}}_{N}\}$. Here, ${\overrightarrow{h}}_{i}\in {\mathbb{R}}^{F}$ denotes the input vector of the ith node, N denotes the number of nodes in the graph, and F denotes the dimensionality of the input vector. Through the computation of the GAT network, this layer generates a new set of node features ${\mathcal{H}}^{\prime}=\{{\overrightarrow{h}}_{1}^{\prime},{\overrightarrow{h}}_{2}^{\prime},\cdots ,{\overrightarrow{h}}_{N}^{\prime}\}$. Similarly, here ${\overrightarrow{h}}_{i}^{\prime}\in {\mathbb{R}}^{{F}^{\prime}}$ denotes the output vector of the ith node, and ${F}^{\prime}$ denotes the dimensionality of the output vector.
Figure 3 gives the general flow of information aggregation for a single node. Equation (
11) is a concrete implementation of calculating the attention coefficient
${e}_{ij}$ for the
ith node and its neighbour node
j one by one. Equation (
12) is used to calculate the normalised attention factor
${\alpha}_{ij}$:
Here, ${\mathcal{N}}_{i}$ denotes the set of all neighbouring nodes of the ith node, and $\mathbf{W}$ is a shared parameter for linear mapping of node features. $\mathcal{F}$ is a singlelayer feedforward neural network for mapping the spliced highdimensional features into a real number ${e}_{ij}$. ${e}_{ij}$ is the attention coefficient of node $j\to i$, and ${\alpha}_{ij}$ is its normalised value.
Finally, the new feature vector
${\overrightarrow{h}}_{i}^{\prime}$ of the current node
i is obtained by weighting and summing the feature vectors of each neighbouring node according to the calculated attention coefficients, where
${\overrightarrow{h}}_{i}^{\prime}$ records the neighbourhood information of the current node.
Here, $\sigma $ represents applying a nonlinear activation function logistic sigmoid at the end.
Furthermore, if information aggregation is accomplished through the
K head attention mechanism, the final output vector can be obtained by taking the average.