1. Introduction
Federated learning (FL) enables multiple clients to contribute to training a global machine learning model without sharing their data with a central server [
1]. Clients perform computations on their local data and send only the model updates to a central server, which aggregates them to improve the global model. The global model is then redistributed to the clients for subsequent training rounds. This framework ensures data safety by storing the data only on client devices, thereby minimizing the risk of breaches [
2]. In the context of healthcare, this approach is particularly valuable because it enables collaborative research and model training across multiple medical institutions while complying with strict privacy regulations and minimizing the risk of exposing sensitive patient data [
3].
Distillation is a machine learning technique that trains a simpler model, called a student, to mimic the actions of a more complex model, called a teacher, which typically improves efficiency without sacrificing accuracy. Federated distillation extends this approach to a decentralized setting, allowing many devices to train a student model collaboratively while keeping their data localized [
4]. Recently, federated distillation has attracted considerable attention. Federated distillation captures and communicates the learning experience via logits, which are the pre-activation function outputs of individually trained models. This approach significantly reduces the communication overhead compared to traditional FL [
4]. It also provides a balance between flexibility and security. Clients can use models suitable for their computational capabilities [
5]. At the same time, the risks of information exposure are significantly reduced by transmitting only distilled knowledge via logits rather than raw data, thereby increasing the data privacy level [
6].
Traditional federated distillation methods rely on a uniform global logit, which results in reduced accuracy, particularly when the data have a distinct group structure. To illustrate this, consider FL between hospitals specializing in different types of medical treatments. A hospital specializing in cancer will have a dataset containing only different types of cancer, while another hospital specializing in infectious diseases will have images labeled “infection”. In this context, the use of a uniform global logit compromises the quality of the global model, making it biased and less accurate.
Traditional federated distillation methods rely on a uniform global logit, which results in reduced accuracy. The label distribution skew [
1], in which data labels are not equally distributed across clients, is a common problem in federated learning. However, traditional federated distillation trains all clients with the same global logit. This is likely to yield sub-optimal results for each client’s data distribution. To illustrate this, consider FL between hospitals specializing in different types of medical treatments. A hospital specializing in cancer will have a dataset containing only different types of cancer, while another hospital specializing in infectious diseases will have images labeled “infection”. In this situation, it would not be appropriate to train all clients with a single global logit. It would be better to divide the clients into groups based on the labels they have and train the clients with different logits based on the distribution of the clients’ labels per group.
Although clustering techniques exist in FL, to the best of our knowledge, no method has integrated clustering with federated distillation. Most clustering algorithms in FL use model parameters [
7,
8] for clustering. However, federated distillation only exchanges the output of the client model and does not exchange model parameters, so this approach is not applicable. Therefore, a clustering method suitable for federated distillation is needed that utilizes the output of the model. To focus on the label distribution skew, we use the label predicted by the client model. Therefore, we proposed a method that classifies client models based on the number of times they predict each label.
Figure 1 illustrates our algorithm, which utilizes information about clusters for effective distillation. In practice, the number of groups is often unknown. The algorithm we propose addresses this issue by using hierarchical clustering, which eliminates the need for prior knowledge of the number of clusters.
In FL, fairness requires sensitive groups, such as gender and race, to not experience disparate outcome patterns such as different accuracy [
1]. Unfortunately, minority social groups are often underrepresented in the training data. Therefore, their accuracy is often degraded. When the size of each client group varies, the existing methods significantly undermine the performance of minority client groups. On the other hand, our method performed well regardless of group size by assigning a logit that fits the distribution of the group data. This allowed us to become closer to a fair FL.
Guided by an empirical analysis of the esteemed MNIST and CIFAR datasets, we demonstrate that the clustering accuracy via prediction exceeds 90%. We also achieve high accuracy for each client model compared to traditional federated distillation methods in settings where an apparent group structure exists. Performance increased by up to 75%, and the greater the difference in the data distribution between each group, the greater the advantage of our algorithm. We show that our algorithm is effective even when the data is sparse.
The main contributions of this paper are:
We propose the first federated distillation approach that utilizes the predictions of the clients model on public data to cluster clients.
We show that our approach results in successful clustering even when the boundaries between each client group are unclear.
We demonstrate the effectiveness of our approach under challenging conditions such as insufficient data and ambiguous group boundaries. It also improves the performance of minority groups, bringing us closer to a fairer FL.
Our paper is organized as follows.
In
Section 2, we review the relevant literature on clustering techniques in federated learning and federated distillation methods.
In
Section 3, we formally define the problem and present the proposed clustered federated distillation algorithm using a label-based group structure.
In
Section 4, we evaluate the proposed method
- -
In
Section 4.1, we present the setup used in our experiments.
- -
In
Section 4.2, we compare the clustering accuracy of our algorithm with existing FL methods using clustering in a label-based group structure.
- -
In
Section 4.3, we compare the performance of our method with existing federated distillation methods in a label-based group structure.
In
Section 5, we summarize the main results and limitations and suggest directions for future work.
5. Discussion
In this study, we address the scenario of different data distributions between different client groups in federated distillation. We introduce a methodology that uses hierarchical clustering to categorize clients according to the number of labels predicted by each model for public data. This approach overcomes the limitations of traditional federated distillation techniques that assume a uniform data distribution when a label-based group structure exists. Our method can be used when different groups (e.g., demographic groups) have significantly different data distributions to ensure that all groups receive equally good results.
The experiments show that the model correctly classifies groups with different labels. The accuracy of the model exceeds that of traditional methods when there is a clear cluster structure based on labels. In particular, the accuracy of a small number of groups, which is problematic in traditional federated distillation, is significantly improved. This may pave the way for fair FL. Furthermore, our method does not require knowledge of the number of clusters, making it applicable in a wider range of environments. However, as the group structure becomes less clear, the performance gap between our method and existing algorithms narrows. We will continue to improve our method to perform better in ambiguous group structures.
It would be an interesting research topic to combine our method with different data types, such as text, more complex images, or time series data. Our method could also be combined with data-free distillation where no public data exists. Our algorithm will also be very effective in the presence of malicious clients that send false predictions to the server. By creating a group of malicious clients, we can ensure that other clients are not affected by them.