Common cross-modal image fusion frameworks include single-step multi-layer network direct fusion frameworks and encoder–decoder frameworks. To ensure that the fused image can fully retain the features of the original infrared image, Wave-Cross adopts a multi-step algorithm structure of encoder–fusion module–decoder and proposes a wavelet transform-based cross-attention mechanism framework for image fusion. In the pre-trained encoding and decoding stages, encoders and decoders with the same structure but different parameters are designed for each modality image; in the subsequent fusion and decoding stages, the decoder used in pre-training is discarded, and an additional set of fusion-decoders is trained to efficiently coordinate the fusion and decoding stages. The key components involved in the entire framework include independent modality encoders, cross-modal multi-frequency band cross-attention modules, and additional multi-layer mixed feature decoders.
3.3. Fusion Module
The task of the fusion module is to add visible-light details to the infrared image without destroying the relative gray level relationship in the infrared image. To this end, the WA module employs wavelet transform to decompose the dual-modal features from the encoder into different frequency bands. The WA-H module, used for the visible-light branch, extracts the main structure and global features using multiple convolution layers in the low-frequency band. In the mid- and high-frequency bands, it employs self-attention mechanisms to extract high-frequency details in the horizontal, vertical, and diagonal directions. In contrast, the WA-L module, used for the infrared branch, operates in the opposite manner. It captures global and brightness features using self-attention mechanisms in the low-frequency band and extracts high-frequency texture details using multiple convolution layers in the mid- and high-frequency bands. Finally, in the fusion module, we design multiple cross-attention mechanisms to achieve efficient multi-frequency band cross-modal fusion and output the result to the final decoder.
Here, we will detail the most important WA module in the fusion stage, as shown in
Figure 3.
Two feature branches from the encoder, representing infrared and visible modalities, respectively, enter the WA-L and WA-H modules with different parameters. Within the WA module, we first apply a LayerNorm to the encoder outputs of both modalities. After that, we apply a 2D discrete wavelet transformusing the Haar wavelet to decompose the two input branches into four sub-bands. For the encoder part that provides feature maps
and
I ∈ [
ir,
vi], we use the Haar transform to decompose them into different sub-bands for subsequent feature extraction. The Haar wavelet transform filters consist of a low-pass filter L and a high-pass filter H, which can be represented as:
After the wavelet transform, the input feature map is decomposed into four sub-bands:
where, N denotes the number of patches, C represents the number of channels, and
H and
W are the height and width of the feature map, respectively. The low-frequency sub-band
retains the main structure and global information of the image, such as overall brightness, contours, and major textures. The horizontal high-frequency sub-band
preserves the detail information of the image in the horizontal direction, such as horizontal edges, horizontal textures, and abrupt changes in the horizontal direction. The vertical high-frequency sub-band
and the diagonal high-frequency sub-band
, respectively, preserve the detail information of the image in the vertical and diagonal directions. DWT does not change the number of channels, and due to its biorthogonal property, no information is lost in the entire process, which meets our previously proposed requirement that the fusion algorithm should fully retain infrared features.
Taking the infrared branch using the WA-L module as an example, for the low-frequency sub-band
, we perform patch embedding to transform it into
, where
NLL is the number of patches of sub-band
, and
h = 4 and
w = 4 are the height and width of each patch, respectively. The patch-embedded feature
is transposed and reshaped into a contiguous token sequence
with
d =
C ×
h ×
w = 512 before entering the attention module. Then, we employ the self-attention mechanism to extract and preserve the global background features.
In this equation, is a learnable transformation matrix. , and represent the input value of the low-frequency sub-band of a certain modality after transformation. d represents the feature dimension obtained from patch embedding. denotes the layer normalization operation. is a multi-layer perceptron.
The high-frequency sub-bands are combined into
, which refers to concatenation along the 0th dimension. First, the high-frequency details are amplified by a convolution block containing LeakyReLU.
Specifically,
denotes a 2D convolution with a kernel size of
, and BN stands for batch normalization. After the entire WA module operation, the output is reconstructed using the inverse wavelet transform (IWT).
The output
is converted into
after patch embedding, reshaped into
, and then input into the next two cross-attention mechanism mixing modules, as shown in
Figure 1. Concretely, given an encoder feature map
, the DWT produces four sub-bands
. The low-frequency sub-band is patch-embedded into
, while the IWT reconstructs a fused feature map
, ensuring consistent tensor shapes throughout the WA module.
In WA-L module, the high-frequency sub-bands (LH, HL, and HH) are refined using a two-layer convolutional block after wavelet decomposition. Specifically, the input channels are first expanded from C to 2C through a 3 × 3 convolution with stride 1 and padding 1, followed by batch normalization and a LeakyReLU activation to enhance local contrast and suppress noise. A second 3 × 3 convolution then reduces the channels from 2C back to C, also followed by batch normalization. This configuration, which directly corresponds to our PyTorch (version: 2.5.1) implementation, effectively preserves fine textures while stabilizing feature distributions. In the concrete implementation, C is set to 128.
For the low-frequency branch, we adopt a custom head-partitioned dot-product attention mechanism. The feature dimension dim is evenly divided into 16 partitions, and attention is computed independently within each partition by linearly projecting the input into Q, K, and V spaces and applying scaled dot-product attention with a factor of dim−1/2. The outputs are aggregated through a linear projection with a dropout rate of 0–0.1. After attention mixing and high-frequency refinement, all four sub-bands are reconstructed via inverse wavelet transform.
WA-H and WA-L share the same architecture but maintain independent parameter sets for the visible and infrared branches.
The cross-modal interaction after WA module is implemented by two sequential cross-attention stages operating on the same patch-level tokens rather than on different wavelet sub-bands. Let XIR and denote the patch-embedded infrared and visible features after the WA module.
In the first cross-attention stage, we perform bidirectional cross-attention:
where
,
and
are learnable transformation matrices of
CAi. The first argument of
CAi is used as queries and the second as keys/values. The two outputs are then aggregated by element-wise addition to form a shared cross-modal pattern
In the second cross-attention stage, this shared pattern is used as the query to further refine both modalities:
and the final fused tokens are obtained by
followed by a two-layer feed-forward network with GELU activation and layer normalization:
In summary of the cross-attention stage, the first stage symmetrically exchanges information between infrared and visible tokens, while the second stage uses the fused cross-modal pattern to re-attend to and jointly refine both modalities. All four cross-attention blocks operate on the same token sequence (same sub-bands after WA), and the two stages are applied sequentially.
3.4. Decoder Architecture
The decoder used in the fusion stage is shown in
Figure 4. After the first and fifth convolutions, residual structures are employed to mix the features from the pre-trained encoders of the two modalities with the features from the decoding stage in order to preserve the important features of the original modalities before mixing.
As shown in
Figure 4, the decoder of fusion stage receives three inputs: the fused mid-level feature
and the modality-specific deep features from the infrared decoder and visible decoder. As shown in the figure, these three tensors are first summed to form (
B,
C,
H,
W), which is processed by a convolution layer (orange block) to refine the channel size C. A subsequent convolution layer reduces the channels to C/2, followed by an upsampling layer that doubles the spatial resolution. The output then passes through a three-stage convolutional block, which sequentially performs convolution
→ upsampling, convolution
→ upsampling, and a final convolution
, progressively increasing spatial resolution and reducing channel dimensionality.
During reconstruction, shallow features from both encoders—infrared shallow feature and visible shallow feature—are injected through skip connections (green and purple arrows) and concatenated with the decoder input.
The pre-training encoder stage is shown in
Figure 5, where the decoder structures used in the two modalities are the same, but the parameters are different; a new decoder is used in the fusion stage to decode the mixed features. The synchronous fusion and decoding enable the output mixed image to present the significant temperature-related features from the original infrared image, as well as the fine texture features from the visible-light image.
As shown in
Figure 5, the module takes as inputs the current decoder feature
and the shallow encoder feature
. As illustrated, the input feature first passes through a convolution layer, which reduces the channel dimension from C to C/2, followed by an upsampling layer that doubles the spatial resolution. The output is then processed by a three-stage convolutional block consisting of convolution
→ upsampling and convolution
→ upsampling. Finally, a convolution
maps the fused feature to the reconstruction output. This sequence progressively increases spatial resolution while reducing channel dimensionality, producing features of size.
3.5. Loss Function
3.5.1. The Loss Function Used During Encoder Training
To independently encode features from different modalities and ensure thorough mixing, our method employs distinct loss functions during the pre-training of the encoder and the training of the fusion stage. The loss function used during encoder training consists of two parts and can be written as follows:
where
Ie is the reconstructed image encoded under a certain modality (infrared or visible-light), and
Ic is the infrared image or visible-light image.
ε is a hyperparameter.
3.5.2. The Loss Function Used During the Training of Mixing Part
Since the fused image should contain more complementary features and reduce redundant information from different modalities, a novel loss function is proposed to train our network. The formula of our loss function is given as follows.
The loss function
Lall consists of two parts and can be written as follows:
Specifically,
Lint measures the main part of the mixed image, such as illumination and contour information; the Heat-Consistency Loss
Lheat can measure the infrared temperature difference readability of the fused image, making the fused image visually clear while still retaining an interpretable thermal perception capability. The specific calculation method of
Lheat is described in
Section 3.5.3.
3.5.3. Heat-Consistency Loss Function
In practical infrared image acquisition, sensor noise and nonlinear responses may exist, but the grayscale intensity of pixels generally maintains an approximate monotonic relationship with temperature. Thus, the relative magnitude of pixel values can be regarded as ranking information for thermal levels. This property is crucial for tasks such as human detection or fault diagnosis. However, most current fusion algorithms, especially those dominated by visible-detail compensation, emphasize visual clarity while neglecting thermal fidelity. As a result, fused images often suffer from temperature inversion, grayscale drift, or energy leakage: the thermal ordering of pixels may be reversed, or local energy may become blurred. Such artifacts reduce the interpretability of fused images in temperature-sensitive applications.
To address the retention of thermal radiation information in infrared images, we propose a Heat-Consistency Loss function. This loss function, in a self-supervised manner, explicitly guides the model during training to preserve the temperature ranking structure, local energy distribution, and the significance of heat source regions in the infrared images. By doing so, it maintains the thermal radiation information in a physically reasonable way, ensuring that the fused image achieves visual clarity while maintaining interpretable thermal perception.
The overall loss function is composed of two complementary sub-terms—the Weighted Ranking Preservation Loss and the Local Energy Preservation Loss—which together constrain the fused image from both ordinal and energetic perspectives, and are jointly defined as follows:
where
λ1 and
λ2 are the balance coefficients of the two sub-losses.
From a theoretical standpoint, this design is grounded in the physical interpretability of infrared imaging. In infrared images, pixel intensity correlates monotonically with object temperature; thus, preserving the ranking order of pixel values ensures that the fused output maintains consistent thermal semantics—a form of monotonic mapping constraint. However, ranking preservation alone cannot guarantee radiometric accuracy, as it may allow global brightness drift. Therefore, the local energy preservation loss complements this by enforcing energy conservation within local neighborhoods, ensuring that average intensity distributions remain consistent with the thermal domain.
Together, these two constraints establish a balance between ordinal consistency and radiometric fidelity, providing both perceptual and physical interpretability for the fusion process. To further validate the rationality of this formulation, the mathematical definitions and derivations of the two sub-loss components are presented below.
a. Weighted Ranking Preservation Loss
To maintain the relative ranking relationship of pixel grayscale values in infrared images, we constructed a pixel pair weighted ranking preservation loss:
In this term,
and
represent the grayscale values of the infrared image and the fused image at pixel
i, respectively. If the infrared image has a higher grayscale value at
i but the fused image has a lower value (i.e., the sorting is reversed), the loss term is greater than zero. This loss penalizes inconsistent sorting, encouraging the fused image to maintain a consistent relative temperature structure in terms of thermal sensation. The set
ρ typically consists of randomly sampled pixel pairs or neighboring pixel pairs within the same image, ensuring low computational cost and local robustness. The saliency-guided weighting factor
ωi,j, which highlights the importance of heat source regions, is calculated as follows:
where
Si and
Sj represent the pixel saliency maps of the pixel pair (
i,
j), respectively. The introduction of the pixel saliency map is aimed at highlighting the important heat source regions in the infrared image. The pixel saliency map
S is generated based on the grayscale of the infrared image and is defined as follows:
Here, denotes the Sigmoid function; k is the stretching factor that controls the steepness of the response boundary (set to 10); t is the significance threshold used to distinguish whether the grayscale region is a “heat source” region, which is set to 0.6; IIR represents the normalized grayscale values of the infrared image. The saliency-weighting mechanism ensures that greater loss is incurred when ranking errors occur in heat source regions, guiding the model to prioritize learning the ranking consistency of key target regions.
b. Local Energy Preservation Loss
To maintain the energy consistency within local regions of the image, that is, to ensure that the total amount of thermal radiation does not shift, we designed the following energy preservation term:
where
represents the grayscale mean of the
k-th local window in the infrared image, while
denotes the grayscale mean of the corresponding window in the fused image. Essentially, this term penalizes brightness changes at the window level to prevent heat leakage, overexposure, or underexposure distortion in the fused image. It ensures that the “thermal intensity” in the local regions of the fused image is consistent with that of the infrared image, providing a reliable basis for subsequent tasks such as local temperature rise detection that rely on regional temperature judgments.
In our implementation, the local energy is computed using a fixed window of size w × w with w = 7. For each pixel location (i, j), we apply a sliding 7 × 7 averaging window centered at (i, j) with a stride of 1, resulting in fully overlapping neighborhoods across the entire image. Zero-padding of pixels is used at the borders so that the local energy is defined for all spatial positions. In practice, this operation is implemented as a 2D convolution with a normalized box filter of size 7 × 7 (all ones divided by w2) applied to both the infrared image and the fused image, and the Local Energy Preservation Loss is defined as the mean squared error between the two resulting local mean maps.