FocalMatch: Mitigating Class Imbalance of Pseudo Labels in Semi-Supervised Learning

: Semi-supervised learning (SSL) is a popular research area in machine learning which utilizes both labeled and unlabeled data. As an important method for the generation of artiﬁcial hard labels for unlabeled data, the pseudo-labeling method is introduced by applying a high and ﬁxed threshold in most state-of-the-art SSL models. However, early models prefer certain classes that are easy to learn, which results in a high-skewed class imbalance in the generated hard labels. The class imbalance will lead to less effective learning of other minority classes and slower convergence for the training model. The aim of this paper is to mitigate the performance degradation caused by class imbalance and gradually reduce the class imbalance in the unsupervised part. To achieve this objective, we propose FocalMatch, a novel SSL method that combines FixMatch and focal loss. Our contribution of FocalMatch adjusts the loss weight of various data depending on how well their predictions match up with their pseudo labels, which can accelerate system learning and model convergence and achieve state-of-the-art performance on several semi-supervised learning benchmarks. Particularly, its effectiveness is demonstrated with the dataset that has extremely limited labeled data.


Introduction
Machine learning (ML) is one of the most important and popular fields in artificial intelligence. The core concept of ML is about the data-driven model [1]. It has evolved rapidly in recent decades due to the explosive growth in the amount of available data and the increase in computational power. The main feature of machine learning is to automatically improve performance through experience [1,2]. Due to this feature, machine learning has rapidly become the fundamental technique of many modern applications, including computer vision, natural language processing (NLP), fraud detection, medical analysis (both physically and mentally), the agriculture industry, the energy sector, mechanical engineering, network security, etc. [3][4][5][6][7][8][9][10][11][12]. Machine learning has spawned many branches, such as supervised learning and unsupervised learning. The main difference between supervised and unsupervised learning is whether or not labeled data is used. Supervised learning usually achieves better performance than unsupervised learning on the same task by leveraging valuable information from labeled data.
Although the amount of available data has dramatically increased over the last few decades, labeled data still represents a small fraction. Labeling data is either complicated, costly (time-consuming and/or expensive), or both. For example, some samples can only be labeled by the expert, such as the medical analysis. Furthermore, as the number of available data increases, people are more and more concerned about data privacy. Even if a significant number of labels have previously been collected, it still remains unknown if the labels will be available for model learning, or if more attention will be needed to handle the data [13]. Therefore, it is critical for the model to generate artificial labels for unlabeled data instead of manually labeling the data due to privacy concerns. The lack of labeled data has given rise to a different research area called semi-supervised learning (SSL). SSL uses datasets that contain only a small amount of labeled data and a considerable amount of unlabeled data. By leveraging massive volumes of unlabeled data, SSL can significantly improve the model performance with much less labeled data.
In recently proposed semi-supervised learning frameworks, pseudo-labeling [14] has been widely used. Pseudo-labeling is based on the assumption that the learning model should generate hard labels for unlabeled data on its own (i.e., through the predicted class distributions) and then use these generated labels as targets for unlabeled data. FixMatch [15] is a state-of-the-art semi-supervised learning method that produces pseudo (one-hot) labels from weakly augmented samples and utilizes the cross-entropy loss to ensure the consistencies between pseudo labels and the predictions of the same samples (strongly augmented). The generated pseudo labels of unlabeled data help FixMatch to achieve entropy minimization [16] in the unsupervised learning part. However, FixMatch and other semi-supervised learning methods that utilize pseudo-labeling have a tendency to assign pseudo labels to certain classes, particularly in the initial stages of the training process. This introduces the problem of class imbalance. The overall loss can be dominated by classes with a large number of pseudo labels. Hence, the model can only learn useful information from the majority classes of pseudo labels and ignores other classes. As a result, most classes of pseudo labels usually become the easy samples (with high prediction accuracy), and other classes become the hard samples. This further exacerbates the class imbalance issue since the model rarely gives high prediction confidence for hard samples. Hence, the pseudo labels of hard samples are less likely to be considered valid pseudo labels. Therefore, the corresponding loss item will not be added to the overall unsupervised loss.
In this paper, we propose a new method FocalMatch that combines FixMatch with focal loss [17] to address the problem of class imbalance that occurs in the unsupervised learning part. In light of the systematic analysis of the unsupervised learning part in FixMatch, the focal loss is proposed to enhance the learning process of FocalMatch by providing the capability of loss contribution adjustment for different samples. This is particularly achieved by the automatic evaluation of how close their predictions are to their pseudo labels. As a result, the overall unsupervised loss will not be overwhelmed by a large number of easy samples. The workflow of FocalMatch is shown in Figure 1. Experiments show that FocalMatch significantly reduces the difference in the total number of pseudo labels generated for each class and provides a smoother learning curve in comparison with other state-of-the-art models.
In brief, the main contributions of this work can be summarized in the following sections: • We propose FocalMatch, a novel but simple semi-supervised learning method that combines FixMatch and focal loss, which effectively mitigates the performance degradation caused by class imbalance and gradually reduces class imbalance that occurs in the unsupervised learning part when generating pseudo labels. • FocalMatch adjusts the loss weights of different unlabeled data based on the proximity of their predictions to their pseudo labels. Hence, the loss will not be overwhelmed by easy samples. Thus, the model can effectively learn valuable information from all classes. • FocalMatch outperforms most state-of-the-art semi-supervised learning methods on several benchmarks, especially when the quantity of labeled data is severely limited. Experiments show that FocalMatch significantly reduces the difference between the number of pseudo labels generated for each class. FocalMatch also has a smoother training curve and converges faster compared to FixMatch.
For the following sections, we first discuss the related work in Section 2, which includes semi-supervised learning and class imbalance. Next, We discuss the materials and methods of our proposed FocalMatch in Section 3. We then introduce the experiments we have performed, including the experiment setting and baseline methods used in Section 4. We compare the experiment outcomes in Section 5 and conduct an ablation study to investigate the effectiveness of our method in Section 6. Finally, we give a summary of this paper in Section 7, followed by Appendix A, which gives a detailed experiment setting. Figure 1. Framework of FocalMatch. Unlabeled data is fed into the model, and the model generates pseudo labels based on the weakly-augmented unlabeled data. At early iterations, the model prefers certain classes for pseudo labels, which causes severe class imbalance. Then, the model makes predictions on the same unlabeled data but with strong augmentation. The model calculates the loss between predictions and pseudo labels via focal loss [17] to ensure consistency. Focal loss adjusts the weight of different data based on how close their predictions are to their pseudo labels. For wellclassified data (i.e., the majority of pseudo labels), their loss contribution is reduced. Therefore, the model can rapidly eliminate class imbalance at early iterations.

Semi-Supervised Learning
Semi-supervised learning (SSL) is a popular machine learning technique that combines supervised and unsupervised learning, i.e., both labeled and unlabeled data are used for training. Typically, semi-supervised learning uses smaller labeled datasets because one of the main ideas behind SSL is to address the problem of insufficient labeled data [18,19]. Semi-supervised learning has become increasingly important in many areas, for example, medical image analysis [20] and natural language processing [21] since labeled data in these scenarios is usually either expensive or hard to obtain. A detailed introduction to semi-supervised learning can be found in [22].
In the study of semi-supervised learning, consistency regularization [23] is a commonly used technique. It assumes that distorted versions of the same input sample should yield similar predictions from the model. Consistency regularization has been applied in many state-of-the-art SSL methods. In order to generate distorted samples, data augmentation is usually beneficial. ReMixMatch [24] and UDA [25] both use strong data augmentations to improve the consistency regularization between different versions of images. Pseudolabeling [14] is another widely used technique in SSL that generates hard (one-hot) artificial labels for unlabeled data from the model predictions. In several recently proposed SSL methods, pseudo-labeling is combined with consistency regularization. FixMatch [15] generates one-hot pseudo labels from predictions on weakly-augmented data with a predefined high threshold and ensures consistency against strongly-augmented data.

Class Imbalance
In the majority of machine learning, the training dataset is considered well-balanced (i.e., each class contains a similar number of samples) [26]. However, class distribution is usually imbalanced (i.e., some classes contain considerably more samples than other classes) in real-world scenarios, including fraud detection [27], medical diagnosis [28], software failure prediction [29], etc. The class imbalance problem has a detrimental impact on machine learning, such as the convergence and generalization ability of the model [30]. A more extensive introduction of class imbalance is provided in [31].
As [32] suggests, the mainstream solution to class imbalance can be summarized into two approaches: data-level methods and algorithmic-level methods. Data-level methods aim to eliminate class imbalance by modifying the training dataset, such as oversampling [33] (randomly selects samples from minority class and duplicates them) and undersampling [34] (randomly selects samples from majority class and discards them). However, these sampling methods may degrade the final performance, such as causing overfitting [35]. Algorithmic-level methods, instead, aim to address the class imbalance problem by modifying the learning algorithms. Threshold moving is one of the most famous algorithmic-level methods. The main idea behind threshold moving is to adjust the output (e.g., weights) of the model continuously to accommodate the imbalanced distributions of samples [36]. In the machine learning area, both algorithmic-level and data-level methods are commonly used. For example, ref. [37] proposes a novel hybrid sampling method to address class imbalance based on generative adversarial network.
Our proposed FocalMatch combines FixMatch (the state-of-the-art semi-supervised learning framework) and focal loss (an algorithmic-level method) to address the class imbalance problem that occurs when generating pseudo labels in the unsupervised learning part. It surpasses the traditional cross-entropy loss function used by the model instead of modifying the original dataset, which is more in line with the application of semisupervised learning and ensures data privacy. FocalMatch adjusts the loss weights of different samples based on how close their predictions are to their pseudo labels. It decreases the loss weights of easy samples so that the overall unsupervised loss will not be overwhelmed by easy samples. A detailed ablation study to investigate the effectiveness of FocalMatch is discussed in Section 6. Comparing FocalMatch to other state-of-theart models (including Π model [38], Mean Teacher [39], MixMatch [40], ReMixMatch [24], UDA [25], and FixMatch [15]), our experiments reveal that FocalMatch significantly reduces the difference in the total number of pseudo labels generated for each class and has a more gradual learning curve. FocalMatch surpasses most state-of-the-art semi-supervised learning algorithms on several benchmarks, particularly when the amount of labeled data is severely constrained.

Consistency Regularization and Pseudo-Labeling
Consistency regularization [23] is one of the most popular ideas in semi-supervised learning. It assumes that the distortion of a sample should not have an impact on the predictions of the model. [15] formulates the consistency loss as Equation (1): where µ is the relative size of unlabeled data to labeled data, B is the batch size of labeled data, u b is an unlabeled data, α is a data augmentation function, p m (y|α(u b )) is the predictions (soft label) from the model on augmented u b . Since both α and p m are stochastic, the two items in Equation (1) are different. p m − p m 2 2 is used to measure the distance between the aforementioned two predictions.
In modern semi-supervised learning techniques, pseudo-labeling [14] is highly related to consistency regularization. It suggests that the model should generate artificial labels for the unlabeled data.The authors in [15] give the definition of the loss function of pseudolabeling as Equation (2): where p m (y|(u b )) is the predictions (soft label) from the model on u b , whereasp m is the one-hot pseudo label obtained from p m . τ is the hyperparameter that defines the threshold, and H is the cross-entropy loss.

FixMatch
FixMatch [15] is a recently proposed state-of-the-art semi-supervised learning algorithm. In FixMatch, the learning model first makes predictions on a weakly-augmented sample with probability distributions of each class (soft label). If the probability of a specific class exceeds the pre-defined threshold, that class will be adapted as a pseudo label (one-hot label) of the sample. Secondly, the learning model makes predictions on the same strongly-augmented sample and uses a cross-entropy loss to ensure the consistency regularization between the predictions of the strongly-augmented sample and the pseudo label. FixMatch combines consistency regularization and pseudo-labeling. Hence, Equation (2) can be re-formulated as: Equation (3) is simply the combination of Equations (1) and (2) provided in [15], where α refers to the weak data augmentation and A refers to the strong data augmentation. τ is the threshold, and H is the cross-entropy loss. In FixMatch, the one-hot pseudo label (p m ) is obtained by applying arg max to the soft label (p m ) of the weakly-augmented image.

FocalMatch
The standard cross-entropy loss measures the distance between two probability distributions (i.e., the ground truth and the prediction). The lower the cross-entropy loss, the closer the two probability distributions (i.e., the prediction is closer to the ground truth). Due to this property, the cross-entropy loss is widely used in classification tasks. However, the standard cross-entropy loss treats the loss contribution of each class equally. This is generally acceptable in class balance situations. However, in class imbalance situations (i.e., the sample sizes of some classes are significantly larger than others), the loss from majority classes can dominate the overall cross-entropy loss. As a result, the model can hardly learn useful information from the minority class, which will further decrease the prediction accuracy of the minority class. Moreover, because of the difference in sample size, even if the loss of a single sample from the minority class is higher than that of a sample from the majority class (due to the lower accuracy), the total loss from the majority class may still dominate the overall cross-entropy loss. Equation (4) shows the standard cross-entropy loss. For simplicity, we use the binary classification case in the following sections: where y is the ground truth label, and p is the predicted probability distribution of the sample. In order to solve the class imbalance problem, [17] proposed an improved version of the cross-entropy loss called focal loss. The main idea behind focal loss is to adjust the contributions of different samples. The focal loss adds a modulating factor to the standard cross-entropy loss: where (1 − p t ) γ is the modulating factor, and γ is a hyperparameter that is greater than or equal to 0. The existence of the modulating factor can help the model to adjust the weights of different samples. In the correct classification scenarios, p t is closer to 1, which means that the modulating factor is closer to 0. As a result, the loss weights of these samples (easy samples) are reduced. In the misclassified scenarios, p t is closer to 0, which means that the modulating factor is closer to 1. Therefore, the loss weights of these samples (hard samples) keep unchanged. Even if the number of easy samples is much higher than that of hard samples, the loss from hard samples will still account for a significant portion of the total loss due to the weight adjustment, and the model can learn valuable information from hard samples so that the model performance can be further improved.
Following the method described in [17], Figure 2 shows the loss curves with different γ values. Focal loss adjusts the contributions of easy samples. As γ rises, the model adjusts the loss contributions more strongly. In FixMatch [15], the cross-entropy loss is used between the pseudo labels and the predictions of strongly-augmented images. However, during the training phase, we found that the learning model tends to generate class-specific pseudo labels (e.g., the number of pseudo labels for cats may be much higher than the number of pseudo labels for airplanes). As a result, a class imbalance of pseudo labels occurs in the unsupervised learning phase. Detailed information on the number of pseudo labels generated is provided in Section 6. In this scenario, the cross-entropy loss is no longer optimal. To address this problem, we propose our new method, FocalMatch, that combines FixMatch and focal loss [17]. We replace the cross-entropy loss with the focal loss for the unsupervised learning part so that the model can focus more on the minority pseudo labels. Therefore, the unsupervised loss L u of our method can be formulated as: where u b is unlabeled data, α and A refer to weak data augmentation and strong data augmentation, respectively. p m (y|α(u b )) and p m (y|A(u b )) are the predicted probability distri-butions on weakly-augmented and strongly-augmented samples, respectively. The former one is also the soft label of the unlabeled sample. In addition,p m (y|α(u b ) are the pseudo labels generated from soft labels where the confidence of a specific class is higher than τ (i.e., the hyperparameter that defines the threshold). The supervised loss L s is the same as FixMatch [15]: where x b is the labeled data, y b is the corresponding label, and H is the standard crossentropy loss. The overall loss of FocalMatch is: where λ is another hyperparameter that defines the weight of unsupervised loss Lu. The detailed algorithm of our method is shown in Algorithm 1.

Algorithm 1 FocalMatch Algorithm
Class distribution on weakly augmented x b

Setup
To fairly compare our approach with other SSL methods, all the experiments are implemented with PyTorch [43] using the same codebase of TorchSSL [44]. We use similar hyperparameter settings as [15,44]: all baseline methods use Wide ResNet-28-2 [45] as the backbone network, batch size 64 for labeled data, standard stochastic gradient descent with a momentum of 0.9 as the optimizer [46,47], initial learning rate of 0.03 with cosine learning rate decay [48]. There are other hyperparameters that are method-dependent: µ (unlabeled data to labeled data ratio), τ (threshold of generating pseudo labels), λ (weight of unsupervised loss), temperature (for sharpening soft labels). As suggested by [44], all method-dependent hyperparameters follow the original papers. Some hyperparameters only belong to specific methods (e.g., the weight for distribution matching loss in ReMix-Match); these parameters also follow the original papers. In addition, [15] emphasizes the importance of combining weak and strong data augmentation. We use random horizontal flip (with 50% probability) and random crop (crop size 32) for weak data augmentation on the datasets mentioned above. For strong data augmentation, we use RandAugment [49]. In Appendix A, a comprehensive set of hyperparameters is presented.

Results
Our experiments use top-1 classification accuracy as the evaluation metric for all baseline methods and FocalMatch. The result is shown in Table 1. It shows that FocalMatch outperforms all baseline methods on most of the benchmarks. FocalMatch performs particularly well when the number of labeled data is extremely small (i.e., four labels per class). However, FocalMatch does not perform as well as expected on SVHN with 10 labels per class. FixMatch and UDA outperform their accuracy by around 0.2% and 0.5%, respectively. We believe this is due to the simplicity of the SVHN dataset. When the amount of labeled data is extremely small (i.e., four labels per class), the model is not able to produce valid pseudo labels evenly for all classes since the overall prediction confidence is not high enough. This causes a severe class imbalance problem which can be effectively alleviated by FocalMatch. When we increase the number of labeled data in the SVHN experiment (i.e., 10 labels per class), the model is confident enough to generate pseudo labels evenly, and the accuracy of each class is relatively high. Therefore, the loss contribution adjustment of FocalMatch can slow down the learning of the model. This may also explain the reason why FocalMatch achieves significant performance improvement when the number of labeled data is extremely small, whereas the performance improvement of FocalMatch reduces as the number of labeled data increases (i.e., less challenging to classify). FocalMatch has substantially extended the learning ability of FixMatch by resolving the latent class imbalance issue. Our method not only outperforms FixMatch in terms of classification accuracy on CIFAR-10, CIFAR100, and SVHN (except when the number of labels per class is 10 on the SVHN dataset) but also speeds up the convergence of the model. We compare the convergence speed of our method and FixMatch in terms of the overall top-1 accuracy and loss in Figure 3a,b. It is obvious that the loss curve of FocalMatch is smoother and converges faster compared to FixMatch. Following the approach described in [44], we also compare the accuracy of FixMatch and FocalMatch for each class in Figure 3c,d. It is observed that there is a large gap between the accuracy of each class in FixMatch, which is due to the class imbalance on the pseudo labels generated in the unsupervised learning part. The total unsupervised loss of FixMatch tends to be dominated by classes with a large number of pseudo labels instead of learning from the overall unlabeled data. This could explain why the accuracy of some classes is appreciably lower than that of other classes or even not improving at all (e.g., class 5).
On the other hand, the accuracy for each class of FocalMatch evenly increases with no significant differences between the classes. It demonstrates that FocalMatch is able to extract useful features from all classes uniformly instead of a specific class. We conduct an ablation study in Section 6 to investigate the effect of focal loss on addressing the class imbalance problem of pseudo labels.

Discussion and Ablation Study
Our method simply combines FixMatch [15] and focal loss [17]. The main idea behind focal loss is to address the problem of class imbalance. The class imbalance can make it difficult for the model to learn useful information from the minority class. Focal loss is commonly used in the object detection area since the number of images of target classes is much smaller than that of background classes. We find that focal loss is also useful for image classification when the amount of labels is imbalanced. Our experiments set an equivalent number of labels for each class. Therefore, the class imbalance problem can hardly occur in the supervised learning part. However, in FixMatch, pseudo labels are self-generated from unlabeled data. Therefore, the class imbalance can happen on the generated pseudo labels, which will affect the unsupervised learning of the model. To investigate the effectiveness of focal loss, we conduct an ablation study using different γ values in the focal loss. Figure 4 demonstrates the number of pseudo labels generated (i.e., the predicted confidence of a specific class in the soft label is greater than τ) by the model for each iteration on the CIFAR-10 dataset. Figure 4a shows the result of not using focal loss (i.e., γ = 0), whereas Figure 4b shows the result of using focal loss with γ = 1. It is evident that when not using focal loss, there is a significant quantitative imbalance in the pseudo labels of each class. The class imbalance problem can seriously affect the ability of the model to learn from classes with a small number of pseudo labels. In the early stages, the number of pseudo labels generated from a single class (i.e., class 2) is even higher than the aggregated number of pseudo labels generated from all other classes. This indicates that the unsupervised loss is dominated by a single class instead of all classes.
In contrast, the difference between the number of pseudo labels generated for each class is notably reduced when using focal loss. Therefore, the model can extract useful information from all classes. Focal loss does not present a stricter condition to reduce the number of pseudo labels. Instead, the total number of pseudo labels generated with focal loss is much higher than without focal loss (7 billion to 10 billion). Focal loss provides a smoother learning curve for the model to learn from all unlabeled data, which also shortens the iterations required to reach a stable phase of generating pseudo labels.

Conclusions and Future Work
This paper proposes FocalMatch, a new semi-supervised learning approach that combines FixMatch and focal loss. Instead of using the original cross-entropy loss for the unsupervised learning part, the focal loss is introduced in FocalMatch to effectively alleviate the problem of class imbalance that occurs on the generated pseudo labels during unsupervised learning. FocalMatch compels the model to focus more on the hard samples by adjusting the loss weights of different samples. Experiments show that FocalMatch dramatically reduces the variation in the number of pseudo labels generated for each class. In addition, FocalMatch outperforms all baseline methods and achieves state-of-the-art performance on several commonly used benchmarks, especially when the number of labeled data is extremely small. FocalMatch also provides a smoother learning curve and a higher convergence speed compared to FixMatch. The original focal loss contains an additional hyperparameter α that further adjusts loss contributions by the class frequency [17]. For semi-supervised learning methods that utilize pseudo-labeling, the number of pseudo labels generated for each class is usually unstable; therefore, it is hard to define the value of α beforehand. In future work, we plan to add α to FocalMatch and adjust the value of α and the modulating factor (i.e., γ) dynamically so that the model can converge more smoothly in different stages of training and is able to handle different tasks more efficiently. Data Availability Statement: In this study, there are three public datasets been used for performance evaluation, they are: CIFAR-10 [41], CIFAR-100 [41] and SVHN [42].