Next Article in Journal
A Lightweight YOLOv5-Based Model with Feature Fusion and Dilation Convolution for Image Segmentation
Next Article in Special Issue
Research on a Hotel Collaborative Filtering Recommendation Algorithm Based on the Probabilistic Language Term Set
Previous Article in Journal
Single-Machine Maintenance Activity Scheduling with Convex Resource Constraints and Learning Effects
Previous Article in Special Issue
An Improved Genetic Algorithm for the Granularity-Based Split Vehicle Routing Problem with Simultaneous Delivery and Pickup
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

ProMatch: Semi-Supervised Learning with Prototype Consistency

1
School of Computer Science and Cyber Engineering, Guangzhou University, Guangzhou 510002, China
2
Institute of Artificial Intelligence and Blockchain, Guangzhou University, Guangzhou 511442, China
*
Authors to whom correspondence should be addressed.
Mathematics 2023, 11(16), 3537; https://doi.org/10.3390/math11163537
Submission received: 13 July 2023 / Revised: 7 August 2023 / Accepted: 14 August 2023 / Published: 16 August 2023
(This article belongs to the Special Issue Applications of Big Data Analysis and Modeling)

Abstract

:
Recent state-of-the-art semi-supervised learning (SSL) methods have made significant advancements by combining consistency-regularization and pseudo-labeling in a joint learning paradigm. The core concept of these methods is to identify consistency targets (pseudo-labels) by selecting predicted distributions with high confidence from weakly augmented unlabeled samples. However, they often face the problem of erroneous high confident pseudo-labels, which can lead to noisy training. This issue arises due to two main reasons: (1) when the model is poorly calibrated, the prediction of a single sample may be overconfident and incorrect, and (2) propagating pseudo-labels from unlabeled samples can result in error accumulation due to the margin between the pseudo-label and the ground-truth label. To address this problem, we propose a novel consistency criterion called Prototype Consistency (PC) to improve the reliability of pseudo-labeling by leveraging the prototype similarities between labeled and unlabeled samples. First, we instantiate semantic-prototypes (centers of embeddings) and prediction-prototypes (centers of predictions) for each category using memory buffers that store the features of labeled examples. Second, for a given unlabeled sample, we determine the most similar semantic-prototype and prediction-prototype by assessing the similarities between the features of the unlabeled sample and the prototypes of the labeled samples. Finally, instead of using the prediction of the unlabeled sample as the pseudo-label, we select the most similar prediction-prototype as the consistency target, as long as the predicted category of the most similar prediction-prototype, the ground-truth category of the most similar semantic-prototype, and the ground-truth category of the most similar prediction-prototype are equivalent. By combining the PC approach with the techniques developed by the MixMatch family, our proposed ProMatch framework demonstrates significant performance improvements compared to previous algorithms on datasets such as CIFAR-10, CIFAR-100, SVHN, and Mini-ImageNet.

1. Introduction

In the past few decades, machine leaning has demonstrated remarkable success across various visual tasks [1,2,3,4,5,6,7,8]. This success can be attributed to advancements in learning algorithms and the availability of extensive labeled datasets. However, in real-world scenarios, the construction of large labeled datasets can be costly and often impractical. Therefore, finding ways to effectively learn from a limited number of labeled data points has become a major concern. This is where semi-supervised learning (SSL) [9,10,11] comes into play. SSL is an important branch of the machine learning theory and its algorithms, which has emerged as a promising solution to address this challenge by leveraging the abundance of unlabeled data. It has proven to be a remarkable achievement in the field.
The goal of SSL is to enhance the generalization performance by leveraging the potential of unlabeled data. One widely accepted assumption, known as the Low-density Separation Assumption [12], posits that the decision boundary should typically reside in low-density regions in order to improve generalization. Building upon this assumption, two prominent paradigms have emerged: pseudo-labeling [11] and consistency regularization [13]. These approaches have gained significant popularity in the field as effective methods for leveraging unlabeled data in the pursuit of a better generalization performance. Consistency-regularization based methods have become widely adopted methods in SSL. These methods aim to maintain the stability of network outputs when presented with noisy inputs [13,14]. However, one limitation of consistency-regularization based methods is their heavy reliance on extensive data augmentations, which may restrict their effectiveness in certain domains such as videos and medical images. On the other hand, pseudo-labeling based methods are alternative approaches that have gained popularity in SSL. These methods select unlabeled samples with high confidence as training targets (pseudo-labels) [11]. One notable advantage of pseudo-labeling based methods is their simplicity, as they do not require multiple data augmentations and can be easily applied to various domains.
In recent trends, a combination of pseudo-labeling and consistency regularization has shown promising results [15,16,17,18]. The underlying idea of these methods is to train a classifier using labeled samples and use the predicted distribution as pseudo-labels for unlabeled samples. These pseudo-labels are typically generated by weakly augmented views [16,19], or by averaging predictions from multiple strongly augmented views [9]. The objective is then constructed by applying the cross-entropy loss between the pseudo-labels and the predictions obtained from different strongly augmented views. It is worth noting that the pseudo-labels are often sharpened or processed using argmax, and each instance is assigned to a specific category to further refine the learning process.
Building upon this concept, several methods have been proposed. MixMatch [9] adopts a sharpened averaged prediction from multiple strongly augmented views as the pseudo-label and incorporates the mix-up trick [15] to enhance the quality of pseudo-labels. ReMixMatch [16] further improves upon this idea by generating pseudo-labels with weakly augmented views. Additionally, it introduces a distribution alignment strategy, which promotes alignment between the pseudo-label distribution and the distribution of ground-truth class labels. FixMatch [19] simplifies these ideas by employing a confidence-based threshold to select high-confidence pseudo-labels. This method has achieved state-of-the-art performance among augmentation anchoring-based methods. SimPLE [20] enhances previous approaches by incorporating a novel unsupervised objective called Pair Loss. The primary objective is to minimize the statistical distance between high-confidence pseudo-labels that surpass a specific similarity threshold. SimMatch [21], on the other hand, simultaneously matches similarity relationships in both semantic and instance spaces for different augmentations. It further interacts between the semantic and instance pseudo-labels by employing a memory buffer for instantiation. These methods represent significant advancements in semi-supervised learning by effectively utilizing pseudo-labels and incorporating various strategies to enhance the performance of the models.
The previously mentioned methods have one thing in common: they estimate the pseudo-label depending upon the high confident predicted distributions from a single unlabeled sample. One major issue with these methods is their heavy reliance on the quality of pseudo-labeling, which limits their effectiveness when pseudo-labeling is unreliable. This problem stems from the following facts: (1) pseudo-labels evaluated from a single sample may be inaccurate, especially when the model is poorly calibrated. Erroneous high confidence predictions from a single sample can result in many incorrect pseudo-labels, leading to noisy training. (2) Propagating pseudo-labels from unlabeled samples is unstable, since there is a margin between the pseudo-label and the ground-truth label. This fact is likely to cause the error accumulation in the pseudo-labeling process.
In response to the aforementioned challenges, we aim to leverage the characteristics of prototypes (i.e., centers of representations) from labeled samples to generate more reliable pseudo-labels. To achieve this, we introduce a novel loss called the Prototype Consistency (PC) Loss that ensures the stability of label propagation and enhances the accuracy of pseudo-labels. By incorporating the PC Loss with the techniques developed by the MixMatch family [9,16,19], we propose the ProMatch algorithm, which offers an effective training framework for semi-supervised learning (SSL). The framework of ProMatch is depicted in Figure 1. Initially, we construct two sets of prototypes from the semantic space and the prediction space using weakly augmented labeled data. Next, for a given unlabeled sample, we generate the pseudo-label by selecting the most similar prediction-prototype based on the PC criterion. The PC criterion requires that the categories of the most similar semantic-prototype and prediction-prototype align with the predicted category of the most similar prediction-prototype. Finally, we formulate the PC Loss by combining the unlabeled data with strongly augmented views, and then integrate supervised and unsupervised losses from MixMatch family methods to establish the ProMatch training framework. This approach allows us to improve the reliability of the pseudo-labeling process and the accuracy of the pseudo-label by leveraging the stability and consistency of prototypes. It is worthy noting that compared to the state-of-the-art SSL methods, the principal characteristics of our proposal lie in the fact that we use prototypes of labeled samples to exploit and propagate the pseudo-labels.
Our contributions can be summarized as follows:
  • We introduce a novel loss component called PC Loss, which addresses the limitations of traditional pseudo-labeling methods. By deriving the pseudo-label from the prediction-prototype of labeled data, the PC Loss ensures more precise and stable label propagation in semi-supervised learning.
  • By integrating the PC Loss with the techniques employed in the MixMatch family methods, we establish the ProMatch training framework. This framework combines the benefits of PC Loss and the existing approaches, resulting in an improved performance in semi-supervised learning tasks.
  • Extensive experimental results demonstrate that ProMatch achieves significant performance gains over previous algorithms on popular benchmark datasets such as CIFAR-10, CIFAR-100, SVHN, and Mini-ImageNet. These results validate the effectiveness and superiority of our proposed approach in the field of SSL.

2. Related Work

2.1. Consistency Regularization

Consistency regularization is a commonly used technique in machine learning to improve the generalization ability and stability of models. Specifically, it often employs input perturbation techniques [10,22]. For example, in image classification, it is common to elastically deform or add noise to an input image, which can dramatically change the pixel content of an image without altering its label. In other words, it can artificially expand the size of a training set by generating a near-infinite stream of new, modified data. Up to now, many methods based on the pseudo-label have been proposed. For instance, in [22], it increases the variability of the data by incorporating stochastic transformations and perturbations in deep semi-supervised learning. Simultaneously, it minimizes the discrepancy between the predictions of unlabeled samples and their true labels. Temporal Ensembling [23] meets the consistency requirements by minimizing the mean square difference between the predicted probability distributions of the two data-augmented views. Mean Teacher [10] further extends this concept by replacing the aggregated predictions with the output of an exponential moving average (EMA) model. In VAT [24], consistency regularization is implemented through the introduction of virtual adversarial loss. It perturbs samples in the input space with a small noise and maximizes the adversarial loss, forcing the model to generate consistent predictions for these perturbed samples. To encapsulate, in semi-supervised learning, a classifier should output the same class distribution for an unlabeled example, whether it was augmented or not. For unlabeled points x, in the simplest case, it is achieved by adding a regularization term to the loss function as follows:
p model ( y Augment ( x ) ; θ ) p model ( y Augment ( x ) ; θ ) 2 2 .
Note that Augment(x) is a stochastic transformation, so the two terms in Equation (1) are not identical. VAT [24] computes an additive perturbation to apply to the input, which maximally changes the output class distribution. MixMatch [9] utilizes a form of consistency regularization through the use of standard data augmentation for images (random horizontal flips and crops). FixMatch [19] distinguishes between two degrees of data augmentation (weak and strong): weak augmentation uses standard data augmentation, and strong augmentation may include greater random clipping, rotation, scaling, affine transformation, etc.

2.2. Pseudo-Labeling

Pseudo-labels are artificial labels generated by the model itself; they aid the model in learning more robust and generalized representations. However, one should be cautious when using pseudo-labels, as prediction results may include errors or uncertainty, potentially introducing noise. Among pseudo-labeling-based approaches, such as [11], they conduct entropy minimization implicitly by constructing hard (one-hot) labels from high-confidence predictions on unlabeled samples. TSSDL [25] introduces confidence scores which are determined based on the density of a local neighborhood surrounding each unlabeled sample to measure the reliability of pseudo-labels. In [26], it involves training a teacher model on labeled data to generate pseudo-labels for unlabeled data. Then, a noisy student model is trained using the unlabeled data with pseudo-labels. R2-D2 [27] attempts to update the pseudo-labels through an optimization framework. It generates pseudo-labels by the decipher model during the repetitive prediction process on unlabeled data. In summation, appropriate measures can ensure the reliability of the pseudo-labeled data. It encourages us to focus on the sample with a high confidence (low entropy) that is away from the decision boundaries.

2.3. The Combination of Consistency Regularization and Pseudo-Labeling

Some methods [28,29,30,31] propose to integrate both approaches in a unified framework, which is often called the holistic approach. As one of the pioneering works, FixMatch [19] first generates a pseudo-label from the model’s prediction on the weakly-augmented instance and then encourages the prediction from the strongly-augmented instance to follow the pseudo-label. Their success inspired many variants that use, e.g., curriculum learning [30,31]. FlexMatch [30] dynamically adjusts the pre-defined threshold in a class-specific manner based on the estimated learning status of each class, which is determined by the number of confident unlabeled data samples. Dash [31] dynamically selects the unlabeled data whose loss value does not exceed a dynamic threshold at each optimization step to train learning models. These methods show high accuracy, comparable to supervised learning in a fully-labeled setting.

3. Methods

In this section, we will give a detailed explanation of our newly proposed ProMatch framework. In Section 3.1, we will give a brief definition of the semi-supervised learning (SSL) task. Then, we will introduce the generation of prototypes in Section 3.2. Next, we present the Prototype Consistency Loss (PC Loss) in Section 3.3. Finally, we combine the PC Loss into the MixMatch family methods and introduce the total objective in Section 3.4.

3.1. Preliminaries

We consider the K-class semi-supervised classification setting that define both labeled data X = { ( x b , y b ) ; b ( 1 , , B x ) } and unlabeled data U = { u b ; b ( 1 , , B u ) } to train the model f. The model f is define as f = f θ s f ϕ p which is composed of a semantic encoder f θ s followed by a prediction classifier f ϕ p . Note that θ and ϕ are the set of parameters of f θ s and f ϕ p . The weak augmentation (i.e., random crop and flip) and strong augmentation (i.e., RandAugment [32]) are represented by ω ( · ) and Ω ( · ) .

3.2. Prototype Generation

We create a set of semantic-prototypes  { c k s } k = 1 K and a set of prediction-prototypes { c k p } k = 1 K from X . The process can be seen in Figure 2. In more specific terms, we construct two buffers of memory queues, { Q k s } k = 1 K and { Q k p } k = 1 K , where each key corresponds to the class. Both Q k s and Q k p denote a memory queue for class k with the fixed size | Q k p | (which is equivalent to | Q k s | ) and store the feature points from f θ s and f ϕ p . Each c k s and c k p is calculated by averaging the feature points which are in Q k s and Q k p . Meanwhile, we update Q k s and Q k p for all k at every step by pushing new features from labeled data in the batch and discarding the eldest ones when the queue is full. The prototype represents the center of clustering for each category from labeled data, which shows more stable propagation than a single sample.

3.3. Prototype Consistency Loss

In this section, we propose the PC Loss to help select the most dependable pseudo-label. First, for a given unlabeled sample, we obtain the semantic feature z b s and prediction feature z b p by the equation z b s = f θ s ω u b and z b p = ( f ϕ p ( z b s ) ) over its weakly augmented version.
Second, we construct two similarity vectors { s i s } i = 1 K and { s i p } i = 1 K . Each element s i s / s i p represents the similarity between the semantic/prediction feature and the semantic/prediction prototype of each category, which can be represented as
s i s = z b s · c i s | z b s | · | c i s |
s i p = z b p · c i p
In Equation (2), z b s is the semantic feature of the weakly augmented unlabeled sample, and c i s is the semantic-prototype of the i-th category. { s i s } is the i-th element of similarity vector { s i s } i = 1 K , which represents the similarity between z b s and c i s . In Equation (3), z b p is the prediction feature of the weakly augmented unlabeled sample, c i p is the prediction-prototype of the i-th category. { s i p } is the i-th element of similarity vector { s i p } i = 1 K , which represents the similarity between z b p and c i p . Note that we employ the cosine measure and Bhattacharyya coefficient [33] to calculate the semantic similarity s i s and prediction similarity s i p , respectively.
Third, based on the similarity vectors, we select the most similar semantic-prototype  c ^ i s and most similar prediction-prototype c ^ j p which are regarded as the closest semantic-prototype and prediction-prototype to the semantic and prediction feature, respectively. Where i = a r g m a x { s i s } i = 1 K and j = a r g m a x { s i p } i = 1 K , i and j represent the categories corresponding to the most similar semantic-prototype and most similar prediction-prototype; { s i s } i = 1 K and { s i p } i = 1 K represent the similarity vectors over the semantic and prediction space.
Finally, we formulate the PC loss as
L PC = 1 B u b = 1 B u 1 i = j = a r g m a x c ^ j p H c ^ j p , q b s
where H ( · ) is the cross-entropy, 1 [ · ] is a filter function, c ^ j p is the most similar prediction-prototype, a r g m a x ( c ^ j p ) returns the category having a maximum prediction probability, and q b s is the prediction of the unlabeled sample with strong augmentation, i.e., q b s = f ( Ω ( u b ) ) . Equation (4) is the core of our approach, which highlights the prototype consistency criterion and the method to construct the loss function based on this criterion. Equation (4) demonstrates that the most similar prediction-prototype  c ^ j p can be considered as a pseudo-label only if the following condition is true: the category of the most similar semantic-prototype (i.e., i), the category of the most similar prediction-prototype (i.e., j), and the predicted category of the most similar prediction-prototype (i.e., a r g m a x c ^ j p ) are equivalent. This condition includes two aspects of constraints: (1) i = j indicates that the features regarding with the unlabeled samples and the prototypes regarding the labeled samples are supported to be consistent both on the semantic and prediction space, ensuring the reliability for the pseudo-labeling process. (2) j = a r g m a x c ^ j p indicates that the ground-truth category and the predicted category of the prediction-prototype should be consistent, guaranteeing the accuracy of the pseudo-label. The details are shown in Figure 3.

3.4. Total Objective

The proposed PC loss is a generic loss which can easily be coupled with other SSL algorithms. In this paper, we combine the PC loss with the MixMatch [9,16,19] family methods and present a ProMatch framework, in which the final objective can be represented as:
L t o t a l = L x + λ u L u + λ PC L PC
In Equation (5), L x is the supervised loss computed by L x = 1 B x b = 1 B x H f ω x b , y b . L u is the unsupervised loss formulated as L u = 1 B u b = 1 B u 1 m a x ( z b p τ ) H z b p , q b s . 1 ( · > τ ) is the filter function for confidence-based thresholding, τ is the threshold, and z b p = f w u b , q b s = f Ω u b . L PC is the PC Loss. λ u and λ PC are the hyperparameters that control the weights of the two losses.
The concrete algorithm of our method is shown in Algorithm 1.

4. Experiments

This section validates the effectiveness of ProMatch for semi-supervised learning tasks on three popular benchmarks of CIFAR-10 [34], CIFAR-100 [34], SVHN [35], and Mini-ImageNet. First, we introduce the standard datasets in Section 4.1. Second, the baselines related to our task are introduced in Section 4.2. Third, we list the implementation details of the total experiments in Section 4.3. Fourth, we conduct experiments on the standard datasets with different amounts of labeled data in Section 4.4. Finally, we perform ablation experiments on each component. All experiments are based on the PyTorch deep learning framework, implemented on a Linux server configured with Intel(R) Core(TM) i7-10700F CPU @ 2.90 GHz and GeForce RTX 3090. The version of CUDA is 11.2 and the version of CuDNN is 8.

4.1. Datasets

CIFAR-10 [34]: CIFAR-10 is a widely used benchmark dataset in the field of computer vision and machine learning. The dataset consists of 60,000 color images and each image size is 32 × 32 pixels. It covers a total of 10 different categories and each category contains 6000 images. CIFAR-10 is split into two subsets: a training set and a test set. The training set contains 50,000 images and the test set contains 10,000 images. We randomly select 1000 and 4000 samples from the training set as labeled data, while the rest of the training set is used as unlabeled data.
CIFAR-100 [34]: CIFAR-100 is an extension of the CIFAR-10 dataset. It consists of 60,000 color images and each image size is 32 × 32 pixels. These images cover 100 diverse categories with 600 images per class. Similar to CIFAR-10, CIFAR-100 is split into a training set and a test set. The training set contains 50,000 images and each category contains 500 images. The rest of the 10,000 images form the test set and each category contains 100 images. We randomly select 10,000 samples from the training set as labeled data, while the rest of the training set is used as unlabeled data.
SVHN [35]: SVHN consists of the images of house numbers captured from Google Street View. It contains over 600,000 color images. Each image represents a cropped image patch containing a single digit from 0 to 9. SVHN consists of 10 classes. The training set contains 73,257 images and the test set contains 26,032 images. The image size in SVHN is 32 × 32. We randomly selected 1000 and 4000 samples from the training set as labeled data.
Mini-ImageNet: Mini-ImageNet is a sub-dataset of the ImageNet Large Scale Visual Recognition Challenge (ILSVRC). It is firstly used for few-shot learning. It consists of 100 randomly selected categories from ImageNet-1K, where each category of the training set contains 600 labeled images of size 84 × 84. For the SSL evaluation, we select 500 images from each class to form the training set and the rest of the 100 images are utilized for the test set.

4.2. Baseline

We compare ProMatch with six typical baseline methods.
  • VAT [24]: maintains the consistency of the model’s output under adversarial perturbations.
  • MeanTeacher [10]: keeps the consistency between model parameters and a moving average teacher model.
  • MixMatch [9]: guesses low-entropy labels for data-augmented unlabeled examples and mixes labeled and unlabeled data using Mix-up.
  • PLCB [36]: proposes to learn from unlabeled data by generating soft pseudo-labels using the network predictions.
  • ReMixMatch [16]: introduces the distribution alignment and augmentation anchoring to upgrade MixMatch.
  • FixMatch [19]: simplifies its previous works by introducing a confidence threshold into its unsupervised objective function. For the same unlabeled sample, FixMatch encourages consistency between weak and strong augmented images.
  • SimPLE [20]: introduces a similarity threshold and focuses on the similarity among unlabeled samples.
  • FlexMatch [30]: proposes Curriculum Pseudo Labeling, a curriculum learning approach to utilize unlabeled samples according to the status of model learning.
  • DoubleMatch [37]: combines the pseudo-labeling technique with a self-supervised loss, enabling the model to utilize all unlabeled data in the training process.
  • NP-Match [38]: adjusts neural processes (NPs) to semi-supervised learning and proposes an uncertainty-guided skew-geometric JS divergence to replace the original KL divergence in NPs.
  • Bad GAN [39]: a generative-model-based SSL method, which is built up upon the assumption that good semi-supervised learning requires a bad generator.
  • Triple-GAN [40]: also a generative-model-based SSL method, which is formulated as a three-player minimax game consisting of a generator, a classifier, and a discriminator.

4.3. Implementation Details

For most experiments in ProMatch, Wide ResNet [41] is adopted as the backbone (WRN-28-2 for CIFAR10 and SVHN, WRN-28-8 for CIFAR-100) following [19]. For Mini-ImageNet, we use ResNet-18 as the backbone. We train the model by using SGD [42] with a momentum of 0.9 or AdamW [43]. The initial learning rate is 0.03 with a cosine learning rate decay schedule of η = η 0 c o s 7 π s 16 S , where η 0 is the initial learning rate, s is the current training step, and S is the total training step. For CIFAR-100 on WRN-28-2 and Mini-ImageNet on ResNet-18, we use AdamW without learning rate scheduling. In addition, we use the exponential moving average (EMA) of the network parameter for evaluation and label guessing. Note that we use an identical set of hyperparameters for all datasets: λ u = 1 , λ PC = 1 , τ = 0.95 , α = 0.75 . More specific implementation details are shown in Table 1.

4.4. Results

For all datasets, we perform labeled and unlabeled set segmentation by randomly sampling an equal number of images from all classes without replacement.
CIFAR-100: We evaluate the performance of ProMatch on CIFAR-100 with 10,000 labels under two different settings. First, we utilize Wide ResNet 28-8 [41] as the backbone and set the number of weak augmentation K to 4. We use the SGD [42] optimizer and set the weight decay to 0.001. As shown in Table 2, compared to FixMatch [19], ProMatch achieves an accuracy rate improvement from 77.40% to 78.85%. Next, we evaluate the performance of ProMatch on CIFAR-100 by using the Wide ResNet 28-2 [41] network. We change the optimizer from SGD [42] to AdamW [43]. The different hyperparameters of this experiment are set at K = 2 ,   l r = 0.02 , with EMA decay = 0.04. The experimental result shows that ProMatch significantly outperforms the baseline methods by a large margin. Compared to SimPLE [20], our method achieves an accuracy rate improvement from 70.82% to 72.71%. Moreover, ProMatch achieves nearly the best accuracy rate at about 500 epochs. It indicates that our approach is more stable and has a faster convergence rate.
CIFAR-10: We utilize the Wide ResNet 28-2 [41] network to evaluate the accuracy of ProMatch on CIFAR-10 with 1000 and 4000 labels. We use SGD [42] with a momentum of 0.9 as the optimizer. As shown in Table 3, the accuracy rate of ProMatch reaches 95.01% and 95.83% on 1000 and 4000 labeled samples, respectively. In contrast, the accuracy rates of Bad GAN are 79.37% and 85.59% on 1000 and 4000 labeled samples, while the accuracy rates of Triple-GAN-V2 are 85.00% and 89.99% on 1000 and 4000 labeled samples. It indicates that our method significantly outperforms the classical generative-model-based SSL methods. Moreover, the performance of ProMatch is superior to the recent state-of-the-art methods, such as SimPLE, FlexMatch, and DoubleMatch. Note that compared to the increment of accuracy for CIFAR-100, ProMatch increases at a relatively small accuracy rate for CIFAR-10. This is because CIFAR-10 is a comparatively simple dataset compared to CIFAR-100. Many mainstream semi-supervised algorithms such as FixMatch and SimPLE have already achieved desirable results on CIFAR-10. The prediction accuracy rates of these methods using partial labeled samples (i.e., semi-supervised learning) are pretty close to the accuracy rates of the methods using the whole labeled samples (i.e., supervised learning). This means that the advancement space of the semi-supervised learning methods for the CIFAR-10 dataset is limited, since the accuracy rate derived by the fully supervised benchmark can be considered as the theoretical upper bound for the performance of the semi-supervised learning. All the same, the accuracy rate obtained by our proposal has reached 95.83% under 4000 labeled samples of CIFAR-10, which is not only the highest among the compared methods and but is also very near the upper bound.
SVHN: We use the Wide ResNet 28-2 [41] network and set the same hyperparameters as CIFAR-10 to evaluate the performance of ProMatch on SVHN with 1000 and 4000 labels. As shown in Table 3, ProMatch achieves 97.79% and 97.88% Top-1 accuracy rates for 1000 and 4000 labeled samples, respectively. Compared to Bad GAN and Triple-GAN-V2, our method achieves an accuracy rate improvement distinctly. Despite the outstanding performance of the baseline methods on the SVHN dataset, ProMatch achieves a higher accuracy rate compared to them.
Mini-ImageNet: We use the ResNet-18 network to assess the performance of ProMatch on a more complex dataset Mini-ImagNet. The number of labeled samples is 4000. We set the learning rate and weight decay to 0.02. Moreover, to elucidate the effect of erroneous high confident pseudo-labels, we recorded the error rates of pseudo-labels whose confidence threshold is τ . The results are shown in Table 4. From the table, we can observe that the error rates of pseudo-labels display inverse correlations with the test accuracy rates of the semi-supervised learning model. Among the comparisons, our method has the lowest error rate 13.21 % of pseudo-labels and the highest accuracy rate 50.83 % of the test dataset, indicating that the proposed ProMatch can effectively increase the accuracy of pseudo-labels and boost the performance of semi-supervised learning. Note that compared to SimPLE, the accuracy rate of our method increases by 1.44 % , which demonstrates the effectiveness of our proposal for the complex dataset.

4.5. Ablation Study

We present extensive ablation studies to verify the effect of different components. We conduct these ablation studies on CIFAR-10 with 10,000 labels and use Wide ResNet 28-2 as the backbone. The reasons are as follows: (1) CIFAR-100 covers a certain number of classes and the image size is small; (2) the Wide ResNet 28-2 network is fast enough to train compared to Wide ResNet-28-8.
The Effect of the Weak Augmentation Number K. We exhibit the result of the predicted accuracy rate under different numbers of weak augmentation. The results are shown in Table 5. As shown from the first row and the second row, we compare the accuracy rate by varying the number of K in ProMatch. Compared to ProMatch on K = 2 , ProMatch on K = 7 achieves an accuracy rate improvement from 72.71% to 73.38%. As shown from the third row and the fourth row, we change the number of K in ProMatch without the PC Loss. ProMatch on K = 7 and without the PC Loss reaches an accuracy rate of 69.94%, while the accuracy rate of ProMatch on K = 2 and without the PC Loss is only 69.07%. These experimental results demonstrate that increasing K from 2 to 7 can improve the accuracy rate.
The Effect of PC Loss. To demonstrate the effect of the PC Loss, we conducted three experiments. The results are shown in Table 5. First, as shown from the first row and the third row, we validate the performance of ProMatch and ProMatch without the PC Loss. ProMatch without the PC Loss only reaches an accuracy of 69.07%. Compared to ProMatch, the accuracy rate of ProMatch without the PC Loss decreases by 3.61%. Second, as shown from the second row and the fourth row, we increase the number of K from 2 to 7. In this case, we compare the accuracy rate between ProMatch and ProMatch without the PC Loss. The accuracy rate of ProMatch without the PC Loss is 69.94%, which is lower than ProMatch by 3.44%. Finally, as shown from the fifth row and the sixth row, we change the augmentation type from RandAugment [32] to the fixed augmentation. Based on this, we compare the accuracy rate between ProMatch and ProMatch without the PC Loss. ProMatch reaches an accuracy rate of 67.66%. The accuracy rate of ProMatch without the PC Loss is 67.41%. We find that the accuracy rate of ProMatch is improved as relatively small.
To further demonstrate the effect of the PC Loss, we evaluate the recall rate, Top-1 accuracy rate, and the F1 score on CIFAR-100. Due to the large number of categories in CIFAR-100, we randomly select seven categories to present their results. As shown in Figure 4a,b, ProMatch improves the precision of the classes while achieving an enhanced recall rate. As shown in Figure 4c, ProMatch reaches a higher F1 Score, which further demonstrates that the introduction of the PC Loss has improved the algorithm’s overall performance. Besides, we visualize the t-SNE of the features on ProMatch without the PC Loss and ProMatch. As shown in Figure 5, ProMatch exhibits a better clustering performance and enhances the accuracy of classification. Finally, we visualize the Top-1 and Top-5 accuracy rates on the test dataset and validation dataset. As shown in Figure 6, ProMatch exhibits a rapid convergence speed on both the test dataset and validation dataset. Compared to ProMatch without the PC Loss, ProMatch obtains a high accuracy rate in a short period of time.
Algorithm 1 ProMatch algorithm.
 1:
Input: A batch of labeled data X = x b , y b b = 1 B x and unlabeled data U u b b = 1 B u , network for semantic encoder f θ s · , and prediction classifier f ϕ p · , buffer of memory queue { Q k s } k = 1 K and { Q k p } k = 1 K . Weak-augmentation ω · , strong-augmentation Ω · , Cosine similarity function σ · , Bhattacharyya coefficient b h a · , the number of weak-augmentation K, cross-entropy loss function H ( · ) .
 2:
Prototype Generation:
 3:
    f θ s x b Q k s { c k s } k = 1 K
▹ Generate semantic-prototype and store by label k.
 4:
    f θ s f ϕ p x b Q k p { c k p } k = 1 K
▹ Generate prediction-prototype and store by label k.
 5:
Prototype Consistency:
 6:
for b = 1 to B u  do
 7:
     u ˜ b = ω ( u b )
▹ Apply weak data augmentation to u b
 8:
     u ^ b = Ω ( u b )
▹ Apply strong data augmentation to u b
 9:
     z b s = f θ s ( u ˜ b )
▹ Compute semantic feature across u ˜ b using EMA
10:
     z b p = 1 K k = 1 K f ϕ p ( z b s )
▹ Compute average predictions across u ˜ b using EMA
11:
     z b p = S h a r p e n ( z b p , T )
▹ Apply temperature sharpening to the average prediction
12:
     q b s = f ϕ p f θ s u ^ b
▹ Compute predictions across u ^ b using EMA
13:
     s i s i = 1 K = σ ( z b s , { c k s } k = 1 K )
▹ Similarity between z b s and { c k s } k = 1 K
14:
     s i p i = 1 K = b h a ( z b p , { c k p } k = 1 K )
▹ Similarity between z b p and { c k s } k = 1 K
15:
end for
16:
    c ^ i s = c a r g m a x s i s i = 1 K s
▹ Obtain the most similar semantic-prototype
17:
    c ^ j p = c a r g m a x s i p i = 1 K p
▹ Obtain the most similar prediction-prototype
18:
Loss:
19:
    L x = 1 B x b = 1 B x H ( y b , f ( ω ( x b ) ) )
20:
    L u = 1 B u b = 1 B u 1 max z b p τ · H z b p , q b s
21:
    L PC = 1 B u b = 1 B u 1 i = j = a r g m a x c ^ j p H c ^ j p , q b s
22:
return     L x + λ u L u + λ PC L PC
▹ Total loss
The Effect of Different Augmentation Strategy. To verify the effect of the different augmentation strategies, we simply replace RandAugment [32] with the fixed augmentation. The results are shown in Table 5. First, as shown from the first row and the fifth row, we compare the accuracy rate between ProMatch with fixed augmentation and ProMatch with RandAugment. ProMatch with fixed augmentation only reaches an accuracy rate of 67.66%. The accuracy rate is much lower than ProMatch with RandAugment. Second, as shown from the third row and the sixth row, we evaluate the effect of different augmentation strategies on ProMatch without the PC Loss. The accuracy rate of ProMatch with RandAugment and without the PC Loss is 69.07%, while ProMatch with fixed augmentation and without the PC Loss only achieves a 67.41% accuracy rate. These results demonstrate that RandAugment [32] improves the model’s robustness to different samples and alleviates the overfitting issues.
The Effect of Hyperparameters. In this section, we evaluate the effect of several hyperparameters such as the confidence threshold τ , PC loss coefficient λ PC , and memory buffer size Q k p under different values. The results are shown in Table 5 and Figure 7.
  • The first row and 7th–9th row of Table 5 show the accuracy rates of our proposal ProMatch under different values of the confidence threshold τ = 0 , 0.5 , 0.75 , 0.95 . We also plot these results in Figure 7a. From these results, we can observe that the accuracy rates of our method have a positive correlation with the confidence threshold. Specifically, ProMatch achieves the optimal accuracy rate of 72.71% at the confidence threshold τ = 0.95 .
  • The first row and 11th–13th row of Table 5 show the accuracy rates of ProMatch under different PC loss coefficients λ PC = 0.5 , 1 , 2 , 3 , 4 . We also plot these results in Figure 7b. From these results, we can observe that ProMatch achieves the optimal accuracy rate at λ PC = 4 .
  • The first row and 14th–17th row of Table 5 show the accuracy rates of ProMatch under different buffer sizes Q k p = 3 , 4 , 5 , 10 , 20 . We also plot these results in Figure 7c. These results demonstrates that ProMatch achieves the optimal accuracy rate at Q k p = 4 .

5. Conclusions

This paper proposed a new framework ProMatch to improve the performance of SSL. ProMatch enhances the reliability and accuracy of pseudo-labeling by considering the PC, which maintains the consistency among the ground-truth category of the semantic-prototype, the ground-truth category of the prediction-prototype, and the predicted category of the prediction-prototype. According to the PC criterion, we formulated a new PC loss by selecting the most-similar prediction-prototype that is closest to the predicted distribution for the unlabeled sample as the pseudo-label. Extensive experiments are conducted on CIFAR-10, CIFAR-100, SVHN, and Mini-ImageNet datasets. The results demonstrate the effectiveness of our proposal. In addition, it should be noted that the performance of ProMatch is mainly dependent upon the reliability of the prototypes regarding the labeled samples. Thus, our proposal may not be adept at dealing with some specific tasks such as the imbalanced category task and fine-grained recognition task, where the prototypes of different categories are extremely indistinct. In our future work, we will devote to making our method more robust against various specific circumstances.

Author Contributions

Z.C. contributed to the conception of the study, performed the experiments, and wrote the manuscript. X.W. contributed significantly to the validation of ideas and the revision of the manuscript. J.L. contributed to the design of the experiments and the analysis of the data. All authors have read and agreed to the published version of the manuscript.

Funding

This work was supported by the National Natural Science Foundation of China (No. 62072127, No. 62002076, No. 61906049), Natural Science Foundation of Guangdong Province (No. 2023A1515011774, No. 2020A1515010423), Project 6142111180404 supported by CNKLSTISS, and the Scientific research project for Guangzhou University (No. RC2023031).

Data Availability Statement

Not applicable.

Conflicts of Interest

The authors declared that they have no conflicts of interest in this work. We declare that we do not have any commercial or associative interests that represent a conflict of interest in connection with the work submitted.

References

  1. Deng, J.; Dong, W.; Socher, R.; Li, L.J.; Li, K.; Fei-Fei, L. Imagenet: A large-scale hierarchical image database. In Proceedings of the 2009 IEEE Conference on Computer Vision and Pattern Recognition, IEEE, Miami, FL, USA, 20–25 June 2009; pp. 248–255. [Google Scholar]
  2. Everingham, M.; Van Gool, L.; Williams, C.K.; Winn, J.; Zisserman, A. The pascal visual object classes (voc) challenge. Int. J. Comput. Vis. 2010, 88, 303–338. [Google Scholar] [CrossRef]
  3. Girshick, R. Fast r-cnn. In Proceedings of the IEEE international Conference on Computer Vision, Santiago, Chile, 7–13 December 2015; pp. 1440–1448. [Google Scholar]
  4. He, K.; Gkioxari, G.; Dollár, P.; Girshick, R. Mask r-cnn. In Proceedings of the IEEE International Conference on Computer Vision, Venice, Italy, 22–29 October 2017; pp. 2961–2969. [Google Scholar]
  5. Lin, T.Y.; Maire, M.; Belongie, S.; Hays, J.; Perona, P.; Ramanan, D.; Dollár, P.; Zitnick, C.L. Microsoft coco: Common objects in context. In Proceedings of the Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, 6–12 September 2014; Proceedings, Part V 13. Springer: Berlin/Heidelberg, Germany, 2014; pp. 740–755. [Google Scholar]
  6. Su, X.; Huang, T.; Li, Y.; You, S.; Wang, F.; Qian, C.; Zhang, C.; Xu, C. Prioritized architecture sampling with monto-carlo tree search. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Nashville, TN, USA, 20–25 June 2021; pp. 10968–10977. [Google Scholar]
  7. Tang, K.; Ma, Y.; Miao, D.; Song, P.; Gu, Z.; Tian, Z.; Wang, W. Decision fusion networks for image classification. IEEE Trans. Neural Netw. Learn. Syst. 2022, 2022, 1–14. [Google Scholar] [CrossRef]
  8. Zhu, P.; Hong, J.; Li, X.; Tang, K.; Wang, Z. SGMA: A novel adversarial attack approach with improved transferability. Complex Intell. Syst. 2023, 1–13. [Google Scholar] [CrossRef]
  9. Berthelot, D.; Carlini, N.; Goodfellow, I.; Papernot, N.; Oliver, A.; Raffel, C.A. Mixmatch: A holistic approach to semi-supervised learning. Adv. Neural Inf. Process. Syst. 2019, 32, 5049–5059. [Google Scholar]
  10. Tarvainen, A.; Valpola, H. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Adv. Neural Inf. Process. Syst. 2017, 30, 1195–1204. [Google Scholar]
  11. Lee, D.H. Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks. In Proceedings of the Workshop on challenges in representation learning, ICML, Atlanta, GA, USA, 20–21 June 2013; Volume 3, p. 896. [Google Scholar]
  12. Chapelle, O.; Zien, A. Semi-supervised classification by low density separation. In Proceedings of the International Workshop on Artificial Intelligence and Statistics, PMLR, Bridgetown, Barbados, 6–8 January 2005; pp. 57–64. [Google Scholar]
  13. Verma, V.; Kawaguchi, K.; Lamb, A.; Kannala, J.; Solin, A.; Bengio, Y.; Lopez-Paz, D. Interpolation consistency training for semi-supervised learning. Neural Netw. 2022, 145, 90–106. [Google Scholar] [CrossRef] [PubMed]
  14. Tang, K.; Shi, Y.; Lou, T.; Peng, W.; He, X.; Zhu, P.; Gu, Z.; Tian, Z. Rethinking perturbation directions for imperceptible adversarial attacks on point clouds. IEEE Internet Things J. 2022, 10, 5158–5169. [Google Scholar] [CrossRef]
  15. Zhang, H.; Cisse, M.; Dauphin, Y.N.; Lopez-Paz, D. mixup: Beyond empirical risk minimization. arXiv 2017, arXiv:1710.09412. [Google Scholar]
  16. Berthelot, D.; Carlini, N.; Cubuk, E.D.; Kurakin, A.; Sohn, K.; Zhang, H.; Raffel, C. Remixmatch: Semi-supervised learning with distribution alignment and augmentation anchoring. arXiv 2019, arXiv:1911.09785. [Google Scholar]
  17. Kim, D.J.; Choi, J.; Oh, T.H.; Yoon, Y.; Kweon, I.S. Disjoint multi-task learning between heterogeneous human-centric tasks. In Proceedings of the 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), IEEE, Lake Tahoe, NV, USA, 12–15 March 2018; pp. 1699–1708. [Google Scholar]
  18. Kuo, C.W.; Ma, C.Y.; Huang, J.B.; Kira, Z. Featmatch: Feature-based augmentation for semi-supervised learning. In Proceedings of the Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, 23–28 August 2020; Proceedings, Part XVIII 16. Springer: Berlin/Heidelberg, Germany, 2020; pp. 479–495. [Google Scholar]
  19. Sohn, K.; Berthelot, D.; Carlini, N.; Zhang, Z.; Zhang, H.; Raffel, C.A.; Cubuk, E.D.; Kurakin, A.; Li, C.L. Fixmatch: Simplifying semi-supervised learning with consistency and confidence. Adv. Neural Inf. Process. Syst. 2020, 33, 596–608. [Google Scholar]
  20. Hu, Z.; Yang, Z.; Hu, X.; Nevatia, R. Simple: Similar pseudo label exploitation for semi-supervised classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Virtual, 19–25 June 2021; pp. 15099–15108. [Google Scholar]
  21. Zheng, M.; You, S.; Huang, L.; Wang, F.; Qian, C.; Xu, C. SimMatch: Semi-Supervised Learning With Similarity Matching. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), New Orleans, LA, USA, 18–24 June 2022; pp. 14471–14481. [Google Scholar]
  22. Sajjadi, M.; Javanmardi, M.; Tasdizen, T. Regularization with stochastic transformations and perturbations for deep semi-supervised learning. Adv. Neural Inf. Process. Syst. 2016, 29, 1171–1179. [Google Scholar]
  23. Laine, S.; Aila, T. Temporal ensembling for semi-supervised learning. arXiv 2016, arXiv:1610.02242. [Google Scholar]
  24. Miyato, T.; Maeda, S.i.; Koyama, M.; Ishii, S. Virtual adversarial training: A regularization method for supervised and semi-supervised learning. IEEE Trans. Pattern Anal. Mach. Intell. 2018, 41, 1979–1993. [Google Scholar] [CrossRef] [PubMed]
  25. Shi, W.; Gong, Y.; Ding, C.; Ma, Z.; Tao, X.; Zheng, N. Transductive Semi-Supervised Deep Learning Using Min-Max Features. In Proceedings of the Computer Vision—ECCV 2018, Munich, Germany, 8–14 September 2018; pp. 311–327. [Google Scholar]
  26. Xie, Q.; Luong, M.T.; Hovy, E.; Le, Q.V. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Seattle, WA, USA, 13–19 June 2020; pp. 10687–10698. [Google Scholar]
  27. Wang, G.H.; Wu, J. Repetitive Reprediction Deep Decipher for Semi-Supervised Learning. In Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence, New York, NY, USA, 7–12 February 2020. [Google Scholar]
  28. Rizve, M.N.; Duarte, K.; Rawat, Y.S.; Shah, M. In defense of pseudo-labeling: An uncertainty-aware pseudo-label selection framework for semi-supervised learning. arXiv 2021, arXiv:2101.06329. [Google Scholar]
  29. Xie, Q.; Dai, Z.; Hovy, E.; Luong, T.; Le, Q. Unsupervised data augmentation for consistency training. Adv. Neural Inf. Process. Syst. 2020, 33, 6256–6268. [Google Scholar]
  30. Zhang, B.; Wang, Y.; Hou, W.; Wu, H.; Wang, J.; Okumura, M.; Shinozaki, T. Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling. Adv. Neural Inf. Process. Syst. 2021, 34, 18408–18419. [Google Scholar]
  31. Xu, Y.; Shang, L.; Ye, J.; Qian, Q.; Li, Y.F.; Sun, B.; Li, H.; Jin, R. Dash: Semi-supervised learning with dynamic thresholding. In Proceedings of the International Conference on Machine Learning, PMLR, Virtual, 18–24 July 2021; pp. 11525–11536. [Google Scholar]
  32. Cubuk, E.D.; Zoph, B.; Shlens, J.; Le, Q.V. Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, Seattle, WA, USA, 14–19 June 2020; pp. 702–703. [Google Scholar]
  33. Bhattacharyya, A. On a measure of divergence between two multinomial populations. Sankhyā Indian J. Stat. 1946, 401–406. [Google Scholar]
  34. Krizhevsky, A.; Hinton, G. Learning Multiple Layers of Features from Tiny Images; Technical Report TR-2009; University of Toronto: Toronto, ON, Canada, 2009; Volume 5, p. 6. [Google Scholar]
  35. Netzer, Y.; Wang, T.; Coates, A.; Bissacco, A.; Wu, B.; Ng, A.Y. Reading Digits in Natural Images with Unsupervised Feature Learning. In Proceedings of the NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011, Granada, Spain, 12–17 December 2011. [Google Scholar]
  36. Arazo, E.; Ortego, D.; Albert, P.; O’Connor, N.E.; McGuinness, K. Pseudo-labeling and confirmation bias in deep semi-supervised learning. In Proceedings of the 2020 International Joint Conference on Neural Networks (IJCNN), Glasgow, UK, 19–24 July 2020; pp. 1–8. [Google Scholar]
  37. Wallin, E.; Svensson, L.; Kahl, F.; Hammarstrand, L. Doublematch: Improving semi-supervised learning with self-supervision. In Proceedings of the 2022 26th International Conference on Pattern Recognition (ICPR), Montreal, QC, Canada, 21–25 August 2022; pp. 2871–2877. [Google Scholar]
  38. Wang, J.; Lukasiewicz, T.; Massiceti, D.; Hu, X.; Pavlovic, V.; Neophytou, A. Np-match: When neural processes meet semi-supervised learning. In Proceedings of the International Conference on Machine Learning, PMLR, Baltimore, MD, USA, 17–23 July 2022; pp. 22919–22934. [Google Scholar]
  39. Dai, Z.; Yang, Z.; Yang, F.; Cohen, W.W.; Salakhutdinov, R.R. Good semi-supervised learning that requires a bad gan. Adv. Neural Inf. Process. Syst. 2017, 30, 6510–6520. [Google Scholar]
  40. Li, C.; Xu, T.; Zhu, J.; Zhang, B. Triple generative adversarial nets. Adv. Neural Inf. Process. Syst. 2017, 30, 4088–4098. [Google Scholar]
  41. Zagoruyko, S.; Komodakis, N. Wide residual networks. arXiv 2016, arXiv:1605.07146. [Google Scholar]
  42. Bottou, L. Stochastic gradient descent tricks. In Neural Networks Tricks of the Trade: Second Edition; Springer: Berlin/Heidelberg, Germany, 2012; pp. 421–436. [Google Scholar]
  43. Loshchilov, I.; Hutter, F. Decoupled weight decay regularization. arXiv 2017, arXiv:1711.05101. [Google Scholar]
Figure 1. An overview of the ProMatch algorithm. ProMatch optimizes the classification network with three training objectives: (1) supervised loss L x for augmented labeled data; (2) unsupervised loss L u that aligns the strongly augmented unlabeled data with pseudo labels generated from weakly augmented data; (3) PC Loss L PC that minimizes the statistical distance between the prediction-prototype and the predictions of strongly augmented data.
Figure 1. An overview of the ProMatch algorithm. ProMatch optimizes the classification network with three training objectives: (1) supervised loss L x for augmented labeled data; (2) unsupervised loss L u that aligns the strongly augmented unlabeled data with pseudo labels generated from weakly augmented data; (3) PC Loss L PC that minimizes the statistical distance between the prediction-prototype and the predictions of strongly augmented data.
Mathematics 11 03537 g001
Figure 2. The process of prototype generation. We maintain two buffers to store the features of labeled samples. By clustering the features, we construct a set of the semantic-prototype and a set of the prediction-prototype.
Figure 2. The process of prototype generation. We maintain two buffers to store the features of labeled samples. By clustering the features, we construct a set of the semantic-prototype and a set of the prediction-prototype.
Mathematics 11 03537 g002
Figure 3. PC criterion. We keep the predicted category of the most similar prediction prototype, the ground-truth categories of the most similar semantic prototype, and the most similar prediction prototype equivalent.
Figure 3. PC criterion. We keep the predicted category of the most similar prediction prototype, the ground-truth categories of the most similar semantic prototype, and the most similar prediction prototype equivalent.
Mathematics 11 03537 g003
Figure 4. The recall rate, Top-1 accuracy rate, and F1 score of ProMatch and ProMatch without PC Loss. The blue bar represents our method ProMatch and the orange bar represents ProMatch with PC Loss.
Figure 4. The recall rate, Top-1 accuracy rate, and F1 score of ProMatch and ProMatch without PC Loss. The blue bar represents our method ProMatch and the orange bar represents ProMatch with PC Loss.
Mathematics 11 03537 g004
Figure 5. The T-SNE visualization of features for ProMatch without PC Loss (left) and ProMatch (right) on CIFAR-100.
Figure 5. The T-SNE visualization of features for ProMatch without PC Loss (left) and ProMatch (right) on CIFAR-100.
Mathematics 11 03537 g005
Figure 6. The variation trends of the Top-1 and Top-5 accuracy rates. We show the trends on the test and validation set. Note that the red line represents ProMatch and the blue line represents ProMatch without PC Loss.
Figure 6. The variation trends of the Top-1 and Top-5 accuracy rates. We show the trends on the test and validation set. Note that the red line represents ProMatch and the blue line represents ProMatch without PC Loss.
Mathematics 11 03537 g006
Figure 7. The accuracy rates of our method under different values of hyperparameters.
Figure 7. The accuracy rates of our method under different values of hyperparameters.
Mathematics 11 03537 g007
Table 1. Hyperparameters of the experiments for CIFAR-10, SVHN, CIFAR-100, and Mini-ImagNet, \  represents no use.
Table 1. Hyperparameters of the experiments for CIFAR-10, SVHN, CIFAR-100, and Mini-ImagNet, \  represents no use.
CIFAR-10SVHNCIFAR-100CIFAR-100Mini-ImageNet
τ c 0.95
λ u 1
λ p c 1
Q k p 4
l r 0.0030.02
K7427
T0.5
α 0.75
weight decay0.00050.0010.040.02
batch size6416
EMA decay0.999
backbone networkWRN 28-2WRN 28-8WRN 28-2ResNet 18
optimizerSGDAdamW
momentum0.9\
lr schedulercosine decay\
lr decay rate 7 π 16 \
Table 2. Top-1 test accuracy of the experiments for CIFAR-100. * : using SimPLE [20]’s implementation. : reported in FixMatch [19].
Table 2. Top-1 test accuracy of the experiments for CIFAR-100. * : using SimPLE [20]’s implementation. : reported in FixMatch [19].
CIFAR-100
Method10,000 LabelsBackbone
MixMatch [9] * 64.01%WRN 28-2
MixMatch Enhanced67.12%WRN 28-2
SimPLE [20]70.82%WRN 28-2
ProMatch72.71%WRN 28-2
MixMatch [9]71.69%WRN 28-8
ReMixMatch [16]76.97%WRN 28-8
FixMatch [19]77.40%WRN 28-8
SimPLE [20]78.11%WRN 28-8
FlexMatch [30]78.10%WRN 28-8
DoubleMatch [37]78.78%WRN 28-8
NP-Match [38]78.78%WRN 28-8
ProMatch78.85%WRN 28-8
Table 3. Top-1 test accuracy rate of the experiments for CIFAR-10 and SVHN. Bad GAN and Triple-GAN-V2 use CNN-13 as the backbone, while others use Wide ResNet 28-2. : The accuracy is reported in ReMixMatch [16] and using its own implementation. : Fully supervised baseline using all the labels and simple augmentation (flip-and-crop).
Table 3. Top-1 test accuracy rate of the experiments for CIFAR-10 and SVHN. Bad GAN and Triple-GAN-V2 use CNN-13 as the backbone, while others use Wide ResNet 28-2. : The accuracy is reported in ReMixMatch [16] and using its own implementation. : Fully supervised baseline using all the labels and simple augmentation (flip-and-crop).
CIFAR-10SVHN
Method1000 Labels4000 Labels1000 Labels4000 Labels
VAT [24]81.36%88.95%94.02%95.80%
MeanTeacher [10]82.68%89.64%96.25%96.61%
Bad GAN [39]79.37%85.59%95.75%96.03%
MixMatch [9]92.25%93.76%96.73%97.11%
ReMixMatch [16]94.27%94.86%97.17%97.58%
Triple-GAN-V2 [40]85.00%89.99%96.55%96.92%
FixMatch [19]95.69%97.64%
SimPLE [20]94.84%94.95%97.54%97.31%
FlexMatch [30]-95.81%93.28%-
DoubleMatch [37]-95.35%97.90%-
ProMatch95.01%95.83%97.79%97.88%
Fully Supervised 96.23%98.17%
Table 4. Experimental results on Mini-ImageNet with 4000 labels. We report the Top-1 test accuracy rate and the error rate of pseudo-labels whose confidence τ .
Table 4. Experimental results on Mini-ImageNet with 4000 labels. We report the Top-1 test accuracy rate and the error rate of pseudo-labels whose confidence τ .
Mini-ImageNet
Method4000 LabelsError Rate of Pseudo-Labels
Mean Teacher [10]27.49%38.87%
PLCB [36]43.51%22.63%
MixMatch [9]48.46%18.32%
FixMatch [19]50.21%14.77%
SimPLE [20]49.39%15.35%
ProMatch50.83%13.21%
Table 5. Ablation experiment on CIFAR-100. All experiments use WRN 28-2. w/o represents without PC Loss.
Table 5. Ablation experiment on CIFAR-100. All experiments use WRN 28-2. w/o represents without PC Loss.
Ablations: CIFAR-100
AblationAugmentation Type λ PC τ c K Q k p Accuracy
ProMatchRandAugment10.952472.71%
ProMatchRandAugment10.957473.38%
w/o L PC RandAugment00.952469.07%
w/o L PC RandAugment00.957469.94%
w/o RandAugmentfixed10.952467.66%
w/o RandAugment, w/o L PC fixed00.952467.41%
τ = 0RandAugment102471.54%
τ = 0.5RandAugment10.52471.53%
τ = 0.75RandAugment10.752471.89%
λ PC = 0.5RandAugment0.50.952472.55%
λ PC = 2RandAugment20.952471.81%
λ PC = 3RandAugment30.952472.42%
λ PC = 4RandAugment40.952472.76%
Q k p = 3RandAugment10.952371.83%
Q k p = 5RandAugment10.952572.01%
Q k p = 10RandAugment10.9521071.82%
Q k p = 20RandAugment10.9522072.13%
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Cheng, Z.; Wang, X.; Li, J. ProMatch: Semi-Supervised Learning with Prototype Consistency. Mathematics 2023, 11, 3537. https://doi.org/10.3390/math11163537

AMA Style

Cheng Z, Wang X, Li J. ProMatch: Semi-Supervised Learning with Prototype Consistency. Mathematics. 2023; 11(16):3537. https://doi.org/10.3390/math11163537

Chicago/Turabian Style

Cheng, Ziyu, Xianmin Wang, and Jing Li. 2023. "ProMatch: Semi-Supervised Learning with Prototype Consistency" Mathematics 11, no. 16: 3537. https://doi.org/10.3390/math11163537

APA Style

Cheng, Z., Wang, X., & Li, J. (2023). ProMatch: Semi-Supervised Learning with Prototype Consistency. Mathematics, 11(16), 3537. https://doi.org/10.3390/math11163537

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop