Next Article in Journal
Special Issue “Electric, Magnetic, and Electromagnetic Fields in Biology and Medicine: From Mechanisms to Biomedical Applications: 2nd Edition”
Previous Article in Journal
Distinguishing Dyslexia, Attention Deficit, and Learning Disorders: Insights from AI and Eye Movements
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

M3AE-Distill: An Efficient Distilled Model for Medical Vision–Language Downstream Tasks

1
School of Computer Engineering and Science, Shanghai University, Shanghai 200444, China
2
School of Information Technology, Shanghai Jian Qiao University, Shanghai 201306, China
*
Author to whom correspondence should be addressed.
These authors contributed equally to this work.
Bioengineering 2025, 12(7), 738; https://doi.org/10.3390/bioengineering12070738
Submission received: 11 April 2025 / Revised: 26 June 2025 / Accepted: 1 July 2025 / Published: 6 July 2025
(This article belongs to the Section Biosignal Processing)

Abstract

Multi-modal masked autoencoder (M3AE) are widely studied medical vision–language (VL) models that can be applied to various clinical tasks. However, its large parameter size poses challenges for deployment in real-world settings. Knowledge distillation (KD) has proven effective for compressing task-specific uni-modal models, yet its application to medical VL backbone models during pre-training remains underexplored. To address this, M3AE-Distill, a lightweight medical VL model, is proposed to uphold high performance while enhancing efficiency. During pre-training, two key strategies are developed: (1) both hidden state and attention map distillation are employed to guide the student model, and (2) an attention-guided masking strategy is designed to enhance fine-grained image–text alignment. Extensive experiments on five medical VL datasets across three tasks validate the effectiveness of M3AE-Distill. Two student variants, M3AE-Distill-Small and M3AE-Distill-Base, are provided to support a flexible trade-off between efficiency and accuracy. M3AE-Distill-Base consistently outperforms existing models and achieves performance comparable to the teacher model, while delivering 2.11× and 2.61× speedups during inference and fine-tuning, respectively.

1. Introduction

Recent advances in computational resources and the increasing availability of large-scale public datasets have significantly accelerated progress in uni-modal deep learning [1,2,3]. The perception of humans naturally integrates multi-modal information, including vision, audio signals, and language, to construct a comprehensive understanding of the environment. In the medical domain, data are generally in a multi-modal format such as chest X-ray images accompanied by corresponding clinical reports [4]. This inherent multi-modal data has driven significant interest in medical vision–language (VL) models, which aim to use multi-modal data sources to enhance modeling capabilities and comprehension [5,6]. Such models have demonstrated potential in improving diagnostic precision and reducing clinical workload through applications such as medical image–text retrieval (Med-ITR), which includes retrieving relevant textual reports given a medical image (image-to-text), and vice versa for retrieving images based on textual queries (text-to-image). Another application is medical visual question answering (Med-VQA), which aims to answer clinically relevant questions based on medical images.
In recent studies, medical VL research has trended toward increasingly complex architectures, larger parameter scales, and higher computational demands during pre-training [5,7]. For example, PubMedCLIP [6] and multi-modal masked autoencoders (M3AE) [7] contain 470 M and 347 M parameters, respectively, requiring substantial computing resources and incurring high computational costs. Although these large-scale models achieve strong performance on medical VL tasks, their deployment in real-world clinical settings remains challenging due to their computational overhead. Therefore, exploring model compression techniques for large-scale pre-trained VL models is a promising direction to improve their practicality and scalability in clinical applications.
Knowledge distillation (KD) is a widely adopted model compression technique that transfers knowledge from a large teacher model to a more compact student model [8]. The teacher–student framework is widely used in scenarios where deploying large models is impractical due to computational or memory constraints. In such cases, the teacher model can guide the student to achieve comparable performance while significantly reducing model size and inference cost. While traditional KD methods often rely on mimicking the soft outputs of the teacher, empirical studies indicate that such approaches may result in suboptimal performance, particularly in complex tasks such as medical image analysis. Consequently, recent research has focused on enhancing the distillation process to improve its effectiveness.
Although KD has been extensively studied in uni-modal medical tasks, recent works such as MSKD [9] and DSP-KD [10] have further demonstrated its effectiveness in segmentation and classification by introducing task-specific KD. However, these approaches are typically designed for single-modality settings, which limits their generalization. In the medical VL domain, KD remains relatively underexplored. For example, MHKD-MVQA [11] incorporates KD specifically for Med-VQA, making it less adaptable to other tasks. In contrast, this work introduces KD at the pre-training stage, aiming to construct a generalizable medical VL backbone that enables efficient adaptation to various downstream tasks.
Models such as M3AE [7] have shown promising results on medical VL tasks by adopting a unified pre-training framework. A concise overview of M3AE is provided in Section 2. Despite recent progress in medical VL models, two challenges limit the applicability of current methods. First, existing models adopt complex architectures with large-scale parameters, resulting in high computational costs and limited feasibility for real-world clinical deployment. Second, existing KD approaches in the medical domain are typically designed for specific tasks and uni-modal settings, making them difficult to generalize across diverse VL scenarios.
Therefore, there is a pressing need for a lightweight and generalizable medical VL model that enables efficient adaptation to various downstream tasks. To address this limitation, an efficient VL pre-trained model, M3AE-Distill, is proposed for medical VL tasks. Built upon the M3AE model, M3AE-Distill aims to effectively transfer knowledge from a high-capacity teacher model to a lightweight student model during the pre-training phase. To facilitate this KD process, three key components are developed to distill knowledge from the teacher model into a compact student model. Specifically, M3AE-Distill-Base consistently outperforms existing models and achieves performance comparable to the teacher model, while delivering 2.11× and 2.61× speedups in inference and fine-tuning, respectively. The small variant achieves even higher speedups of 3.51× and 4.83×, albeit with slightly reduced performance. The primary contributions are summarized as follows:
  • KD is integrated into the pre-training pipeline by aligning both attention maps and hidden states. This distillation enables the student to approximate the teacher’s intermediate representations.
  • An attention-guided masking strategy is proposed for the MIM objective. This strategy leverages attention maps from the teacher model to identify and mask semantically salient regions in the image. By reconstructing these regions, the student is encouraged to leverage complementary visual and textual features, thereby facilitating fine-grained cross-modal alignment.

2. Related Work

2.1. Vision–Language Model

Recent VL models in the general domain have shown that unified Transformer architectures can effectively integrate visual and textual information without relying on external object detectors [12,13]. These approaches establish key design principles such as joint multi-modal encoding, masked modeling, and contrastive learning, which have informed the development of domain-specific VL models.
In the medical field, VL models are adapted to better capture domain-specific semantics and structural characteristics. MedViLL [5] incorporates multi-modal attention mechanisms and shows improved alignment between medical images and text. PubMedCLIP [6], a CLIP-based model pre-trained on image–text pairs from PubMed articles, highlights the benefits of contrastive learning in medical VL settings, particularly for improving visual representations for tasks like medical VQA. Among these, M3AE [7] has emerged as a strong baseline for medical VL tasks. It is trained on approximately 300,000 image–text pairs and introduces a unified self-supervised learning framework that combines Masked Language Modeling (MLM) and Masked Image Modeling (MIM). This design enables effective representation learning from both modalities without requiring task-specific supervision. While M3AE performs well across various medical VL tasks, its large number of parameters and complex architecture present challenges for practical deployment.

2.2. Knowledge Distillation

KD aims to transfer knowledge from a large, complex neural network (teacher) to a smaller, more efficient network (student) [8], inspired by the way humans learn from a more knowledgeable teacher. In the vanilla KD approach, the student model is trained to mimic the soft logits produced by the teacher model. However, as these logits only capture information from the final layer, they provide weak supervision, limiting the effectiveness of the distillation process. To enhance supervision, FitNets [14] introduces an intermediate feature-based distillation approach, where the student model learns from the teacher’s feature attention maps, providing richer guidance beyond output logits. For Transformer-based architectures, self-attention plays a crucial role in capturing long-range dependencies between tokens. Leveraging this insight, TinyBERT [15] proposes an attention-based distillation mechanism to transfer the teacher’s linguistic and structural knowledge to the student model. In this study, the teacher–student framework is adopted to retain the strong cross-modal alignment and semantic reasoning capabilities learned by the teacher, while enabling efficient deployment through a lightweight student.

3. Method

The proposed M3AE-Distill model is a distilled version of M3AE, incorporating CLIP-ViT (ViT-B/16) and RoBERTa as uni-modal encoders. The image encoder consists of 12 Transformer layers with a 16×16 patch size and is initialized from CLIP pre-trained weights. Multi-modal feature fusion is achieved through cross-modal encoders.
The training process consists of two stages, as illustrated in Figure 1. In Stage 1 (Figure 1A), the student model is trained on pre-training tasks under the guidance of the teacher model. This stage includes two key components: (1) both attention and hidden state distillation are employed to transfer knowledge from the teacher model to the student, facilitating effective representation learning; and (2) an attention-guided masking strategy is applied to identify and mask semantically salient regions in the image, promoting the fine-grained alignment of visual and textual features. In Stage 2 (Figure 1B), the student model is fine-tuned for downstream medical VL tasks, including Med-VQA, Med-CLS, and Med-ITR. The joint representations learned during pre-training are utilized, while the teacher model is not involved in this stage. This setup highlights the transferability and efficiency of the student model.

3.1. Pre-Training Tasks

The M3AE-Distill model was pre-trained with three tasks: MLM, MIM, and ITM. The objectives of these tasks are outlined below.

3.1.1. Masked Language Modeling

MLM is a self-supervised objective that enables the model to learn word features by predicting masked tokens within a given text. In VL models, incorporating visual features into MLM facilitates enhanced textual understanding and cross-modal feature integration.
Given an image–text pair ( x , t ) , where x represents the image and t denotes the text, the text t is tokenized using a RoBERTa tokenizer with a maximum sequence length of 64 tokens, resulting in t = { w 1 , w 2 , , w n } . A total of 15% of the tokens were randomly and uniformly selected for potential masking. This randomness helps the model learn robust contextual representations and avoid overfitting. To avoid over-reliance on local context (i.e., the tendency to focus only on nearby tokens and ignore long-range dependencies), 80% of the masked tokens were replaced with the special token [ M A S K ] , 10% were substituted with random words, and the remaining 10% remained unchanged. The processed text t ˜ = { w 1 , w 2 , [ M A S K ] , [ M A S K ] , w 5 , , w n } was encoded by RoBERTa, while the image was processed by CLIP-ViT. These features were fused via a cross-modal encoder to produce a joint representation z x t ˜ . A fully connected layer followed by a softmax activation was used to produce a probability distribution over the entire vocabulary for each masked token, and the model was optimized using the cross-entropy loss:
L mlm = E ( x , t ˜ ) D H ( y m s k t , p m s k t ( x t ˜ ) ) ,
where y m s k t denotes the ground-truth tokens for the masked positions, and p m s k t ( x t ˜ ) represents the predicted probability distribution over the vocabulary.

3.1.2. Masked Image Modeling

Inspired by the success of the MLM task and aiming to enhance the model’s ability to learn visual context and spatial relationships, MIM was designed to predict masked regions within an image. Here, spatial relationships refer to structural dependencies among image patches, modeled via self-attention.
The image x was partitioned into non-overlapping 16 × 16 pixel patches, denoted as x = { x 1 , x 2 , , x m } , where m is the number of patches. Masking was performed at the patch level, where 75% of the patches were randomly selected and masked. The remaining unmasked sequence x ˜ = { x 1 , x 3 , x 7 , } was fed to CLIP-ViT to generate visual features z x ˜ . Random masking encourages the model to capture contextual dependencies and prevents overfitting to fixed spatial patterns. Thus, the final cross-modal representation was obtained by concatenating the visual features and textual features, which were then processed by a cross-modal encoder to generate joint features. A decoder was subsequently used to predict the missing patches based on these joint features. The objective was to minimize the reconstruction error, using mean squared error (MSE) between the predicted and original image patches:
L mim = E ( x ˜ , t ) D H ( y m s k x , p m s k x ( x ˜ t ) ) ,
where y m s k x represents the ground-truth pixels of the masked patches, and p m s k x ( x ˜ t ) denotes the predicted pixel values.

3.1.3. Image–Text Matching

To enhance cross-modal understanding ability, ITM loss was employed to predict the semantic alignment between image and text. Semantic alignment ensures that the paired image and text representations correspond at a conceptual level. This is critical for downstream tasks such as Med-ITR and Med-VQA.
The training samples for ITM consist of both positive and negative pairs. Positive samples are matched image–text pairs from the dataset, while negative samples are generated by substituting either the image or the text with a randomly selected alternative. This random selection is essential as it introduces diversity and unpredictability into the negative pairs, preventing the model from overfitting to specific patterns and encouraging it to learn robust semantic relationships. Then, the model computes the above complete image–text pairs to obtain cross-modal representation, followed by a classification head to predict the semantic alignment probability. The ITM loss is formulated using binary cross-entropy:
L itm = E ( x , t ) D H ( y i t m , p i t m ( x t ) ) ,
where y i t m indicates whether the image–text pair is correctly matched (1 for positive pairs, 0 for negative pairs).
Finally, the overall pre-training loss of M3AE-Distill is defined as
L pre - training = L mlm + L mim + L itm

3.2. Knowledge Distillation

KD is a widely adopted model compression technique designed to transfer knowledge from a high-capacity teacher model to a more compact student model, thereby improving computational efficiency. In this study, KD is applied during the pre-training stage, rather than during downstream task-specific fine-tuning, allowing the student to learn transferable VL representations. The student model adopts a pure Transformer architecture to model uni-modal and multi-modal features. In this context, hidden states encode rich semantic representations from visual and textual modalities, while attention maps capture structural and alignment patterns between modalities. Following the approach in [16], both hidden layer distillation and attention map distillation are employed to facilitate effective knowledge transfer from the teacher to the student model.

3.2.1. Hidden Layer Distillation

Hidden layer distillation [8] minimizes the discrepancy between intermediate feature representations of the teacher and student models, enabling the student to acquire deep representations from the teacher. Formally, let H T R L × d and H S R l × d denote the intermediate feature representations of the teacher and student models, respectively, where L and l represent the number of layers in each model, and d denotes the hidden layer dimension. l hidden layers from the teacher model are uniformly sampled to align with all student layers. The similarity between corresponding layers is enforced using the MSE loss:
L hid = i = 1 l MSE ( H i S , H m ( i ) T ) ,
where m ( i ) indicates the index of the teacher layer selected to align with the i-th student layer via uniform sampling across the teacher’s depth.

3.2.2. Attention Distillation

Attention distillation [15] encourages the student model to mimic the self-attention distributions of the teacher model, thereby improving its representational capacity. In Transformer-based architectures, the attention weight matrix A is computed as
A = softmax Q K d k ,
where d k denotes the key dimension, and Q and K represent the query and key matrices, respectively. To ensure that the student model effectively learns the teacher’s attention patterns, attention maps from the student model are aligned with those from the teacher model. Specifically, teacher layers are evenly sampled by selecting one layer every few layers, and each student layer is aligned with the last layer in the corresponding interval. The attention distillation loss is defined as
L attn = 1 h j = 1 l i = 1 h MSE A S ( i , j ) , A T ( i , m ( j ) ) ,
where h denotes the number of attention heads, A S ( i , j ) represents the attention matrix of the i-th head at the j-th student layer, and A T ( i , m ( j ) ) denotes the attention matrix from the selected teacher layer corresponding to the j-th student layer.
For the Small model, a 3:1 mapping is applied in the uni-modal encoders, and a 6:1 mapping is used in the multi-modal encoder. For the Base model, a 2:1 mapping is used for the uni-modal encoders, and a 3:1 mapping is applied in the multi-modal encoder.
Thus, the distillation loss is summarized as
L distill = L hid + L attn

3.3. Attention-Guided Masked Image Modeling

Medical images are characterized by low contrast and high structural similarity between different tissues, which together make the extraction of regions of interest (ROIs) particularly challenging [17]. These characteristics reduce the effectiveness of conventional MIM tasks in the medical domain. Ideally, the MIM task should prioritize these ROIs to encourage the model to focus on both local image features and textual information during reconstruction. But in the medical domain, the conventional MIM strategy is suboptimal, as the surrounding areas are large and are masked with a high probability. Consequently, the MIM task may degenerate into relying on visual features to predict masked pixel values, thereby diminishing the contribution of textual information.
Recent studies [18,19] have demonstrated that attention scores in Transformer-based models effectively identify salient image regions and enhance model interpretability. Motivated by this observation, an attention-guided MIM strategy is introduced. While AttMask [20] demonstrates the effectiveness of attention-guided masking in uni-modal image modeling, its direct application to VL models remains limited. Inspired by the core idea of AttMask, the proposed method extends this concept to the VL setting by leveraging attention scores from both teacher and student models to guide selective masking and reconstruction. Furthermore, by progressively increasing the masking ratio on key regions during training, the approach encourages fine-grained cross-modal alignment between visual and textual representations.
The proposed strategy consists of two stages, as illustrated in Figure 2 and Figure 3. The first stage computes an aggregated attention matrix by combining attention maps from both the teacher and student models across uni-modal and multi-modal branches. The second stage uses this matrix to progressively generate masks that prioritize high-attention areas during training, enabling the model to gradually enhance cross-modal alignment.

3.3.1. Stage 1: Attention Score Matrix Computation

As shown in Figure 2, given each image–text pair ( x , t ) , both the teacher and the student model are used to extract uni-modal and multi-modal attention maps. Let N denote the number of image patches. Four attention matrices are computed: uni-modal attention from the student model A S u n i R N × N , uni-modal attention from the teacher model A T u n i R N × N , cross-modal attention from text to image in the student model A S m u l R N × N , and the corresponding cross-modal attention in the teacher model A T m u l R N × N . These four matrices are summed and then normalized to produce the final attention guidance matrix:
A = Normalize ( A S u n i + A T u n i + A S m u l + A T m u l ) [ 0 , 1 ] N × N
This matrix reflects the relative importance of each image region to the model’s overall judgment, with higher attention values indicating more critical regions.

3.3.2. Stage 2: Attention-Guided Progressive Mask Generation

As shown in Figure 3, based on the attention matrix A obtained in Stage 1, this stage performs progressive masking over the image to guide the fine-grained alignment between the image and text. The procedure is as follows:
1.
Sort the attention scores in ascending order to obtain the sorted attention matrix A sort .
2.
Progressively adjust the proportion r of high-attention group based on training progress:
r = r s t a r t ( r s t a r t r e n d ) · s t e p m a x _ s t e p s ,
where r s t a r t = 0.95 means that at the beginning of training, only the top 5% of patches are treated as high-attention. Over the training phase, r decreases to r e n d = 0.3 , meaning up to 70% of patches may be included in the high-attention group. In this way, the threshold becomes more inclusive as training progresses, allowing the model to gradually expand its focus beyond the most salient regions.
3.
Divide the patches into high-attention and low-attention groups according to r. All patches in the high-attention group are treated as key areas and are masked. For the low-attention group, a proportion of 75 % r is randomly masked to introduce noise and enhance robustness.
4.
The two masked groups are then recombined and restored to their original order, yielding the final binary mask matrix M { 0 , 1 } N × N .
This strategy enables the mask to initially concentrate on a small set of key regions, encouraging the model to learn localized discriminative features. As training progresses, the mask range expands to cover more high score areas, promoting fine-grained alignment.

3.3.3. Case Study

As depicted in Figure 4, for the image–text pair (A), the red bounding box highlights the abnormal region. The abnormal region is only partially masked under the conventional masking strategy (B). In contrast, the proposed attention-guided masking strategy (C) adaptively selects mask regions based on cross-modal attention scores, ensuring that the abnormal region is more likely to be masked while reducing the masking ratio for background areas.
The overall pre-training loss is formulated as
L total = L pre - training + L distill

4. Experiments

4.1. Pre-Training Datasets

Following M3AE [7], the ROCO [21] and MedICaT [22] datasets were utilized for pre-training. Both datasets are derived from PubMed and are widely adopted in medical VL research. The ROCO dataset consists of approximately 81,000 radiology image–text pairs, covering modalities such as CT, MRI, X-ray, PET, ultrasound, and angiography. Each image is paired with a caption, and many also include metadata such as UMLS Concept Unique Identifiers (CUIs) and semantic types, enabling structured semantic alignment. The MedICaT dataset contains around 217,000 medical figures extracted from over 130,000 biomedical papers. About 93% of the images are medical (e.g., radiology, histology, endoscopy), and 75% are compound figures. Each figure includes a caption with an average length of 74 tokens, and 74% have additional inline textual references (average 67 tokens), providing rich multi-modal context.

4.2. Downstream Datasets

To evaluate the model’s performance, experiments were conducted across three medical VL tasks using five datasets.
For the Med-VQA task, the VQA-RAD [23], SLAKE [24], and VQA-2019 [25] datasets were employed, with answer accuracy serving as the evaluation metric:
Score = Correct Predictions Total Predictions .
For the Med-CLS task, performance was assessed on the MELINDA [26] dataset using the accuracy metric:
A c c = T P + T N T P + T N + F P + F N ,
where T P represents the number of true positives, T N represents the number of true negatives, F P represents the number of false positives, and F N represents the number of false negatives.
For the Med-ITR task, evaluations were conducted on the ROCO [21] dataset. This task includes both image-to-text retrieval (I2T), which retrieves the most relevant textual description given an input image, and text-to-image retrieval (T2I), which retrieves the most relevant image based on a textual query. Performance is measured using Recall@K (R@K), which indicates the proportion of queries for which the correct match appears within the top K retrieved results:
Recall @ K = Queries with Correct Match in Top K Total Queries .

4.3. Implementation Details

4.3.1. Experiment Settings

All experiments were conducted using PyTorch-1.9.0. During pre-training, images were resized to 288 × 288 , and text sequences were either truncated or padded to 64 tokens. For fine-tuning, images were resized to 384 × 384 , except for the image–text retrieval task, where a resolution of 288 × 288 was maintained. All reported metrics were evaluated on the test sets.
AdamW was used as the optimizer for all experiments. The learning rate was set to 1 × 10−5, and a linear scheduler with a 10% warm-up ratio was applied, gradually decreasing the learning rate to zero after warm-up. Both pre-training and fine-tuning were conducted on two NVIDIA RTX 3090 GPUs with a total batch size of 32.

4.3.2. Teacher–Student Model Comparison

As shown in Table 1, the two student models designed in this study differ from the teacher model in both architectural complexity and parameter scale, enabling flexible deployment under varying computational resource constraints.
The teacher model adopts a 12-layer Transformer for both image and text encoders, each accompanied by a 6-layer multi-modal interaction module. Specifically, two student variants, M3AE-Distill-Small and M3AE-Distill-Base, are developed with different parameter budgets. The Small version represents the most lightweight configuration, retaining only 5 image encoder layers, 4 text encoder layers, and a single-layer multi-modal module, resulting in a total of 141.6 M parameters. This design prioritizes fast inference and low-resource applicability, while preserving essential representational capacity. The Base version moderately increases the network depth to improve modeling capacity, resulting in 188.8 M parameters—approximately 55% of the teacher model’s size. This version aims to strike a balance between efficiency and performance, particularly in scenarios where moderate computational resources are available.

4.4. Results and Discussion

4.4.1. Medical VQA Results

The experimental results are presented in Table 2. MEVF is pre-trained on approximately 10,000 samples, whereas CPRD is trained on a larger dataset of roughly 20,000 samples. CPRD achieves overall scores of 67.80 and 81.10 on the VQA-RAD and SLAKE datasets, respectively, outperforming MEVF by 1.7 and 2.5 points. These improvements are attributed to the increased scale of pre-training data and the more complex pre-training tasks. Notably, PubMedCLIP, pre-trained on the ROCO dataset comprising about 81,000 samples, achieves an overall score of 72.10 on the VQA-RAD, while exceeding CPRD by 4.3 points, but showing a slight decrease of 1.0 points on the SLAKE dataset.
M3AE demonstrates the best performance among all baselines, achieving 77.01, 83.25, and 79.87 on VQA-RAD, SLAKE, and VQA-2019, respectively. M3AE-98% maintains 98% of the full model’s performance across all evaluation metrics.
The proposed M3AE-Distill models achieve strong results while maintaining compact model sizes. Specifically, M3AE-Distill-Base delivers competitive performance, ranking second on both VQA-RAD (75.55) and SLAKE (82.16), and achieving 78.46 on VQA-2019. Meanwhile, M3AE-Distill-Small yields reasonable performance (73.45 on VQA-RAD and 80.32 on SLAKE) with substantially fewer parameters.
Compared to its teacher model M3AE, M3AE-Distill-Base reduces the parameter count by approximately 150 M while incurring only marginal performance degradation. The accuracy gap remains minimal on SLAKE and VQA-2019 (82.16 vs. 83.25 and 78.46 vs. 79.87, respectively), and it generally performs better than M3AE-98%. Although M3AE-Distill-Small underperforms the Base variant, it still surpasses earlier models such as CPRD-BAN and PubMedCLIP.

4.4.2. Medical Classification Results

Table 3 reports the performance of models fine-tuned on the MELINDA dataset. As shown, uni-modal models perform relatively poorly, with ResNet-101 and RoBERTa achieving 63.84 and 74.60 accuracy, respectively. In contrast, multi-modal models generally achieve superior results.
Among all evaluated models, M3AE achieves the highest accuracy (78.50), followed closely by M3AE-Distill-Base with 77.37, and M3AE-98% with 76.93. M3AE-Distill-Small achieves 74.31 accuracy, slightly lower than other multi-modal models. This demonstrates its effectiveness in resource-constrained scenarios, offering a favorable balance between model compactness and performance.
Overall, both variants of M3AE-Distill exhibit strong performance on the MELINDA classification task. M3AE-Distill-Base provides a competitive alternative to the full M3AE model with a significant reduction in parameter count, while M3AE-Distill-Small offers a lightweight option suitable for low-resource deployment.

4.4.3. Medical Retrieval Results

The retrieval results on the ROCO dataset are summarized in Table 4. Among all evaluated methods, M3AE achieves the best performance across all retrieval metrics. These results significantly outperform prior models such as ViLT and PubMedCLIP, demonstrating the effectiveness of multi-modal pre-training with large-scale datasets.
The proposed M3AE-Distill-Base ranks second overall, achieving R@1 scores of 14.36 (T2I) and 14.80 (I2T), along with competitive results at R@5 and R@10. Despite its reduced parameter count, it retains most of the retrieval performance of the full M3AE model. In contrast, M3AE-Distill-Small shows a marked drop in retrieval task (e.g., R@1 of 4.05 for T2I), despite its acceptable performance on Med-VQA and Med-CLS tasks. This contrast suggests that retrieval tasks are more sensitive to representation capacity and cross-modal alignment quality. While the Small variant strikes a good balance for simpler tasks like VQA and classification, it appears insufficient for retrieval, which demands stronger fine-grained modeling across modalities.

4.4.4. Efficiency Comparison Results

Table 5 compares the training and inference efficiency of different models. As expected, the student models demonstrate substantial improvements in both throughput and CPU inference latency compared to the teacher model.
M3AE-Distill-Small achieves the highest efficiency, with a training throughput of 409.81 pairs/s and inference throughput of 417.29 pairs/s, representing a 4.83× and 3.51× speedup over the teacher model, respectively. Additionally, its CPU inference latency is reduced to 100 ms per sample—about 3.7× faster than M3AE, which makes it highly suitable for deployment in resource-constrained environments. M3AE-Distill-Base also delivers notable gains, achieving 2.61× training and 2.11× inference speedup, with a CPU latency of 208 ms (1.78× faster).
These results demonstrate that both student variants offer considerable efficiency advantages, with Small prioritizing lightweight deployment and Base offering a more balanced trade-off between performance and speed.

4.4.5. Ablation Studies

Module Ablation

To assess the effectiveness of the proposed modules, models are first pre-trained on pre-training datasets and then fine-tuned on the SLAKE dataset. The results are shown in Table 6, where M3AE-Distill-Base is selected as the comparison.
ID 0 corresponds to the M3AE model, serving as an upper-bound reference. ID 1 represents the student model trained with pre-training only, yielding an overall accuracy of 81.26. In ID 2, KD is introduced, resulting in a performance gain to 81.64, demonstrating the effectiveness of distillation in improving student model alignment with the teacher. ID 3 incorporates the proposed attention-guided masking strategy, which further enhances the overall performance to 82.16, while achieving the highest open metric (80.57), highlighting the effectiveness of this module.

Attention Score Source

To evaluate the influence of attention score composition on the masking strategy, three configurations are compared, as shown in Table 7. In the Student-only setting, attention maps are derived solely from the student model; in the Teacher-only setting, they are derived from the frozen teacher model; and in the Student+Teacher setting, they are derived from the element-wise summation of both.
The highest overall accuracy (82.16) is obtained when combining student and teacher attention, indicating that attention information from both models contributes complementary guidance for identifying salient visual regions. In contrast, using attention scores from either the student or teacher alone resulted in slightly lower performance (81.36 and 81.40, respectively), suggesting that a single-source attention signal may be insufficient for optimal mask generation.

4.4.6. Case Study

Feature Distribution Visualization via PCA

To qualitatively assess the representation quality of the student models, 5000 samples from the ROCO dataset are selected, and their feature embeddings from the image encoder, text encoder, and multi-modal encoder are visualized using PCA, as shown in Figure 5.
Clearer cluster boundaries are observed in the image encoder of the Base model, whereas the Small model shows more redundant and dispersed clusters, suggesting weaker visual representation modeling. In the text encoder, both models produce similarly compact distributions, indicating that each is capable of effectively capturing semantic information from the text modality. For the multi-modal encoder, the Base model exhibits well-separated clusters, reflecting better cross-modal alignment and semantic integration, while the Small model displays a more diffuse pattern, revealing limitations in joint representation learning.
Six samples from the SLAKE dataset are selected to assess the effectiveness of M3AE-Distill. The cross-attention weights from text to image are employed to visualize the attention distribution in the image branch, as illustrated in Figure 6, where (A) shows the input image and question, (B) shows the prediction from the Small model, and (C) shows the prediction from the Base model.
Across all cases, both models demonstrate reasonable comprehension and grounding capabilities, while the Base model generally exhibits more focused and anatomically accurate attention distributions. For example, in the last sample, the Small model incorrectly identifies the “large bowel” due to diffuse attention, whereas the Base model correctly locates and answers the “spinal cord” with a more precise activation map.

5. Conclusions

This study builds on the M3AE medical VL model and introduces M3AE-Distill, an efficient framework for medical VL tasks, including VQA, classification, and image–text retrieval. Although M3AE benefits from large-scale pre-training datasets and extensive tunable parameters, its deployment is constrained by the substantial computational demands. To address this limitation, KD is employed to compress M3AE into a lightweight model, M3AE-Distill. To enhance knowledge transfer from teacher to student, an attention-guided masking strategy is developed.
The experimental results demonstrate that M3AE-Distill outperforms prior medical VL models. It attains performance comparable to that of the teacher model while significantly enhancing computational efficiency.
For future work, more advanced KD techniques will be explored to further compress VL models without sacrificing performance. Furthermore, improving the robustness and generalization of medical VL models to better adapt to diverse real-world medical datasets remains a key direction for future research.

Limitations and Prospects

Despite its demonstrated effectiveness across several medical VL tasks, the proposed M3AE-Distill framework still faces certain limitations. One key limitation lies in the nature of the pre-training data, which consist exclusively of 2D medical images, such as X-rays and 2D slices of CT or ultrasound. As a result, the model may not effectively capture the spatial continuity and contextual depth required for interpreting full 3D volumetric data, such as complete CT or MRI scans. Addressing this limitation may require adapting the model architecture to incorporate 3D spatial representations or designing pre-training strategies that account for volumetric information. In addition, the current framework assumes access to clean and fully aligned image–text pairs. Extending the model to handle weak supervision, noisy annotations, or partially missing modalities would further increase its practical applicability in clinical settings.

Author Contributions

Conceptualization, X.L. and M.Z.; methodology, X.L.; validation, X.L. and M.Z.; visualization, X.L. and J.X.; writing—original draft preparation, X.L.; writing—review and editing, X.L., J.X. and Z.B.; supervision, J.X. and Z.B.; project administration, J.X. All authors have read and agreed to the published version of the manuscript.

Funding

This research received no external funding.

Institutional Review Board Statement

Not applicable.

Informed Consent Statement

Not applicable.

Acknowledgments

This work is Supported by Shanghai Technical Service Computing Center of Science and Engineering, Shanghai University.

Conflicts of Interest

The authors declare no conflict of interest.

References

  1. He, K.; Zhang, X.; Ren, S.; Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 770–778. [Google Scholar]
  2. Liu, Y.; Ott, M.; Goyal, N.; Du, J.; Joshi, M.; Chen, D.; Levy, O.; Lewis, M.; Zettlemoyer, L.; Stoyanov, V. Roberta: A robustly optimized bert pretraining approach. arXiv 2019, arXiv:1907.11692. [Google Scholar]
  3. Devlin, J.; Chang, M.W.; Lee, K.; Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Minneapolis, MN, USA, 2–7 June 2019; Volume 1, pp. 4171–4186. [Google Scholar]
  4. Singh, S.; Karimi, S.; Ho-Shon, K.; Hamey, L. From chest x-rays to radiology reports: A multimodal machine learning approach. In Proceedings of the 2019 Digital Image Computing: Techniques and Applications (DICTA), Perth, Australia, 2–4 December 2019; pp. 1–8. [Google Scholar]
  5. Moon, J.H.; Lee, H.; Shin, W.; Kim, Y.H.; Choi, E. Multi-modal understanding and generation for medical images and text via vision-language pre-training. IEEE J. Biomed. Health Inform. 2022, 26, 6070–6080. [Google Scholar] [CrossRef] [PubMed]
  6. Eslami, S.; Meinel, C.; De Melo, G. PubMedCLIP: How Much Does CLIP Benefit Visual Question Answering in the Medical Domain? In Proceedings of the Findings of the Association for Computational Linguistics: EACL 2023, Dubrovnik, Croatia, 2–6 May 2023; pp. 1151–1163. [Google Scholar]
  7. Chen, Z.; Du, Y.; Hu, J.; Liu, Y.; Li, G.; Wan, X.; Chang, T.H. Multi-modal masked autoencoders for medical vision-and-language pre-training. In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention, Singapore, 18–22 September 2022; pp. 679–689. [Google Scholar]
  8. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
  9. Zhao, L.; Qian, X.; Guo, Y.; Song, J.; Hou, J.; Gong, J. MSKD: Structured knowledge distillation for efficient medical image segmentation. Comput. Biol. Med. 2023, 164, 107284. [Google Scholar] [CrossRef] [PubMed]
  10. Zeng, X.; Ji, Z.; Zhang, H.; Chen, R.; Liao, Q.; Wang, J.; Lyu, T.; Zhao, L. DSP-KD: Dual-stage progressive knowledge distillation for skin disease classification. Bioengineering 2024, 11, 70. [Google Scholar] [CrossRef] [PubMed]
  11. Wang, J.; Huang, S.; Du, H.; Qin, Y.; Wang, H.; Zhang, W. MHKD-MVQA: Multimodal hierarchical knowledge distillation for medical visual question answering. In Proceedings of the 2022 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), Las Vegas, NV, USA, 6–8 December 2022; pp. 567–574. [Google Scholar]
  12. Kim, W.; Son, B.; Kim, I. Vilt: Vision-and-language transformer without convolution or region supervision. In Proceedings of the International Conference on Machine Learning, Virtual, 18–24 July 2021; pp. 5583–5594. [Google Scholar]
  13. Li, J.; Selvaraju, R.; Gotmare, A.; Joty, S.; Xiong, C.; Hoi, S.C.H. Align before fuse: Vision and language representation learning with momentum distillation. Adv. Neural Inf. Process. Syst. 2021, 34, 9694–9705. [Google Scholar]
  14. Romero, A.; Ballas, N.; Kahou, S.E.; Chassang, A.; Gatta, C.; Bengio, Y. Fitnets: Hints for thin deep nets. arXiv 2014, arXiv:1412.6550. [Google Scholar]
  15. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. Tinybert: Distilling bert for natural language understanding. arXiv 2019, arXiv:1909.10351. [Google Scholar]
  16. Wang, T.; Zhou, W.; Zeng, Y.; Zhang, X. Efficientvlm: Fast and accurate vision-language models via knowledge distillation and modal-adaptive pruning. arXiv 2022, arXiv:2210.07795. [Google Scholar]
  17. Xie, X.; Pan, X.; Zhang, W.; An, J. A context hierarchical integrated network for medical image segmentation. Comput. Electr. Eng. 2022, 101, 108029. [Google Scholar] [CrossRef]
  18. Han, Y.; Holste, G.; Ding, Y.; Tewfik, A.; Peng, Y.; Wang, Z. Radiomics-guided global-local transformer for weakly supervised pathology localization in chest X-rays. IEEE Trans. Med. Imaging 2022, 42, 750–761. [Google Scholar] [CrossRef] [PubMed]
  19. Leem, S.; Seo, H. Attention guided CAM: Visual explanations of vision transformer guided by self-attention. In Proceedings of the AAAI Conference on Artificial Intelligence, Vancouver, BC, Canada, 26–27 February 2024; Volume 38, pp. 2956–2964. [Google Scholar]
  20. Kakogeorgiou, I.; Gidaris, S.; Psomas, B.; Avrithis, Y.; Bursuc, A.; Karantzalos, K.; Komodakis, N. What to hide from your students: Attention-guided masked image modeling. In Proceedings of the European Conference on Computer Vision, Tel Aviv, Israel, 23–27 October 2022; pp. 300–318. [Google Scholar]
  21. Pelka, O.; Koitka, S.; Rückert, J.; Nensa, F.; Friedrich, C.M. Radiology Objects in COntext (ROCO): A Multimodal Image Dataset. In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI) Workshop, Granada, Spain, 16–20 September 2018; pp. 180–189. [Google Scholar]
  22. Subramanian, S.; Wang, L.L.; Mehta, S.; Bogin, B.; van Zuylen, M.; Parasa, S.; Singh, S.; Gardner, M.; Hajishirzi, H. Medicat: A dataset of medical images, captions, and textual references. arXiv 2020, arXiv:2010.06000. [Google Scholar]
  23. Lau, J.J.; Gayen, S.; Ben Abacha, A.; Demner-Fushman, D. A dataset of clinically generated visual questions and answers about radiology images. Sci. Data 2018, 5, 1180251. [Google Scholar] [CrossRef] [PubMed]
  24. Liu, B.; Zhan, L.M.; Xu, L.; Ma, L.; Yang, Y.; Wu, X.M. Slake: A semantically-labeled knowledge-enhanced dataset for medical visual question answering. In Proceedings of the 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI), Nice, France, 13–16 April 2021; pp. 1650–1654. [Google Scholar]
  25. Ben Abacha, A.; Hasan, S.A.; Datla, V.V.; Demner-Fushman, D.; Müller, H. Vqa-med: Overview of the medical visual question answering task at imageclef 2019. In Proceedings of the CLEF (Conference and Labs of the Evaluation Forum) 2019 Working Notes, Lugano, Switzerland, 9–12 September 2019. [Google Scholar]
  26. Wu, T.L.; Singh, S.; Paul, S.; Burns, G.; Peng, N. Melinda: A multimodal dataset for biomedical experiment method classification. In Proceedings of the AAAI Conference on Artificial Intelligence, Virtual, 19–21 May 2021; Volume 35, pp. 14076–14084. [Google Scholar]
  27. Nguyen, B.D.; Do, T.T.; Nguyen, B.X.; Do, T.; Tjiputra, E.; Tran, Q.D. Overcoming data limitation in medical visual question answering. In Proceedings of the Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, 13–17 October 2019; pp. 522–530. [Google Scholar]
  28. Liu, B.; Zhan, L.M.; Wu, X.M. Contrastive pre-training and representation distillation for medical visual question answering based on radiology images. In Proceedings of the Medical Image Computing and Computer Assisted Intervention—MICCAI 2021: 24th International Conference, Strasbourg, France, 27 September–1 October 2021; pp. 210–220. [Google Scholar]
  29. Yang, Z.; He, X.; Gao, J.; Deng, L.; Smola, A. Stacked attention networks for image question answering. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 21–29. [Google Scholar]
Figure 1. Overview of the M3AE-Distill model.
Figure 1. Overview of the M3AE-Distill model.
Bioengineering 12 00738 g001
Figure 2. Computation of the aggregated attention score matrix.
Figure 2. Computation of the aggregated attention score matrix.
Bioengineering 12 00738 g002
Figure 3. Attention-guided masking strategy.
Figure 3. Attention-guided masking strategy.
Bioengineering 12 00738 g003
Figure 4. Comparison of random masking and attention-guided masking strategies.
Figure 4. Comparison of random masking and attention-guided masking strategies.
Bioengineering 12 00738 g004
Figure 5. PCA visualizations of image, text, and multi-modal embeddings.
Figure 5. PCA visualizations of image, text, and multi-modal embeddings.
Bioengineering 12 00738 g005
Figure 6. Case study on the SLAKE dataset. Attention maps highlight the regions in the image attended to by the model in response to the question.
Figure 6. Case study on the SLAKE dataset. Attention maps highlight the regions in the image attended to by the model in response to the question.
Bioengineering 12 00738 g006
Table 1. Comparison of model architectures and parameter counts.
Table 1. Comparison of model architectures and parameter counts.
Model NameUni-Modal
(Layers L/Params M)
Multi-Modal
(Layers L/Params M)
Total Params
(M)
Image Text Image Text
Teacher (M3AE)12/104 M12/124 M6/56.7 M6/56.7 M341.4 M
Student (Small)5/54.7 M4/67.9 M1/9.5 M1/9.5 M141.6 M
Student (Base)7/68.9 M6/82.1 M2/18.9 M2/18.9 M188.8 M
Table 2. VQA results on VQA-RAD, SLAKE and VQA-2019 datasets (Best; Second-Best).
Table 2. VQA results on VQA-RAD, SLAKE and VQA-2019 datasets (Best; Second-Best).
MethodsVQA-RADSLAKEVQA-2019
Open Closed Overall Open Closed Overall Overall
MEVF-SAN [27]49.2073.9064.1075.3078.4076.5068.90
MEVF-BAN [27]49.2077.2066.1077.8079.8078.6077.86
CPRD-BAN [28]52.5077.9067.8079.5083.4081.10-
PubMedCLIP [6]60.1080.0072.1078.4082.5080.10-
M3AE [7]67.23 83.4677.0180.3187.8283.2579.87
M3AE-98% [7]65.8981.7975.4778.7086.0681.5979.27
M3AE-Distill-Small64.2579.4973.4578.1783.6580.3274.07
M3AE-Distill-Base65.9281.8775.5580.5784.6282.1678.46
Table 3. Classification results on MELINDA dataset. (Best; Second-Best).
Table 3. Classification results on MELINDA dataset. (Best; Second-Best).
MethodsModalityAcc
ResNet-101 [1]Image63.84
RoBERTa [2]Text74.60
NLF [26]Image + Text76.60
SAN [29]Image + Text72.30
M3AE [7]Image + Text78.50
M3AE-98% [7]Image + Text76.93
M3AE-Distill-SmallImage + Text74.31
M3AE-Distill-BaseImage + Text77.37
Table 4. Image–text retrievals results on ROCO dataset. (Best; Second-Best).
Table 4. Image–text retrievals results on ROCO dataset. (Best; Second-Best).
MethodsT2II2T
R@1 R@5 R@10 R@1 R@5 R@10
ViLT [12]9.7528.9541.4011.9031.9043.20
PubMedCLIP * [6]8.6126.7338.498.1625.7838.24
M3AE * [7]16.9646.4761.3317.6546.4060.95
M3AE-Distill-Small4.0514.9624.365.4019.9529.20
M3AE-Distill-Base14.3637.1751.2314.8036.9050.20
* denotes our reproduced results.
Table 5. Training and inference efficiency comparison of different models.
Table 5. Training and inference efficiency comparison of different models.
Model VariantTraining
(pairs/s)
Inference
(pairs/s)
CPU Inference
(ms/pair)
Speedup
(Train/Inference/CPU)
Teacher (M3AE)84.77119.073701.00/1.00/1.00
Student (Small)409.81417.291004.83/3.51/3.70
Student (Base)221.61251.472082.61/2.11/1.78
Table 6. Ablation experiments of employed modules (Best; Second-Best).
Table 6. Ablation experiments of employed modules (Best; Second-Best).
IDStrategyOpenCloseOverall
0M3AE (Teacher)80.3187.8283.25
1Pre-training79.1084.6281.26
2Pre-training + KD79.2685.3781.64
3Pre-training + KD + Attention Mask80.5784.6282.16
Table 7. Ablation experiments of attention score source (Best; Second-Best).
Table 7. Ablation experiments of attention score source (Best; Second-Best).
Attention Score TypeOpenCloseOverall
Student-only80.0383.4181.36
Teacher-only80.0383.5381.40
Student + Teacher80.5784.6282.16
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Liang, X.; Xie, J.; Zhang, M.; Bi, Z. M3AE-Distill: An Efficient Distilled Model for Medical Vision–Language Downstream Tasks. Bioengineering 2025, 12, 738. https://doi.org/10.3390/bioengineering12070738

AMA Style

Liang X, Xie J, Zhang M, Bi Z. M3AE-Distill: An Efficient Distilled Model for Medical Vision–Language Downstream Tasks. Bioengineering. 2025; 12(7):738. https://doi.org/10.3390/bioengineering12070738

Chicago/Turabian Style

Liang, Xudong, Jiang Xie, Mengfei Zhang, and Zhuo Bi. 2025. "M3AE-Distill: An Efficient Distilled Model for Medical Vision–Language Downstream Tasks" Bioengineering 12, no. 7: 738. https://doi.org/10.3390/bioengineering12070738

APA Style

Liang, X., Xie, J., Zhang, M., & Bi, Z. (2025). M3AE-Distill: An Efficient Distilled Model for Medical Vision–Language Downstream Tasks. Bioengineering, 12(7), 738. https://doi.org/10.3390/bioengineering12070738

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop