Next Article in Journal
Research on Electric Vehicle Powertrain Systems Based on Digital Twin Technology
Next Article in Special Issue
Multitasking Feature Selection Using a Clonal Selection Algorithm for High-Dimensional Microarray Data
Previous Article in Journal
Optimization of Impact Ionization in Metal–Oxide–Semiconductor Field-Effect Transistors for Improvement of Breakdown Voltage and Specific On-Resistance
Previous Article in Special Issue
UDIS: Enhancing Collaborative Filtering with Fusion of Dimensionality Reduction and Semantic Similarity
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Multiloss Joint Gradient Control Knowledge Distillation for Image Classification

1
The School of Microelectronics and Communication Engineering, Chongqing University, Chongqing 400044, China
2
China Railway Design Corporation, Tianjin 300450, China
*
Author to whom correspondence should be addressed.
Electronics 2024, 13(20), 4102; https://doi.org/10.3390/electronics13204102
Submission received: 29 September 2024 / Revised: 13 October 2024 / Accepted: 16 October 2024 / Published: 17 October 2024
(This article belongs to the Special Issue Knowledge Information Extraction Research)

Abstract

:
Knowledge distillation (KD) techniques aim to transfer knowledge from complex teacher neural networks to simpler student networks. In this study, we propose a novel knowledge distillation method called Multiloss Joint Gradient Control Knowledge Distillation (MJKD), which functions by effectively combining feature- and logit-based knowledge distillation methods with gradient control. The proposed knowledge distillation method discretely considers the gradients of the task loss (cross-entropy loss), feature distillation loss, and logit distillation loss. The experimental results suggest that logits may contain more information and should, consequently, be assigned greater weight during the gradient update process in this work. The empirical findings on the CIFAR-100 and Tiny-ImageNet datasets indicate that MJKD generally outperforms traditional knowledge distillation methods, significantly enhancing the generalization ability and classification accuracy of student networks. For instance, MJKD achieves a 63.53% accuracy on Tiny-ImageNet for the ResNet18 MobileNetV2 pair. Furthermore, we present visualizations and analyses to explore its potential working mechanisms.

1. Introduction

Deep convolution neural networks (CNNs) have been the predominant design paradigm for numerous computer vision tasks over the past few decades [1,2]. However, high-capacity models typically necessitate massive architectures, resulting in substantial computational and memory requirements. The deployment of state-of-the-art models into resource-constrained embedded systems remains a significant challenge. As the demand for low-cost networks on embedded systems continues to increase, there is an urgent need for more compact networks with reduced computational and memory requirements while maintaining performance comparable to their larger counterparts.
To address these issues, numerous techniques have been developed for compact neural networks, including knowledge distillation (KD) [3], network pruning [4,5,6], and quantization [7,8,9]. In this study, we focused on knowledge distillation due to its efficacy and potential for robust model compression capabilities. Knowledge distillation functions as a methodology for guiding the training of “student” neural networks through the encoding and propagation of knowledge from previously trained “teacher” networks. KD methods can be categorized into two primary classifications: logit-based and feature-based.
The original logit-based KD method introduced by Hinton [3] operates by minimizing the KL-divergence between the predicted logits of the teacher and student models. In comparison with logit-based methods, feature distillation demonstrates superior performance across various tasks, prompting numerous researchers to focus on investigating feature-based distillation methods [10,11,12,13,14,15]. To enhance the transfer of deep feature knowledge from intermediate layers to lightweight student networks, feature-based distillation methods such as FitNets [16] were proposed. FitNets initially addressed the issue of intermediate representation mismatches between the student and teacher networks through intermediate layer mapping, thus enabling feature-based knowledge distillation. Building on the FitNet method, AT [12] achieves further optimization by redefining the attention of convolutional neural networks to transfer knowledge through attention maps of features. In addition, the PKT [10] models the intermediate layer features of both the teacher and student as probability distributions, thereby facilitating a more effective knowledge transfer from the teacher. CRD [11] employs contrastive learning to transfer the deep knowledge of convolutional networks. Due to their outstanding performance and improved interpretability, feature-based distillation has long been a focus of research in knowledge distillation. However, recent perspectives suggest that logit-based distillation also holds significant potential, as the features from the final layer of the network may encapsulate comprehensive information about the entire model. DKD [17] effectively decouples the classical KD loss, achieving state-of-the-art performance in classification and detection tasks across various datasets solely using logit distillation. Furthermore, Hao et al. [18] have observed that the classical KD method can also achieve leading performance in numerous cases, particularly with larger datasets and extended distillation training periods. This observation underscores the simplicity and efficiency of logit-based distillation methods. Additionally, the LD [19] method highlights that the output features from the network’s final layer may also contain information related to target localization, further underscoring the potential of output feature-based distillation. From the perspective of deep learning optimizers, DOT [20] addresses the trade-off between the two losses in logit-based methods by employing a momentum buffer, thereby further enhancing the performance of logit-based distillation methods. Indeed, logits, as the final output of the network, encapsulate the overall network information and hold a comparably significant position alongside feature-based methods.
The aforementioned methods primarily focus on either feature-based or logit-based aspects. A more effective combination of the strengths of both methods can significantly enhance the students’ assimilation of knowledge, thereby improving their performance. However, merely combining these two methods does not substantially enhance the effectiveness of distillation. To address this issue, we propose a novel knowledge distillation method called Multiloss Joint Gradient Control Knowledge Distillation (MJKD). Inspired by recent advancements in deep learning optimizers [20], our method integrates the task loss (cross-entropy loss) between the student network outputs and the labels, feature distillation loss, and logit distillation loss with gradient control. The feature distillation loss and logit distillation loss were derived from the previous works, [10] and [17], respectively. In contrast to traditional methods, we control the proportions of different losses by adjusting the momentum of each loss independently, thus fully utilizing each component’s loss to achieve superior student model performance. The primary contributions of this study are as follows:
  • We present a novel method for independently controlling the momentum of task loss, feature distillation loss, and logit distillation loss. This innovation facilitates a more effective transfer of knowledge, resulting in enhanced student network performance.
  • We empirically validate the efficacy of the proposed MJKD method on two widely used image classification datasets, CIFAR-100 and Tiny-ImageNet. The empirical results demonstrate that training the student network using the MJKD distillation method achieves superior performance compared with traditional knowledge distillation methods. Our findings indicate that logits may contain richer network information. Furthermore, the robustness and efficiency of MJKD has been substantiated.
  • We provide a comprehensive analysis of the MJKD method, including visualizations of loss landscapes and correlation matrices between student and teacher logits. These insights offer a more profound understanding of the mechanisms underlying the improved performance of student networks trained with MJKD.
The remainder of this paper is organized as follows. In Section 2, we describe the primary methods employed in our study, focusing on feature distillation loss, logit distillation loss, and multiloss joint gradient control. Section 3 details the implementation of our experiments, including descriptions of the datasets and experimental setup. We present the main results, including motivation validation, comparative results, and visualization analysis. Finally, in Section 4, we summarize our research and briefly discuss the practical significance of MJKD.

2. Main Methods

In this section, we delineate the three key components of our knowledge distillation method. First, we define the feature distillation loss for our distillation method. Subsequently, we introduce distillation loss. Finally, we present a detailed explanation of the core component of our method, the multiloss joint gradient control mechanism.

2.1. Feature Distillation Loss

Feature distillation loss refers to the supervision derived from the intermediate layers of the teacher network, which is used to guide the intermediate layers of the student network. To facilitate the student model’s effective assimilation of the teacher model’s features, it is imperative to characterize the geometric structure of the feature spaces of both networks for each batch of data. Specifically, the joint probability density between any two data points within the network’s feature space can be utilized to calculate the distance between them. The objective is to minimize the discrepancy in joint density between the student probability distribution S and the teacher probability distribution T. Kernel Density Estimation (KDE) [21] provides an efficacious method to estimate these joint density probability functions as follows:
s i j = s i | j s j = 1 N K ( x i , x j ; 2 σ t 2 ) ,
and
t i j = t i | j t j = 1 N K ( y i , y j ; 2 σ s 2 ) ,
where x i and y i denote the output representations of the student model and teacher model, respectively, for the i-th class. N represents the number of classes in the dataset. K ( x i , x j ; 2 σ s 2 ) is a symmetric kernel with width 2 σ s 2 . Additionally, since the conditional probability distribution can represent the similarity between data within each batch [22], employing the conditional probability distribution allows for a more precise description of the local regions between samples [10]. Consequently, the joint probability density function can be substituted with the conditional probability distribution of the samples. For both the teacher model and the student model, the conditional probability distribution is defined as follows:
s i | j = K ( x i , x j ; 2 σ t 2 ) k = 1 , k j N K ( x k , x j ; 2 σ t 2 ) [ 0 , 1 ] ,
and
t i | j = K ( y i , y j ; 2 σ t 2 ) k = 1 , k j N K ( y k , y j ; 2 σ s 2 ) [ 0 , 1 ] .
In this study, we employ an approach that utilizes an affinity metric based on cosine similarity. The similarity metric employed is defined as follows:
K c o s i n e ( x i , x j ) = 1 2 ( x i T x j | | x i | | 2 | | x j | | 2 + 1 ) [ 0 , 1 ] .
Cosine metrics often demonstrate superior performance compared with Euclidean metrics, particularly in high-dimensional spaces [23], by utilizing cosine similarity as the kernel metric, enabling a more accurate computation of affinity. we employ the Kullback–Leibler (KL) divergence. Given that the distributions S and T can only be estimated through a finite number of data points, the feature distillation loss is defined as follows:
L f e a t u r e = i = 1 N j = 1 , i j N s j | i log s j | i t j | i .

2.2. Logit Distillation Loss

“Logit” refers to the raw output of a model, which is converted into class probabilities through the application of a “softmax” output layer. The logit distillation loss represents the supervision from the teacher model’s logits over the student model’s logits. Within each batch, the classification probabilities for a training sample belonging to the t-th class are represented as n = n 1 , n 2 , , n t , , n D R 1 × D , where n i denotes the probability of the i-th class, and D is the total number of classes. Given that z i represents the logits of the i-th class, the softmax function can be formally written as
n i = exp ( z i ) k = 1 D exp ( z k ) .
The softmax function can be decoupled into two components. Excluding the t-th class, the probabilities are represented by n ^ = n ^ 1 , , n ^ t 1 , n ^ t + 1 , , n ^ D R 1 × ( D 1 ) , which independently models probabilities among nontarget classes. Each element is computed as follows:
n ^ i = exp ( z i ) k = 1 , k t D exp ( z k ) .
Subsequently, we define the binary probabilities as c = c t , 1 c t R 1 × 2 , where c t represents the probability of the target class, and 1 c t represents the probability of all other non-target classes. The specific calculation is as follows:
c t = exp ( z t ) k = 1 D exp ( z k ) .
Thus, the logit distillation loss can be defined as
L l o g i t s = α KL ( c T | | c S ) + β KL ( n ^ T | | n ^ S ) .
Here, KL ( · ) denotes the KL divergence. KL ( c T | | c S ) represents the Target Class Knowledge Distillation (TCKD) component, which measures the similarity between the teacher and student models’ binary probabilities for the target class. Concurrently, KL ( n ^ T | | n ^ S ) represents the NonTarget Class Knowledge Distillation (NCKD) component, which measures the similarity between the teacher and student models’ probabilities of the nontarget classes. Following the approach in [17], this study sets the values of α and β to 1 and 8, respectively.

2.3. Multiloss Joint Gradient Control

The primary contribution of this paper is the introduction of multiloss joint gradient control, which enables the independent regulation of multiple loss functions by updating the gradients of each loss separately and storing them in distinct momentum buffers. This approach facilitates the management of the optimization process for each individual loss component.
Prior to elucidating the specifics of updating gradient control, it is necessary to provide an introduction to Stochastic Gradient Descent (SGD). The SGD optimizer is currently a widely utilized method for training in deep learning. SGD with momentum updates the network parameters using both the current gradient and historical gradients. This is denoted as g = θ L ( θ ) , where θ represents the network parameters. Specifically, SGD maintains a “momentum buffer” to store gradients for network parameter updates. For each training batch, the momentum buffer update in SGD can be expressed as follows:
v g + μ v ,
Here, v and μ represents the momentum buffer and the momentum coefficient. Subsequently, the model parameters θ will be updated according to the gradient descent rule with the learning rate γ :
θ θ γ υ ,
Subsequently, we introduce Multiloss Joint Gradient Control Knowledge Distillation (MJKD). For each batch of input data, MJKD initially computes three distinct types of losses, then independently updates the momentum buffers, and finally updates the student network parameters.
As illustrated in Figure 1, Figure 1a depicts the classical knowledge distillation (KD) method, while Figure 1b presents an overview of the MJKD approach. The KD method employs the Stochastic Gradient Descent (SGD) update mechanism, wherein different loss terms are numerically summed directly. The gradient of the total loss is then computed and stored in a momentum buffer, followed by backpropagation to update the model parameters. In contrast to the SGD optimizer, MJKD independently computes the gradients for the task loss, logit distillation loss, and feature distillation loss. Subsequently, it utilizes a hyperparameter Δ to individually control the gradient of each loss, storing them separately in their respective momentum buffers. Finally, the parameters of the student network are updated accordingly.
To further investigate the differences between this method and the SGD optimizer, the mathematical formulation of the update process will be analyzed. These three momentum buffers are denoted as v task , v logits , and v feature , respectively. MJKD introduces a hyperparameter Δ and employs μ + Δ 1 , μ + Δ 2 , and μ + Δ 3 to update the momentum buffers v task , v logits , and v feature , respectively. The specific formulas are as follows:
v task g task + ( μ + Δ 1 ) v task , v logits g logits + ( μ + Δ 2 ) v logits , v feature g feature + ( μ + Δ 3 ) v feature .
The parameters of the student network are updated based on the sum of three momentum buffers:
θ θ γ ( v task + v logits + v feature ) .
To further examine the differences between the two updating approaches, we begin by analyzing the momentum buffer formulation for SGD, given as follows:
v sgd = g task + g logits + g feature + μ ( v task + v logits + v feature ) ,
For MJKD, the updated momentum buffer is
v mjkd = g task + g logits + g feature + ( μ + Δ 1 ) v task + ( μ + Δ 2 ) v logits + ( μ + Δ 3 ) v feature = g task + g logits + g feature + μ ( v task + v logits + v feature ) + Δ 1 v logits + Δ 2 v task + Δ 3 v feature ,
Contrasting Equations (15) and (16), the differences between the two methods are as follows:
v diff = v mjkd v sgd = Δ 1 v logits + Δ 2 v task + Δ 3 v feature
MJKD can enhance the impact of each branch of loss on the gradient updates by adjusting Δ 1 , Δ 2 , and Δ 3 . Furthermore, when Δ = 0 , the update mechanism of MJKD is equivalent to traditional SGD.

3. Experiments and Discussion

In this part, we first detail the experimental settings of two datasets. Subsequently, we empirically validate our hypotheses and discuss the primary results on widely used image classification benchmarks. We also provide analyses, including visualizations, for further insights.

3.1. Dataset and Experimental Setup

To ensure the reliability of the experimental results, unless otherwise specified, all data are presented as the mean (±standard deviations) of five independent experiments. The values in bold within the table indicate the highest values in the context of the presented comparisons. We conduct experiments on two image classification datasets, CIFAR-100 [24] and Tiny-ImageNet [25].
We start with the CIFAR-100 dataset to evaluate our method. CIFAR-100 is a well-known image classification dataset, containing 32 × 32 images of 100 categories. Training and validate sets are composed of 50 k and 10 k images. For CIFAR-100, we train all models for 240 epochs with learning rates decayed by 0.1 at the 150th, 180th, and 210th epochs. The batch size is 64 for all models. We use SGD with 0.9 momentum and 0.0005 weight decay as the optimizer.
Tiny-ImageNet consists of 200 classes and the image size is 64 × 64. The training set contains 100 k images and the validation contains 10 k images. ImageNet-1k is a large-scale classification dataset that consists of 1 k classes. The training set contains 1.28 million images and the validation set contains 50 k images. All images are cropped to 224 × 224. For Tiny-ImageNet, all the models are trained for 200 epochs with learning rates decayed by 0.1 at the 60th, 120th, and 160th epochs. The initial learning rate is 0.05 for a 64-batch size. We use SGD with 0.9 momentum and 0.0005 weight decay as the optimizer.

3.2. Main Results

3.2.1. Motivation Validations

The verification of our conjectures and hypotheses constitutes a critical experiment. We initially investigated three questions: (1) Is the effectiveness of independent gradient control genuinely substantiated? (2) Among task loss, logit distillation loss, and feature distillation loss, which plays a more critical role? (3) Is it possible to improve distillation performance solely by adjusting the relative contributions of different loss functions without independent gradient control?
To address the first question, we initially investigated the impact of Δ 1 , Δ 2 , and Δ 3 on model performance. Δ 1 , Δ 2 , and Δ 3 are the momentum buffer update parameters for the task loss, logit distillation loss, and feature distillation loss, respectively. A positive Δ indicates that the corresponding loss receives a greater update weight. Utilizing ResNet32x4 as the teacher network and ResNet8x4 as the student network, we conducted distillation training on the CIFAR-100 dataset. As illustrated in Table 1, we observe that the performance of the student network is significantly influenced by the values of Δ 1 , Δ 2 , and Δ 3 . When these parameters are set to −0.05, 0.05, and −0.05, respectively, the model achieves the optimal validation accuracy of 76.52%, indicating that the logit distillation loss, with Δ 2 > 0 , plays a crucial role in enhancing performance. This finding suggests that increasing the weight of the logit distillation loss allows it to contribute more substantially to the overall optimization, resulting in improved knowledge transfer between the teacher and student networks. This configuration outperforms the others. Therefore, the subsequent experiments will continue to maintain the configuration where Δ 2 > 0 .
To further examine whether the effectiveness of independent gradient control is not due to meticulous hyperparameter tuning, we investigated the impact of varying Δ values in MJKD. As shown in Table 2, setting Δ = 0 (i.e., training all losses with the same momentum, equivalent to SGD) results in a top-1 accuracy of 75.78%. In contrast, introducing controlled multiloss joint gradient control significantly improves performance. For example, when Δ 1 , Δ 2 , and Δ 3 are set to −0.0375, 0.0375, and −0.0375, respectively, the model achieves the highest top-1 accuracy of 76.44%. Under all other parameter configurations, the model’s accuracy consistently exceeded that observed in the case where Δ = 0 . This demonstrates the effectiveness of independent gradient control in improving model performance without requiring meticulous hyperparameter tuning.
Based on the results above, it can be observed that amplifying the dominance of logit distillation loss in gradient updates contributes to improving the student network’s performance. This raises the second question: What is the relationship between task loss, logit distillation loss, and feature distillation loss, and which plays a more critical role? To investigate this issue, we conducted a series of experiments employing ResNet32x4 as the teacher network and ResNet8x4 as the student network on the CIFAR-100 dataset. Additionally, we utilized the ResNet18-MobileNetV2 teacher–student pair for experiments on the Tiny-ImageNet dataset. For simplicity, in our subsequent experiments, we set the parameters Δ 1 , Δ 2 , and Δ 3 to −0.05, 0.05, and −0.05, respectively.
We plot the training loss in Figure 2, where the left Figure 2a represents results on the CIFAR-100 training set, and the right Figure 2b represents results on the Tiny-ImageNet training set. In these figures, the baseline refers to training using only the task loss (i.e., cross-entropy loss), while MJKD without gradient control refers to the case of MJKD with equal momentum (i.e., Δ = 0 ). MJKD with Δ 1 = 0.05 implies that the task loss receives a smaller update momentum, which would typically impede the task loss from converging to a lower level. However, the results suggest that reducing the supervision of the task loss while enhancing the role of the logit loss may be more advantageous for knowledge transfer. As demonstrated in Figure 2, the final performance of MJKD (depicted by blue lines) on the CIFAR-100 dataset is comparable to that of the cross-entropy baseline (represented by black lines), with a slight improvement over MJKD without gradient control (shown by green lines). Furthermore, MJKD consistently achieves a lower task loss on Tiny-ImageNet datasets compared with both baseline and MJKD without gradient control.
This phenomenon can be attributed to the fact that the information carried by the labels is not readily assimilated by the student network. Utilizing independent gradient control allows the logit distillation loss from the teacher network to dominate the updates. Furthermore, the label information carried by the logit distillation loss, after being processed by the teacher network, is more readily assimilated by the student network, thus more effectively guiding it to achieve enhanced model performance.
Subsequently, we investigate the relationship between the logit distillation loss and feature distillation loss. As shown in Figure 3, “MJKD without feature loss” and “MJKD without logits loss” indicate that these losses are excluded from the backpropagation process, contributing only their computed values to the results. We adopt a warm-up strategy for the distillation loss, gradually increasing it over the first 20 epochs. Consistent with the previous analysis, the left Figure 3a and Figure 3c represent results on the CIFAR-100 training set, while the right Figure 3b and Figure 3d represent results on the Tiny-ImageNet training set. The network and parameter settings remain consistent with those previously mentioned.
In Figure 3, it is evident that the blue MJKD curve is the lowest in both the logit distillation loss and feature distillation loss diagrams. This demonstrates that utilizing both logits and feature distillation losses concurrently is beneficial for enhanced knowledge transfer from the teacher network, leading to improved student network performance. Omitting either distillation loss results in performance degradation. However, it is noteworthy in Figure 3c and Figure 3d that the absence of logit distillation loss leads to a higher feature distillation loss compared with the absence of feature distillation loss. This indicates that the “logits” of the fully connected output contain overall feature information of the network. Nevertheless, this phenomenon may not occur with other distillation configurations, as we employed specific methods for calculating the logit loss and feature loss. Consequently, the absence of supervision from the logit distillation loss hinders the student network’s ability to effectively learn feature information from the teacher network. However, the absence of feature distillation loss results in a decrease in student network performance. Although the role of feature distillation loss is relatively secondary, it remains significant. Feature-based distillation methods provide a more precise representation of intermediate network layer characteristics and have consistently demonstrated robust distillation performance. These methods have long been a central focus in knowledge distillation research. However, recent studies have revealed the substantial potential of logit-based methods, which has not been fully realized due to certain limiting factors [17,19]. The information contained in the logit distillation loss from the teacher network encompasses, to some extent, the comprehensive information present in the feature distillation loss. Consequently, logit distillation loss should assume a primary role in our distillation training.
Nevertheless, the third question arises as to whether it is possible to improve distillation performance solely by adjusting the relative contributions of different loss functions, without employing independent gradient control. To address this question, we introduced a weight parameter α to regulate the proportion of each loss, without utilizing gradient control. The logit distillation loss is multiplied by the coefficient α , while the task loss and feature distillation loss are multiplied by the coefficient 1 α . A larger α value amplifies the influence of logit distillation loss on optimization and diminishes the influence of task loss and feature distillation loss.
Table 3 and Table 4 present the experimental results with varying weightings of α on the CIFAR-100 and Tiny-ImageNet datasets, respectively. As demonstrated, regardless of the value of α , the top-1 accuracy achieved with different weightings is consistently lower than that obtained with the MJKD method. For CIFAR-100, the best accuracy of 76.42% is achieved with MJKD, whereas the highest accuracy obtained with any α setting is 75.14%. Similarly, for Tiny-ImageNet, the MJKD method achieves a top-1 accuracy of 63.53%, while the highest accuracy with any α setting is 60.98%. MJKD significantly outperformed the method of only adjusting the relative contributions of different loss functions without employing independent gradient control on both datasets.
Figure 4 illustrates the impact of different α values on the three types of loss during training and compares them with the MJKD method. The upper three subfigures are derived from experiments using the CIFAR-100 dataset, while the lower three subfigures are based on the Tiny-ImageNet dataset. The MJKD method generally demonstrates a consistent ability to achieve the lowest loss, indicating that employing independent momentums to control different loss weights is both necessary and effective. These results indicate that adjusting α alone does not attain the performance level of MJKD, and the final student network performance is substantially inferior to that of the MJKD method, highlighting the importance of gradient control in achieving superior model performance.

3.2.2. Comparative Results

For CIFAR-100 image classification, Table 5 presents a comparative analysis of our proposed MJKD method with other widely used knowledge distillation. The evaluation includes teacher–student pairs involving both identical and heterogeneous network architectures. It is evident that our MJKD consistently achieves superior performance, yielding the highest validation accuracy across all configurations. This demonstrates the robustness and generalizability of our method, which maintains high accuracy across various architectures. For example, when using ResNet32x4 as the teacher and ResNet8x4 as the student, our MJKD method achieves a validation accuracy of 76.42%, surpassing other advanced techniques such as AT, DKD, and PKT, which reach 73.61%, 76.06%, and 74.10%, respectively. In a similar manner, when utilizing MobileNet-V2 as the student network in conjunction with VGG13 as the teacher, the MJKD approach demonstrates a performance improvement of nearly 1% over classical DKD, underscoring its efficacy in addressing teacher–student pairs involving heterogeneous network architectures. These results validate the effectiveness and robustness of our method under diverse teacher–student framework configurations.
For Tiny-ImageNet classification, Table 6 presents results in terms of both top-1 and top-5 accuracy. Here, we conducted experiments comparing MJKD to other distillation methods, with ResNet18 as the teacher network and either MobileNet-V2 or ShuffleNetV2 as the student networks. The results clearly demonstrate that MJKD achieves superior performance across all metrics. When MobileNet-V2 is the student, MJKD improves the top-1 accuracy by approximately 2% over DKD, achieving 63.53% compared with DKD’s 61.17%. Similarly, for ShuffleNetV2 as the student, MJKD attains a significant accuracy boost, with top-1 accuracy reaching 57.68%, outperforming DKD’s 56.09%. Furthermore, MJKD also excels in top-5 accuracy, achieving 84.64% for MobileNet-V2 and 80.80% for ShuffleNetV2, confirming the effectiveness of the proposed method. 

3.2.3. Visualization Analysis

To further illustrate the superiority of the proposed knowledge distillation method, we visualized the model’s loss landscapes and the differences in correlation matrices between the student and teacher logits. Three knowledge distillation methods were employed for comparative analysis: KD, which represents the classical knowledge distillation approach; DKD, indicating the current state-of-the-art framework of knowledge distillation; and MJKD, which is the knowledge distillation method proposed by us. We used the ResNet18-MobileNetV2 teacher–student pair on the Tiny-ImageNet dataset. As shown in Table 6, the experimental top-1 accuracies for the KD, DKD, and MJKD methods are 57.43%, 61.17%, and 63.53%, respectively.
Leveraging three distinct types of loss can flatten the loss convergence minima of a student network. A well-established hypothesis posits that the flatness of the loss convergence minima is related to the generalization capability of the neural networks [27]. Networks with flatter minima exhibit enhanced robustness and generalization ability [28,29,30]. Due to the inherent bias between the training and testing datasets, the model’s optimal parameters for peak performance may exhibit slight variations. In instances where the minima are sharp, even minor differences in the model’s parameters can result in significant fluctuations in accuracy. Conversely, flatter minima offer greater flexibility, as small changes in parameters do not substantially impact performance. Consequently, flatter minima demonstrate increased resilience to stochastic perturbations in the loss landscape, which mitigates the sensitivity to these parameter shifts and effectively reduces the generalization error.
Utilizing the work of Li et al. [31] to visualize the loss landscape can facilitate the observation of loss convergence in models after knowledge distillation, thus providing enhanced insights into their operational mechanisms. To further investigate the impact of our MJKD method on model generalization, we visualized the loss landscapes of the converged student networks after training. As shown in Figure 5, the left Figure 5a employ the KD method, and Figure 5b utilize the DKD method, while Figure 5c depicts the corresponding landscapes when the student networks are trained using our proposed MJKD method. The primary area of interest in the figure is the size of the purple region at the center of the loss landscapes, which indicates the flatness of the loss minima. The comparative analysis of these loss landscapes reveals that the loss functions of student networks trained with MJKD exhibit significantly flatter minima compared with those obtained with KD and DKD. This enhanced flatness of minima suggests that our method encourages the convergence of the task loss to smoother and broader regions in the parameter space. This enhanced flatness of minima suggests that our method encourages the convergence of the task loss to smoother and broader regions in the parameter space, which explains the generalization performance observed in the student networks.
In addition, we visualized the differences in correlation matrices between the logits generated by the student networks and those produced by the teacher networks. This analysis provides a more comprehensive understanding of how effectively the student networks learn to replicate the behavior of the teacher networks. In Figure 6, the subfigures represent the differences in correlation matrices when utilizing KD(a), DKD(b), and MJKD(c) on Tiny-ImageNet datasets. The comparison elucidates that, in contrast to KD, the MJKD method enables the student networks to produce logits that are more closely aligned with those of the teacher. Although the disparity is less pronounced compared with the advanced DKD method, the student networks trained with MJKD still produce logits that align more closely with those of the teacher, demonstrating a subtle improvement. This greater similarity in the logits indicates that MJKD significantly enhances the distillation process by enabling the student to more effectively emulate the output distributions of the teacher network. This improved alignment between the student and teacher logits contributes to the enhanced performance of MJKD.

4. Conclusions

In this study, we propose a novel knowledge distillation method called Multiloss Joint Gradient Control Knowledge Distillation (MJKD) to enhance the performance of student networks by effectively combining feature-based and logit-based knowledge distillation with gradient control. Our approach introduces independent gradient control for task loss, feature distillation loss, and logit distillation loss, which facilitates a more efficient and effective knowledge transfer from teacher to student networks.
The introduction of independent gradient control for the three loss components is a critical innovation of our method. A detailed theoretical analysis of the proposed method was conducted, followed by several experiments to evaluate the knowledge distillation approach. The experimental results suggest that logits may encapsulate substantial information about the model’s overall architecture. Our empirical evaluation of the proposed method on the CIFAR-100 and Tiny-ImageNet datasets demonstrated significant improvements over the classical knowledge distillation methods. By independently controlling the gradients of task loss, feature distillation loss, and logit distillation loss, the MJKD method consistently achieves higher validation accuracies across various teacher–student network pairs, regardless of whether the architectures are identical or diverse. For instance, MJKD achieves a validation accuracy of 76.42% with the ResNet32x4–ResNet8x4 pair. On Tiny-ImageNet, MJKD reaches 63.53% for the ResNet18–MobileNet-V2 pair. Our experiments validated the robustness and effectiveness of the MJKD method, confirming its potential for broader application in various image classification tasks. Furthermore, we present visualizations and analyses to explore the potential working mechanisms of our method by examining the model’s loss landscapes and the differences in correlation matrices between the student’s and teacher’s logits.
The practical implications of our research are significant for the development of efficient and high-performance neural networks for real-world applications. The MJKD method enables the deployment of lightweight student networks that maintain high accuracy, making it suitable for resource-constrained environments such as mobile devices and edge computing.

Author Contributions

Conceptualization, W.H. and X.Z.; methodology, W.H. and X.Z.; software, J.P.; validation, J.P., J.Z. and X.H.; formal analysis, J.Z.; investigation, J.P. and X.H.; resources, X.Z.; data curation, J.P.; writing—original draft preparation, J.P.; writing—review and editing, W.H. and Y.L.; visualization, J.Z. and X.H.; supervision, X.Z. and Y.L.; project administration, J.L.; funding acquisition, Y.L. All authors have read and agreed to the published version of the manuscript.

Funding

This work was funded by the Class A key project of China Railway Design Corporation, grant number 2023A0203602.

Data Availability Statement

The data presented in this study are available on request from the corresponding author.

Conflicts of Interest

Author Jianyu Zhang was employed by the company China Railway Design Corporation. The remaining authors declare that the research was conducted in the absence of any commercial or financial relationships that could be construed as a potential conflict of interest.

References

  1. He, K.; Zhang, X.; Ren, S.; Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 770–778. [Google Scholar]
  2. Krizhevsky, A.; Sutskever, I.; Hinton, G.E. ImageNet classification with deep convolutional neural networks. Commun. ACM 2017, 60, 84–90. [Google Scholar] [CrossRef]
  3. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
  4. Anwar, S.; Hwang, K.; Sung, W. Structured pruning of deep convolutional neural networks. ACM J. Emerg. Technol. Comput. Syst. 2017, 13, 1–18. [Google Scholar] [CrossRef]
  5. Liu, Z.; Sun, M.; Zhou, T.; Huang, G.; Darrell, T. Rethinking the value of network pruning. arXiv 2018, arXiv:1810.05270. [Google Scholar]
  6. Han, S.; Pool, J.; Tran, J.; Dally, W. Learning both weights and connections for efficient neural network. Adv. Neural Inf. Process. Syst. 2015, 28, 1135–1143. [Google Scholar]
  7. Courbariaux, M.; Bengio, Y.; David, J.P. Binaryconnect: Training deep neural networks with binary weights during propagations. Adv. Neural Inf. Process. Syst. 2015, 28, 3123–3131. [Google Scholar]
  8. Jacob, B.; Kligys, S.; Chen, B.; Zhu, M.; Tang, M.; Howard, A.; Adam, H.; Kalenichenko, D. Quantization and training of neural networks for efficient integer-arithmetic-only inference. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Salt Lake City, UT, USA, 18–23 June 2018; pp. 2704–2713. [Google Scholar]
  9. Rastegari, M.; Ordonez, V.; Redmon, J.; Farhadi, A. Xnor-net: Imagenet classification using binary convolutional neural networks. In Proceedings of the European Conference on Computer Vision, Amsterdam, The Netherlands, 11–14 October 2016; Springer: Berlin/Heidelberg, Germany, 2016; pp. 525–542. [Google Scholar]
  10. Passalis, N.; Tzelepi, M.; Tefas, A. Probabilistic Knowledge Transfer for Lightweight Deep Representation Learning. IEEE Trans. Neural Netw. Learn. Syst. 2021, 32, 2030–2039. [Google Scholar] [CrossRef]
  11. Tian, Y.; Krishnan, D.; Isola, P. Contrastive Representation Distillation. In Proceedings of the International Conference on Learning Representations, Addis Ababa, Ethiopia, 26–30 April 2020. [Google Scholar]
  12. Zagoruyko, S.; Komodakis, N. Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. arXiv 2016, arXiv:1612.03928. [Google Scholar]
  13. Peng, B.; Jin, X.; Liu, J.; Li, D.; Wu, Y.; Liu, Y.; Zhou, S.; Zhang, Z. Correlation congruence for knowledge distillation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, Seoul, Republic of Korea, 27–28 October 2019; pp. 5007–5016. [Google Scholar]
  14. Park, W.; Kim, D.; Lu, Y.; Cho, M. Relational knowledge distillation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Long Beach, CA, USA, 15–20 June 2019; pp. 3967–3976. [Google Scholar]
  15. Heo, B.; Kim, J.; Yun, S.; Park, H.; Kwak, N.; Choi, J.Y. A comprehensive overhaul of feature distillation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, Seoul, Republic of Korea, 27–28 October 2019; pp. 1921–1930. [Google Scholar]
  16. Adriana, R.; Nicolas, B.; Samira, E.; Antoine, C.; Carlo, G.; Yoshua, B. Fitnets: Hints for Thin Deep Nets. In Proceedings of the International Conference on Learning Representations, San Diego, CA, USA, 7–9 May 2015. [Google Scholar]
  17. Zhao, B.; Cui, Q.; Song, R.; Qiu, Y.; Liang, J. Decoupled knowledge distillation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, New Orleans, LA, USA, 18–24 June 2022; pp. 11953–11962. [Google Scholar]
  18. Hao, Z.; Guo, J.; Han, K.; Hu, H.; Xu, C.; Wang, Y. VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale. arXiv 2023, arXiv:2305.15781. [Google Scholar]
  19. Zheng, Z.; Ye, R.; Hou, Q.; Ren, D.; Wang, P.; Zuo, W.; Cheng, M.M. Localization distillation for object detection. IEEE Trans. Pattern Anal. Mach. Intell. 2023, 45, 10070–10083. [Google Scholar] [CrossRef]
  20. Zhao, B.; Cui, Q.; Song, R.; Liang, J. DOT: A Distillation-Oriented Trainer. In Proceedings of the IEEE/CVF International Conference on Computer Vision, Vancouver, BC, Canada, 18–22 June 2023; pp. 6189–6198. [Google Scholar]
  21. Scott, D.W. Multivariate Density Estimation: Theory, Practice, and Visualization; John Wiley & Sons: Hoboken, NJ, USA, 2015. [Google Scholar]
  22. Van der Maaten, L.; Hinton, G. Visualizing data using t-SNE. J. Mach. Learn. Res. 2008, 9, 2579–2605. [Google Scholar]
  23. Wang, D.; Lu, H.; Bo, C. Visual tracking via weighted local cosine similarity. IEEE Trans. Cybern. 2014, 45, 1838–1850. [Google Scholar] [CrossRef]
  24. Krizhevsky, A.; Hinton, G. Learning Multiple Layers of Features from Tiny Images; University of Toronto: Toronto, ON, Canada, 2009. [Google Scholar]
  25. Le, Y.; Yang, X. Tiny imagenet visual recognition challenge. CS 231N 2015, 7, 3. [Google Scholar]
  26. Ahn, S.; Hu, S.X.; Damianou, A.; Lawrence, N.D.; Dai, Z. Variational information distillation for knowledge transfer. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Long Beach, CA, USA, 16–20 June 2019; pp. 9163–9171. [Google Scholar]
  27. Keskar, N.S.; Mudigere, D.; Nocedal, J.; Smelyanskiy, M.; Tang, P.T.P. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv 2016, arXiv:1609.04836. [Google Scholar]
  28. Dinh, L.; Pascanu, R.; Bengio, S.; Bengio, Y. Sharp minima can generalize for deep nets. In Proceedings of the International Conference on Machine Learning, PMLR, Sydney, Australia, 6–11 August 2017; pp. 1019–1028. [Google Scholar]
  29. Izmailov, P.; Podoprikhin, D.; Garipov, T.; Vetrov, D.; Wilson, A.G. Averaging weights leads to wider optima and better generalization. arXiv 2018, arXiv:1803.05407. [Google Scholar]
  30. Cho, J.H.; Hariharan, B. On the efficacy of knowledge distillation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, Seoul, Republic of Korea, 27–28 October 2019; pp. 4794–4802. [Google Scholar]
  31. Li, H.; Xu, Z.; Taylor, G.; Studer, C.; Goldstein, T. Visualizing the loss landscape of neural nets. Adv. Neural Inf. Process. Syst. 2018, 31, 6391–6401. [Google Scholar] [CrossRef]
Figure 1. The figure illustrates the concept of knowledge distillation [3] alongside our proposed Multiloss Joint Gradient Control Knowledge Distillation (MJKD) approach. In MJKD, the gradients associated with the task loss, logit distillation loss, and feature distillation loss are computed independently and subsequently utilized to update their respective momentum buffers.
Figure 1. The figure illustrates the concept of knowledge distillation [3] alongside our proposed Multiloss Joint Gradient Control Knowledge Distillation (MJKD) approach. In MJKD, the gradients associated with the task loss, logit distillation loss, and feature distillation loss are computed independently and subsequently utilized to update their respective momentum buffers.
Electronics 13 04102 g001
Figure 2. Task loss on CIFAR-100 (a) and Tiny-ImageNet (b).
Figure 2. Task loss on CIFAR-100 (a) and Tiny-ImageNet (b).
Electronics 13 04102 g002
Figure 3. Distillation Loss on CIFAR-100 (a,c) and Tiny-ImageNet (b,d).
Figure 3. Distillation Loss on CIFAR-100 (a,c) and Tiny-ImageNet (b,d).
Electronics 13 04102 g003
Figure 4. Illustration of loss for weighing α on CIFAR-100 and Tiny-ImageNet.
Figure 4. Illustration of loss for weighing α on CIFAR-100 and Tiny-ImageNet.
Electronics 13 04102 g004
Figure 5. Loss landscapes of (a) KD, (b) DKD, and (c) MJKD on the Tiny-ImageNet dataset.
Figure 5. Loss landscapes of (a) KD, (b) DKD, and (c) MJKD on the Tiny-ImageNet dataset.
Electronics 13 04102 g005
Figure 6. Difference in the correlation matrices of student and teacher logits on the Tiny-ImageNet dataset.
Figure 6. Difference in the correlation matrices of student and teacher logits on the Tiny-ImageNet dataset.
Electronics 13 04102 g006
Table 1. Different Δ values on CIFAR-100 validation. The teacher network is ResNet32x4 and the student network is ResNet8x4.
Table 1. Different Δ values on CIFAR-100 validation. The teacher network is ResNet32x4 and the student network is ResNet8x4.
Δ 1   Δ 2   Δ 3 −0.05 −0.05 0.05−0.05 0.05 −0.050.05 −0.05 −0.05
Top-175.41 ± 0.2176.52 ± 0.3775.26 ± 0.29
Δ 1   Δ 2   Δ 3 −0.05 0.05 0.050.05 −0.05 0.050.05 0.05 −0.05
Top-176.08 ± 0.2074.64 ± 0.2275.47 ± 0.33
Table 2. Different Δ values on CIFAR-100 validation under logit distillation loss dominance. The teacher network is ResNet32x4 and the student network is ResNet8x4.
Table 2. Different Δ values on CIFAR-100 validation under logit distillation loss dominance. The teacher network is ResNet32x4 and the student network is ResNet8x4.
Δ 1   Δ 2   Δ 3 0.00 0.00 0.00−0.0125 0.0125 −0.0125−0.025 0.025 −0.025−0.0375 0.0375 −0.0375
Top-175.78 ± 0.3776.17 ± 0.3176.15± 0.1276.44 ± 0.23
Δ 1   Δ 2   Δ 3 −0.05 0.05 0.05−0.0625 0.0625 −0.0625−0.0750 0.0750 −0.0750−0.0875 0.0875 −0.0875
Top-176.42 ± 0.2676.33 ± 0.2776.15 ± 0.1975.29 ± 0.13
Table 3. Different weight ( α ) results on CIFAR-100 validation.
Table 3. Different weight ( α ) results on CIFAR-100 validation.
α 0.10.250.50.750.9MJKD
top174.88 ± 0.1975.14 ± 0.3974.87 ± 0.2075.01 ± 0.1574.93 ± 0.2376.42 ± 0.26
Table 4. Different weight ( α ) results on Tiny-ImageNet validation.
Table 4. Different weight ( α ) results on Tiny-ImageNet validation.
α 0.10.250.50.750.9MJKD
top160.25 ± 0.3160.98 ± 0.2460.78 ± 0.2260.47 ± 0.0559.97 ± 0.1663.53 ± 0.26
Table 5. Results on the CIFAR-100 validation.
Table 5. Results on the CIFAR-100 validation.
TeacherResNet32x4VGG13VGG13WRN-40-2ResNet50ResNet32x4ResNet32x4
79.4274.6474.6475.6179.3479.4279.42
StudentResNet8x4VGG8MobileNet-V2WRN-16-2MobileNet-V2ShuffleNet-V1ShuffleNet-V2
73.06 ± 0.1670.94 ± 0.2465.85 ± 0.2273.50 ± 0.2565.85 ± 0.2272.40 ± 0.3873.80 ± 0.22
AT [12]73.61 ± 0.2671.76 ± 0.1760.42 ± 0.4774.30 ± 0.1758.06 ± 1.4473.57 ± 0.3473.66 ± 0.24
VID [26]72.98 ± 0.1271.00 ± 0.3065.72 ± 0.5373.87 ± 0.1565.77 ± 0.4772.78 ± 0.2373.84 ± 0.38
FITNET [16]73.71 ± 0.1771.48 ± 0.3264.44 ± 0.9973.61 ± 0.2164.33 ± 0.5674.39 ± 0.3575.09 ± 0.15
PKT [10]74.10 ± 0.3573.36 ± 0.1368.34 ± 0.1875.12 ± 0.1868.59 ± 0.8275.71 ± 0.3576.10 ± 0.20
KD [3]73.70 ± 0.3773.31 ± 0.2268.02 ± 0.2775.07 ± 0.2368.50 ± 0.4274.88 ± 0.2475.35 ± 0.12
RKD [14]72.72 ± 0.1671.71 ± 0.3965.97 ± 0.2973.82 ± 0.0965.95 ± 0.5273.84 ± 0.2374.87 ± 0.37
DKD [17]76.06 ± 0.2574.66 ± 0.1767.11 ± 0.5175.40 ± 0.1968.23 ± 0.4374.34 ± 0.2176.97 ± 0.46
MJKD76.42 ± 0.2674.75 ± 0.2368.76 ± 0.6975.52 ± 0.2369.78 ± 0.5576.38 ± 0.0977.25 ± 0.43
Table 6. Results on the Tiny-ImageNet validation.
Table 6. Results on the Tiny-ImageNet validation.
TeacherStudentAT [12]KD [3]PKT [10]DKD [17]MJKD
ResNet18 as the teacher, MobileNet-V2 as the student
Top-163.7455.51 ± 0.4857.20 ± 0.4457.26 ± 0.2557.43 ± 0.1661.17 ± 0.2263.53 ± 0.26
Top-583.5579.76 ± 0.2180.91 ± 0.2480.80 ± 0.3781.37 ± 0.2283.55 ± 0.1984.64 ± 0.18
ResNet18 as the teacher, ShuffleNetV2 as the student
Top-163.7451.65 ± 0.3653.56 ± 0.1152.32 ± 0.4252.50 ± 0.4556.09 ± 0.1857.68 ± 0.23
Top-583.5576.67 ± 0.3078.16 ± 0.2177.31 ± 0.2377.40 ± 0.3779.96 ± 0.1380.80 ± 0.27
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

He, W.; Pan, J.; Zhang, J.; Zhou, X.; Liu, J.; Huang, X.; Lin, Y. Multiloss Joint Gradient Control Knowledge Distillation for Image Classification. Electronics 2024, 13, 4102. https://doi.org/10.3390/electronics13204102

AMA Style

He W, Pan J, Zhang J, Zhou X, Liu J, Huang X, Lin Y. Multiloss Joint Gradient Control Knowledge Distillation for Image Classification. Electronics. 2024; 13(20):4102. https://doi.org/10.3390/electronics13204102

Chicago/Turabian Style

He, Wei, Jianchen Pan, Jianyu Zhang, Xichuan Zhou, Jialong Liu, Xiaoyu Huang, and Yingcheng Lin. 2024. "Multiloss Joint Gradient Control Knowledge Distillation for Image Classification" Electronics 13, no. 20: 4102. https://doi.org/10.3390/electronics13204102

APA Style

He, W., Pan, J., Zhang, J., Zhou, X., Liu, J., Huang, X., & Lin, Y. (2024). Multiloss Joint Gradient Control Knowledge Distillation for Image Classification. Electronics, 13(20), 4102. https://doi.org/10.3390/electronics13204102

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop