1. Introduction
Deep learning is widely used in clinical scenarios, such as disease screening, health management, diagnosis and treatment. Obtaining models that can perform various medical tasks well often requires a large amount of training data; however, due to privacy limitations in the medical field, it is not possible to pool data from various medical sites to form larger datasets, which isolates each medical site and means that models can only be trained with a small amount of local data, resulting in the poor performance of trained models. Federated learning [
1] has been proposed as an effective solution to this problem. Firstly, as a kind of distributed machine learning, federated learning can jointly train global models for multiple medical institutions by combining data and annotations from each institution to expand the sample data volume and the number of annotations [
2], thereby making it possible to solve unbalanced data distributions. Secondly, federated learning does not require data exchanges among healthcare institutions, which satisfies requirements such as patient privacy protection, data security and government regulations. Additionally, the results of federated learning can be shared among medical institutions, which can alleviate the problem of the uneven distribution of medical resources to a certain extent.
The training process of federated learning involves medical institutions training model parameters based on local datasets, then sharing model parameters among medical institutions and finally fusing all model parameters in an aggregated manner to form better-performing models. When the data distributions of medical institutions are inconsistent, i.e., the assumption of independent and identical distribution (IID) is not satisfied among medical institutions, the complexity of the problem modelling, theoretical analysis and empirical evaluation of solutions increases, resulting in the degradation of model performance [
3]. A feasible idea to solve this problem is to share data distributions based on the model sharing in federated learning, i.e., share the data distribution information of different medical institutions with other medical institutions. This is similar to the sharing and exchange of treatment experiences among doctors at multiple medical institutions, which can improve treatment levels by learning from each other. In addition, there are certain requirements for data security while keeping shared data.
The initial federated learning framework was the centralised federated learning framework, which faced the problem that it is difficult to find trusted third parties to perform parameter aggregation [
4]. To solve this problem, decentralised federated learning frameworks have been developed, such as peer-to-peer network structures; however, they have certain requirements for the computing power of each client. Due to the frequent information exchanges between multiple clients, the communication costs are also relatively high. The decentralised federated learning architectures remove the central server to perform task model aggregation locally and only exchange information between adjacent clients on the communication graph, which reduces the probability of network congestion and communication overheads while improving data privacy protection capabilities. Therefore, these architectures are very suitable for the model exchange framework of federated learning and the exchange of shared data.Thus, based on this, our approach is proposed to improve the task model performance of federated learning for non-IID data.
To sum up, the main contributions of this work can be summarised as follows:
A novel unidirectional synchronous cyclic decentralised federated learning framework and an effective evaluation of the convergence of the model;
A new distribution information sharing and knowledge distillation model aggregation algorithm for the federated task model, which solves the problem of data distribution inconsistency both at the algorithm level and the data level;
The first attempt to use federated learning to diagnose Alzheimer’s disease based on medical datasets;
A way to measure the inconsistent distributions of data features using the maximum mean difference (MMD).
The rest of our paper is organised as follows.
Section 2 introduces related work.
Section 3 details our proposed approach.
Section 4 describes the experimental environment and our experimental results.
Section 5 concludes the paper and proposes future work.
2. Related Work
Since federated learning was first proposed, four main types of challenges have arisen: communication challenges, system challenges, statistical challenges and privacy challenges [
4]. We can refer to these two articles [
5,
6] for the communication challenges and system challenges of a cyclic federated learning framework, which have been analyzed and solved by predecessors. For privacy challenges, we can refer to the solutions in these two articles [
7,
8]. The privacy security protection strategies proposed in both papers consist of a privacy protection module and an attack detection module, while the major difference between the two is that the first scheme uses a two-level privacy data protection module. This scheme uses perturbation-based privacy converts categorical values into numeric and normalizes feature values into a range of [0, 1] before transforming the data using DL-based encoder techniques, which strengthens privacy and increases the utility of DL models. The statistical challenges, e.g., the non-independent and identical distribution of data (non-IID) problem, are some of the most non-negligible challenges in the application of federated learning in the medical field. Therefore, in this paper, we mainly focus on the non-IID problems.In response to non-IID problems, existing research has mainly solved the problems at the algorithm and data levels.
The algorithm-level solutions mainly include objective function modification and solution mode optimisation. Objective function modification involves adding regularisation terms on the client side. A trade-off has been achieved between optimising local models and reducing the differences between local models and global models to solve the non-independent homogeneous distribution of data at each node [
9,
10,
11,
12]. The measure of the differences between local models and global models by the regularisation terms can be either the distance between them or the differences in model behaviour. The distance measures between local and global models are Euclidean distances [
9] and weighted distances [
10]. For example, the federated proximal optimisation (FedProx) algorithm that has been proposed in the literature [
9] corrects the client-side drift that occurs in FedAvg by restricting the Euclidean distances between local models and global models as proximal terms. This means that the local updates do not excessively deviate from the global models, which alleviates any inconsistencies in the client-side data and improves the stability of global model convergence. The federated curvature (FedCurv) algorithm that has been proposed in the literature [
10] uses Fisher information from global models obtained during the previous rounds of training to weight the distances, which can reduce excessive errors in the model parameters. The differences in model behaviour between local and global models can be measured by the degree of inconsistency in the model output distributions on local datasets or by the gradient of the global models on local datasets. For example, in the literature [
11], the maximum mean discrepancy (MMD) has been used as a metric to measure the inconsistency in model output distributions on local datasets. The stochastic controlled averaging (SCAFFOLD) algorithm that has been proposed in the literature [
12] improves the FedProx algorithm by adding a control variable on the client side. This control variable can take either the gradient norm of global models on local datasets or the Euclidean distances between local and global models, thus preventing local models from deviating from the globally correct training direction. These methods can improve the performance of federated learning for model learning on non-IID datasets to some extent, but the degree of improvement is limited by the consistency of the client-side data sampling [
3].
In solution optimisation, the good performance of federated learning models is mainly achieved by improving the server-side aggregation method. The ideal application conditions for federated learning are IID-based datasets (such as the initially proposed FedAvg algorithm) and weights for clients that are proportional to the number of samples. The accuracy of global models is greatly degraded in the case of the inconsistent, unbalanced and non-independent distribution of client data [
13]. For this reason, most scholars have aimed to improve the shortcomings of aggregation methods for federated averaging algorithms. Accuracy-based averaging (ABAvg) has been in the literature [
14], in which the server-side tests the accuracy of temporary models on validation datasets to obtain the accuracy of the models on the client side and then normalises them before aggregating all parameters. The federated learning with matched averaging (FedMA) algorithm that has been proposed in the literature [
15] uses Bayesian non-parametric methods to match and average weights in a hierarchical manner. The federated averaging with momentum (FedAvgM) algorithm that has been proposed in the literature [
16] applies momentum when updating global models on a server. The federated normalised averaging (FedNova) algorithm that has been proposed in the literature [
17] normalises local updates before averaging. However, these methods have limited success in improving the performance of global models [
12], so some scholars have proposed approaches that evade this problem, such as personalised federated learning, multitask federated learning and federated meta-learning, which can also improve the performance of federated learning on non-IID data to some extent.
The source of global model performance degradation is the non-IID problem; thus, data-level approaches to sharing client-side data have become new options for solving the non-IID problem. Client-side data sharing can be divided into two types: direct data sharing and indirect data sharing. In terms of direct data sharing for federated learning, one approach is to use a global sharing strategy [
18,
19,
20], in which the server-side shares small amounts of public data with the client for training to reduce the variance between trained local models, thus increasing the robustness and stability of the training process. This sharing approach relies on task-specific public datasets, and, in practice, there is a risk of privacy violation during both the acquisition and sharing of public data. Another approach is to use a local sharing strategy [
21,
22], in which small amounts of data are shared directly through trusted communication links between clients; however, this approach also violates the privacy preservation conventions of federated learning.
Indirectly shared federated learning does not share data directly, but rather makes the distributions of client datasets consistent by sharing data distribution information on the client side and then augmenting local training datasets with the shared distribution information [
23,
24]. The data distribution information can be learned using generator networks, which can be divided into global and local generators, depending on how the generators are trained. For example, a global generator shared approach has been proposed in the literature [
23] that trains conditional generative adversarial network (CGAN) [
25] generators on central servers and then shares the generators with clients to share distribution information. However, the data required for training CGANs using central servers are extracted from all clients, and there is a risk of privacy violation during the transmission of extracted data from the clients to the server side. A local generator shared approach has also been proposed in the literature [
24] that trains bulldozer distance-based generative adversarial networks (i.e., Wasserstein generative adversarial networks, WGANs) [
23] on local datasets on the client side and shares them with other clients. An image translation network is then trained using local generators and other generators to solve the federated learning problem for client-side heterogeneous data. Implicit data sharing through generators does not cause any privacy problems and is more practical than direct data sharing because it meets the need for patient privacy protection in healthcare organisations.
The data-immobile and model-immobile nature of federated learning has led to its increasingly widespread application in fields with high requirements for sensitive data protection, such as medicine. To address the problem of the degradation of federal learning performance due to inconsistent data distributions among federated learning participants, federated learning for client-side data sharing has become an effective solution strategy. Among the different options, the approach of sharing data distributions rather than the data themselves is more appropriate for application because it does not create the risk of privacy violation. Therefore, we addressed this issue by integrating solutions at both the data and algorithm levels. See
Figure 1 for details of classification guidelines.
3. The Distribution Information Sharing- and Knowledge Distillation-Based Cyclic Federated Learning Method
The ultimate goal of federated learning is to jointly train optimal models for multiple clients; in this paper, we refer to these as task models, which are made by multiple medical institutions to obtain target models. Task models can be for the diagnosis of diseases, lesion segmentation, etc. In federated learning, local task models tend to be consistent with global task models; however, in the case of non-IID local client data, local task models deviate from global task models. In the existing state-of-the-art circular decentralised federated learning schemes, the model parameters of nodes are updated after multiple steps of weighted summation and then averaged, which is a complex and costly communication strategy. In addition, the weighted average approach to model parameter aggregation often yields poor task model performance on non-IID datasets because the client data distributions of neighbouring nodes may differ significantly and thus, the trained task models are biased. To address this, a natural idea is to degrade this bias by sharing data distributions to generate augmented datasets while preserving data privacy and then using the augmented data to learn the data distributions of other clients to achieve the implicit aggregation of model parameters. For this purpose, we used generators to learn the data distribution information of clients and share the local task models of clients, together with the local data generators, with neighbouring clients. Since both the generators and the task models carrying the data distribution information of the neighbouring clients were trained on the same datasets, this facilitated the use of the migration learning idea to aggregate the task models of two neighbouring clients. Based on this, we proposed a teacher–student model-based migratory learning approach for task model aggregation.
Figure 2 shows a general block diagram of our proposed approach.
Supposing that there are C clients involved in the federated learning task (where G is the shared generator model parameters that are locally trained offline, and w is the task model parameters that are dynamic shared weights), the overall process can be divided into two stages as follows:
Stage 1: The offline process. All clients participating in the federated learning task train the generator network offline on local datasets to obtain the generator network G that responds to local distribution information. Then, all clients pass the trained generator G to the next client in turn. The next client c+1 generates the corresponding virtually shared local data after receiving the generator from client c before.
Stage 2: The online process, which can be mainly divided into two steps. The first step is the knowledge distillation learning process, in which all clients first initialise the task model on local datasets and share it with the next client, and the next client then uses the shared task model to teach its task model on the data that were virtually shared via knowledge distillation. The second step simply re-updates the trained task model on local datasets again and shares it with the next client.
3.1. Distribution Information Acquisition Based on Deep Learning
To eliminate the adverse effects of the non-IID problem on the performance of medical institution federated learning, an effective approach is to augment the local datasets of medical institutions by sharing their data distributions. To obtain information about the data distributions of healthcare institutions, the current state-of-the-art approach is to use a generator model with deep learning. Generators are the most effective tools for data augmentation because they not only learn the distribution information of data effectively but also generate data that match the real distributions. Generative adversarial networks (GANs), as one of the current types of mainstream deep neural network generators, are powerful in terms of image enhancement and image-to-image conversion [
22]. Therefore, we adopted a GAN as a data generator on the main server to obtain the data distribution information of local clients [
26,
27,
28] and added conditional information to generate the type of data that we needed, i.e., the final generator model was a CGAN. Specifically, let the total number of clients (federated learning participants) participating in the federated learning task be
C, let the local datasets of the
c (
) client be
and let
be the number of clients in the training sample.The client
c trains a generator and reflects the distribution information
of local datasets
. Thus,
C clients are trained to obtain
C generator models. The distribution of information obtained in this way is relatively safe from privacy breaches.
3.2. Distribution Information Sharing
The purpose of sharing distribution information is to enable later clients in the cyclic communication graph to have virtually shared data about the previous client’s data distribution information, thus enabling two adjacent clients to achieve a consistent distribution of data to improve the performance of task models. To this end, we combined the features of a cyclic federated learning architecture and model parameters to accomplish this process. Let
and let the client
c transmit the generator
to the client
c+1. When
, let
, thus forming a ring-shaped communication link. Under the condition of this cyclic communication link, let the client
receive the generator
from the client
c, where
is the number of local data points from the client
c. Accordingly,
can generate
virtually shared data points, i.e.,
. Therefore, only the client
c+1 has the distribution information of the client
c, which indirectly realises distribution information sharing while protecting patient privacy. The distribution information sharing process is schematically illustrated in
Figure 3.
3.3. Task Model Parameter Aggregation
The task model parameter aggregation process focuses on how to use shared distribution information for model parameter aggregation to eliminate the adverse effects of the non-IID problem on federated learning performance. In our cyclic federated learning framework, the client c+1 not only receives the task model parameters from the client c through a trusted channel but also the generator model . The virtually shared data can be generated locally via . Since have consistent distributions across the local datasets of the client c, the task model obtained by the client c after training using has a good performance. However, the distributions of the local datasets of the client c+1 are usually not consistent with those of , such that performs worse on the local datasets of the client than on . As a result, existing model aggregation algorithms, such as federated averaging and its various improvements, performed poorly in our cyclic federated learning framework. To this end, we proposed a new method for model aggregation for federation learning tasks based on knowledge distillation.
Since the locally trained task model of client
c has a similar optimal performance on datasets
and
, the locally trained task model
of client
c+1 can be trained using the local task model
of client
c on the datasets
to improve performance. This idea could be implemented using the teacher–student model for migration learning, as shown in
Figure 4 and
Figure 5.
The training goal of our cyclic federated learning method based on the distribution of information sharing and knowledge distillation was the minimisation of the total loss function:
where
is a hyperparameter that controls the propensity of the local task model,
is the parameter of the local task model of the client
c + 1 (the task model to be trained can be the same or different for each client), and the loss function corresponding to
has the following definition:
where
is the loss of the task model
on the data sample
and
is the difference between the models of the adjacent clients
c and
c+1 in the cyclic communication graph, which is defined as follows:
where
are the knowledge distillation loss and student loss on the datasets
, respectively (which are defined in the same way as in the standard teacher–student model), and
and
are two hyperparameters with values of 0 when the adjacent client models are the same and values of greater than 0 when they are different; the smaller the difference, the smaller the value (and vice versa). According to the incremental convex optimisation theory, the minimisation equation (Equation (
1)) can be solved using the following iteration. At the
k-th iteration, the gradient descent update is first performed on the intermediate variable
:
where
is the gradient descent size, and the superscripts
k and
k+1 denote the values of the
k-th and
k-th+1 iterations, respectively. Then, the model parameters are updated as follows:
Using Equation (
6), the iteration of
learns the behaviour of
on the datasets
, thus optimising the performance of the local model
that was updated using Equation (
7) on the datasets
. After multiple further iterations of training, as shown in
Figure 4, all clients can learn the features of the data distributions of other clients via this cyclic framework, i.e., the training effect of a global model is reached. Ultimately, the adverse effects of the non-IID problem on medical institution-federated learning performance can be eliminated.
The above solution process can be described in pseudo-code as shown in Algorithm 1.
Algorithm 1 Federated learning algorithm based on distribution information sharing and knowledge distillation. |
Input:C clients, each with its own training datasets , generator and its own task model |
Output: Trained model parameter set |
1: for
do |
2: Client c sends to Client c+1 |
3: Client c+1 generates virtual shared data with |
4: end for |
5: for
do |
6: for do |
7: Client c sends to Client c+1 |
8: Client c+1 updates according to (3) and (6) |
9: Client c+1 updates according to (1) and (7) |
10: end for |
11: end for |
5. Summary
To address the non-IID problem in medical institution federated learning that cannot be effectively solved using existing federated learning techniques, this paper proposed a cyclic federated learning method (CFL_DS_KT) based on distribution information sharing and knowledge distillation. This is a novel and effective federated learning approach and, to the best of our knowledge, the first time we have used this unidirectional synchronous cyclic decentralised federated learning framework and effectively evaluated the convergence of a model with this structure. The experimental results also show that the task model achieves convergence under our proposed approach. Furthermore, in contrast to existing scholarly research solutions, we solve the non-IID problem by optimising the solution through the solution approach of distribution sharing and knowledge distillation. By considering both data-level and algorithm-level optimisation approaches, we achieve better performance of the federation learning model under non-IID while safeguarding client data privacy. In our extensive experiments on medical and public datasets, CFL_DS_KT shows a good improvement over various state-of-the-art methods, and its accuracy is closer to that of centralised learning. Further improvements in privacy preservation were achieved due to using a cyclic federated learning method. It also provided the idea of training federated learning models on heterogeneous data, which could eliminate data heterogeneity by transforming the data distribution information from one client to another.
However, our proposed approach has some shortcomings. When the client data is extremely heterogeneous, it is difficult to train a good generator to generate high-quality images due to the small amount of training data. Additionally, it is not suitable to train federated learning models with large numbers of clients as this could increase breakpoint failures and model training cycle times. Therefore, this method would mainly be suitable for federated learning across medical institutions.