To enhance clarity, the overall framework is divided into smaller sub-modules, each with its own specific role in the algorithm’s functionality.
Figure 2 illustrates the Dynamic Sample Reweighting Strategy (DSR), responsible for processing unlabeled data and adaptively adjusting sample weights. This helps to improve pseudo-label quality and reduce the impact of noisy samples on model training. Next,
Figure 3 presents the Medical Multi-scale Feature Fusion Network (MedFuseNet), which integrates feature maps of unlabeled images at multiple scales. The fusion of low-level and high-level features enhances the classification accuracy, providing rich representations for further learning tasks. Finally,
Figure 4 depicts the Pseudo-label Guided Contrastive Learning (PGC) module. Here, pseudo-labels are used to define positive and negative sample pairs for contrastive learning. The feature representations are updated using momentum, and dynamic weighting is applied to the contrastive learning loss. Additionally, the feature queue is updated, improving the learning effectiveness on the unlabeled data. Each of these modules contributes to the robust performance of the semi-supervised medical image classification task. By following these sub-modules in sequence, the entire framework operates cohesively, ensuring effective pseudo-label generation, high-quality feature fusion, and efficient contrastive learning. Collectively, these components play a critical role in improving the model’s performance on semi-supervised medical image classification tasks.
3.1. Dynamic Sample Reweighting Strategy
In semi-supervised medical image classification, leveraging unlabeled data is essential for improving model performance. However, one of the key challenges lies in the potential for error propagation through noisy and imbalanced data, especially in the context of pseudo-labeling. Pseudo-labeling, while effective in utilizing unlabeled data, often suffers from the problem of noisy pseudo-labels, particularly during the early stages of training when the model’s predictions are less accurate. In such situations, the model may generate incorrect pseudo-labels that are subsequently used to train the model, amplifying errors as training progresses. This error propagation is further exacerbated when the data are imbalanced, as the model may rely too heavily on noisy pseudo-labels from underrepresented classes, leading to poor generalization.
To address this issue, we propose a Dynamic Sample Reweighting strategy that adjusts the weight of unlabeled samples based on their predicted reliability. As illustrated in
Figure 2, by incorporating meta-learning concepts and gradient descent algorithms, our method dynamically reweights the unlabeled data throughout the training process. This approach helps reduce the influence of unreliable pseudo-labels by giving more weight to samples that the model is confident about. Consequently, this strategy mitigates the risk of error propagation, ensuring that the model’s learning process remains stable and reliable, even when faced with noisy or imbalanced data.
Additionally, in our Dynamic Sample Reweighting strategy, the model adjusts the weight of unlabeled samples based on their predicted uncertainty, which is calculated using metrics like entropy. During the early training stages, when pseudo-labels are less reliable, we assign lower weights to uncertain samples, reducing their impact on training and mitigating error propagation. As training progresses and the model’s predictions become more accurate, the weights of more reliable pseudo-labels are increased, allowing these samples to contribute more to the training process. This dynamic adjustment ensures that noisy pseudo-labels have minimal influence initially, while the model gradually learns from more confident pseudo-labels, leading to a stable and reliable learning process that can better generalize to unseen data.
Overall optimization loss function. In semi-supervised medical image classification, we aim to utilize both labeled and unlabeled data to improve model performance. To achieve this, we introduce an overall optimization loss function denoted as
. This loss function is designed to balance the learning process across different types of data, where
represents the model parameters and
represents the weights assigned to unlabeled samples based on their reliability. The optimization process incorporates the dynamic reweighting of unlabeled samples, which is crucial in the presence of noisy or imbalanced data. By adjusting the weights
based on the predicted entropy, the model focuses more on reliable unlabeled data while minimizing the impact of unreliable samples. This dynamic adjustment helps maintain the stability of the model’s learning process, especially in the early training stages, and ensures that the performance gains are robust to noise and data imbalance. The overall optimization loss function can be expressed as
where
represents the loss function of the outer loop, which focuses on optimizing the meta-parameters
. The outer loop adjusts the weights of the unlabeled samples based on their predicted reliability. Meanwhile,
is the inner loop loss function, responsible for updating model parameters
to minimize the training loss on both labeled and weighted unlabeled data. The term
is a hyperparameter that balances the contribution of the inner and outer loop objectives [
34].
Gradient Calculation. Next, we calculate the gradient of the overall optimization loss function
with respect to the parameters
and
. The gradients are essential for updating the model parameters and are computed using the following equation:
where
represents the gradient of the loss function with respect to model parameters
and meta-parameters
. The gradient
indicates how the loss function changes with respect to changes in the model parameters, while
shows the sensitivity of the loss to changes in the weights assigned to the unlabeled samples. These gradients are computed using standard backpropagation techniques, allowing for iterative optimization of both labeled and unlabeled data during the training process.
By using this gradient-based approach, the model iteratively adjusts both the parameters and the weights of the samples. This two-loop optimization process not only updates the model to fit reliable data but also continuously refines the pseudo-label reliability assessment throughout training. As a result, it significantly reduces the risk of error amplification that often plagues pseudo-labeling approaches in noisy and imbalanced medical datasets.
Parameter update. We use the gradient descent algorithm to iteratively update the model parameters
and weights
as follows:
In these equations,
and
are the learning rates for updating the model parameters
and weights
, respectively. The learning rates control the size of the parameter update step, ensuring stable convergence during training.
In practice, this dynamic reweighting strategy helps the model focus more on unlabeled samples with higher predicted reliability, effectively reducing the influence of noisy pseudo-labels during the early stages of training. By adjusting the sample weights according to the predicted entropy (uncertainty), the model avoids relying heavily on samples with uncertain or erroneous pseudo-labels, thus reducing error propagation. In the early stages of training, when the model’s predictions are less accurate and pseudo-labels are noisy, the dynamic weighting mechanism ensures that these uncertain samples are given lower weight, preventing them from negatively affecting the model’s learning process. As the model improves and its confidence in the pseudo-labels increases, the weights of the unlabeled samples are dynamically updated, ensuring that only the most reliable samples contribute to the training, which reduces the risk of error propagation from incorrect pseudo-labels. This is particularly important in medical image classification tasks where noisy and imbalanced data are common, as it ensures that the model focuses on more reliable data, leading to better generalization and performance on unseen samples.
Unlabeled Data Weight. For each unlabeled sample
, we calculate the predicted entropy to assign a weight to the sample based on its predicted uncertainty. The entropy
H of the prediction is given by
where
is the predicted probability for the unlabeled sample
belonging to class
j, and
C is the total number of classes.
The weight
for each unlabeled sample is then defined as
This formulation allows us to convert the predicted entropy
into a confidence score
. The weight for each unlabeled sample is computed based on its predicted entropy. The entropy reflects the model’s uncertainty about the sample’s class: higher entropy indicates greater uncertainty, resulting in a lower weight, while lower entropy suggests higher confidence, leading to a higher weight. This dynamic weighting mechanism ensures that the model prioritizes more reliable data and reduces the impact of noisy or incorrect pseudo-labels during training. As the model’s confidence improves, the weights of unlabeled samples are updated to better reflect their reliability, minimizing the risk of error propagation from uncertain pseudo-labels.
In the early stages of training, when the model’s predictions are less reliable, the entropy is higher for most of the unlabeled samples, meaning they will receive lower weights. This ensures that uncertain samples have minimal influence on the model’s learning in the initial phases. As training progresses and the model’s confidence increases, the entropy decreases for more confident samples, and the weights are gradually adjusted, allowing the model to learn more effectively from reliable pseudo-labels and reducing the influence of noisy ones.
The entire Dynamic Sample Reweighting strategy is a meta-learning-based parameter update process. The model inputs the weights
for the unlabeled samples and the parameters
of the neural network. The training process involves two nested loops. The inner loop updates the model parameters
to minimize the inner loop loss
given the current weights
, as shown in Equation (
3). In contrast, the outer loop adjusts the weights
based on the updated model parameters
, to maximize the outer loop loss
, as shown in Equation (
4). This dynamic optimization process ensures that the weights of the unlabeled samples are continuously updated to reflect their reliability, improving the efficiency of unlabeled data utilization during training. Moreover, the use of a meta-learning framework allows the model to iteratively adjust these weights as the training progresses, reducing the reliance on noisy pseudo-labels, particularly in the early training stages. Unlike traditional pseudo-labeling approaches that apply fixed weights, our method continuously updates the weights based on the model’s current confidence, thereby mitigating the risk of error propagation. This adaptive strategy ultimately enhances the performance of the semi-supervised medical image classification model, particularly in improving the consistency loss and pseudo-label-guided contrastive learning loss.
3.2. Enhancing Consistency Regularization
We apply consistency regularization to train the model with the goal of generating similar predictions from different perturbations of the same image. In our approach, we have removed the Gaussian noise from the original Mean Teacher (MT) model [
10], as smaller Gaussian noise does not significantly enhance model performance. Instead, we use a diverse set of augmentation strategies. For each unlabeled sample, both weak and strong augmentation strategies are applied before the samples are processed by the student and teacher models. Additionally, Dynamic Sample Reweighting ensures the reliability of the data. This enhancement strategy better simulates real data diversity and improves the model’s utilization efficiency of unlabeled data.
Weak and Strong Augmentation. For each unlabeled sample
, we apply weak augmentation
and strong augmentation
:
Here,
and
are the augmented versions of the sample
using weak and strong augmentations, respectively. The choice of augmentation strategies follows methods used in prior works for enhancing model robustness [
35].
Weak Augmentation: Weak augmentation refers to mild perturbations applied to the original image, such as small geometric transformations or slight changes in color. These augmentations are designed to simulate slight variations in the input while preserving the core structure and information of the image. The goal of weak augmentation is to enforce consistency under minimal perturbation, ensuring that the model learns to generalize well under minor changes.
Strong Augmentation: In contrast, strong augmentation introduces more significant transformations to the image, such as larger crops, rotations, or applying more extreme color distortions. These augmentations are intended to challenge the model more and promote robustness to more substantial variations in the input data. The purpose of strong augmentation is to ensure that the model can maintain consistent predictions even under more substantial perturbations, simulating real-world variations in the data.
Model Outputs. The student model and the teacher model generate outputs based on the augmented samples:
where
denotes the student model and
denotes the teacher model. The parameters
and
represent the model parameters for the student and teacher networks. The student model is updated based on the outputs of
, and the teacher model is updated based on
.
Enhanced Consistency Loss. The enhanced consistency loss is defined as
where
is the total number of unlabeled samples and
is the dynamic weight for each sample. The loss function
measures the consistency between the predictions of the student and teacher models on augmented samples. The use of dynamic weights
adjusts the importance of each sample based on its reliability, as determined through our Dynamic Sample Reweighting strategy [
36].
The design of the enhanced consistency loss aims to improve the model’s robustness and generalization capabilities by leveraging different augmentation strategies. By ensuring that the outputs of the student and teacher models remain consistent under various perturbations, we can utilize unlabeled data more effectively, significantly enhancing classification performance in semi-supervised learning scenarios. This approach not only improves the model’s efficiency in utilizing unlabeled data but also strengthens its adaptability to data diversity. Additionally, the model uses an Exponential Moving Average (EMA) strategy to smooth the parameter updates, which helps to reduce fluctuations during training and ensures the consistency of the outputs, which is crucial for effective semi-supervised learning.
3.3. Medical Multi-Scale Feature Fusion Network (MedFuseNet)
To better capture and integrate image features at different levels and optimize unlabeled data feature learning in semi-supervised medical image classification, this paper designs the Medical Multi-scale Feature Fusion Network (MedFuseNet). In MedFuseNet, feature extraction is conducted at multiple levels to capture a wide range of details from lesion images. The low-level features, extracted by the initial layers (conv1), focus on fine-grained details such as edges, textures, and local patterns, which are crucial for identifying subtle variations in the image. The high-level features, derived from the deeper layers (conv3), represent abstract, global information, including the overall shape and contextual relationships of lesions, which contribute to a holistic understanding of the image. This multi-level approach ensures that both detailed and abstract information are effectively utilized, leading to optimal classification results.
To address the challenges in semi-supervised learning (SSL) where labeled data are scarce and unlabeled data may be noisy or imbalanced, MedFuseNet incorporates multi-scale feature fusion in an SSL setting, optimizing the learning of both labeled and unlabeled data. By integrating low-level and high-level features in a manner that prioritizes the most informative parts of the image, we enable the model to better adapt to the complexities of semi-supervised tasks. Specifically, this fusion strategy boosts the model’s ability to classify images with minimal labeled data while leveraging the abundant unlabeled data in the training process.
Observing
Figure 3, the network extracts low- and high-level features using convolutional layers. These features are standardized to a common number of channels using 1 × 1 convolutions. A similarity matrix is then created to evaluate the correlation between features by flattening and multiplying them. The similarity matrix can be expressed as
where
represents the similarity between the
i-th feature in the low-level feature
P of the image and the
j-th feature in the high-level feature
Q.
N is the number of pixels, which is the product of the feature map height
H and width
W, where
and
.
Specifically, we first perform matrix multiplication on the similarity matrix
obtained from the calculation above and the feature at the
i-th position of the low-level feature
to obtain the weighted low-level features
. These features are then reshaped to match the shape of the high-level features
. Finally, the adjusted low-level features are element-wise added to the high-level features to obtain the final fused features:
where
represents the feature of the
j-th channel of the high-level feature map. The fused features are further processed to map them to the number of channels required for the classification task, and the feature map size is adjusted through upsampling to match the original size of the input image, thereby obtaining the final classification result.
In contrast to traditional feature fusion techniques, MedFuseNet uses a similarity matrix to quantify the correlation between low-level and high-level features. Traditional methods typically fuse features from different scales through simple weighted sums or concatenation. In contrast, our design precisely adjusts the fusion of low-level and high-level features by calculating the similarity matrix (Equation (
10)), allowing the fused features to better capture both the detailed and global information of the image. Additionally, the 1 × 1 convolution layers and matrix weighting mechanism in MedFuseNet provide a more effective way to adjust low-level features, effectively preventing the loss of detailed information in the context of high-level features, which is a common issue in traditional methods.
Regarding feature fusion, all extracted features from low and high levels are used. We apply 1 × 1 convolutions (convP, convQ) to standardize the channel dimensions of these features, ensuring they are compatible for the subsequent fusion process. The fusion process involves the use of a similarity matrix, which measures the correlation between low-level and high-level features. This matrix is used to weigh and combine the most relevant features from both levels. The weighted low-level features are then adjusted to match the shape of the high-level features before being fused via element-wise addition. This ensures that the fusion captures the most important information from each level, enhancing feature diversity without excluding any important details.
In MedFuseNet, the fusion of low-level and high-level features serves not only to enhance feature representation but also to closely interact with the pseudo-label mechanism in semi-supervised learning. By dynamically generating pseudo-labels for unlabeled data and weighting the low-level and high-level features based on these labels, MedFuseNet more effectively utilizes unlabeled data, thereby boosting the overall model’s generalization ability. Unlike traditional feature fusion methods, MedFuseNet’s pseudo-label-guided fusion significantly improves the model’s adaptability to unlabeled data, making it more robust in real-world medical imaging scenarios with sparse labeled data.
3.4. Pseudo-Label Guided Contrastive Learning (PGC)
Momentum Update Strategy. In this paper, we adopt the momentum update method from MoCo V2 [
37] to implement dynamic updates of the queue mechanism. This method introduces a momentum parameter to retain historical update information and combines it with current gradient information to achieve real-time updates of features and weights, optimizing the feature learning of unlabeled data.
Specifically, we apply momentum updates to both the features of unlabeled data stored in the queue and their corresponding weights. This ensures that the feature information in the queue is effectively updated and maintained as new data features are captured. Let
represent the features stored in the queue at time step
t. The formula for momentum update can be expressed as follows:
In these equations,
represents the momentum vector at time
t, which integrates both the historical momentum
and the current gradient
. The momentum coefficient
governs the retention of historical information, while
, the gradient at time step
t, signifies the instantaneous rate of change. Additionally,
denotes the learning rate, which determines the step size for the update. Together, these parameters play a crucial role in optimizing the learning process and enhancing the convergence of the model.
Pseudo-label Guided Contrastive (PGC) Loss Function. For unlabeled samples
, we make predictions using the trained model to obtain a class probability distribution. From this distribution, the category with the highest probability is selected as the pseudo-label:
where
C denotes the category index. For the current unlabeled sample, the generated pseudo-label is compared with the sample features stored in the queue. If the pseudo-labels match, they form a positive sample pair (samples of the same category) with the current sample; otherwise, they form a negative sample pair (samples of different categories).
To measure the similarity between samples, a contrastive loss function is constructed. The loss function must meet the following conditions:
For each sample , when it is similar to the positive sample pair, its similarity is high and the contrast loss is small.
When the sample is not similar to the positive sample pair or is similar to other negative sample pairs, the contrast loss should be high.
To address these requirements, this paper proposes a pseudo-label guided contrast loss function:
where
denotes the weight of the unlabeled sample after dynamic reweighting, which ensures the reliability of the unlabeled data. The weight
signifies the importance of the positive sample pair, emphasizing the similarity between samples of the same category, while
represents the weight of all sample pairs, balancing the similarity differences across categories and minimizing the overlap in the feature space. Additionally,
and
correspond to the positive sample pairs and all sample pairs (including both positive and negative pairs), respectively. Finally,
is the temperature parameter that controls the smoothness of the similarity distribution.
Observing
Figure 4, The Pseudo-Label Guided Contrastive (PGC) loss function uses a momentum update strategy to continuously update the features in the queue mechanism, enhancing the model’s effective use of unlabeled data. Additionally, pseudo-labels are used to partition positive and negative sample pairs, which reduces the dependence on labeled data for contrastive learning. Importantly, PGC integrates contrastive learning with the features fused by MedFuseNet, optimizing and improving the classification decision boundary, thereby enhancing classification performance.