2. System Model and Problem Description
As shown in
Figure 1, we consider an edge-aided UAV network that includes one cloud server,
L edge servers located in base stations, and
N UAVs. The UAVs are divided into
L groups, and the
lth group of UAVs, denoted by
with cardinality
, is assigned to the
lth edge server.
Let
n be the total number of data samples across the UAVs, where the
ith UAV has a dataset, denoted as
, consisting of
data samples. The objective of the FL is to minimize the following (global) loss function:
In Equation (1), denotes the loss function for the global model w, and is the loss function for the jth data sample of the ith UAV.
In the naive FL, the training process begins with the central cloud server sending the global model
w to the UAVs [
5]. Then, at each step
t, the
ith UAV trains the global model
w locally, which results in a local model
, using its private dataset
based on the gradient descent method as
where
denotes a step size. After sending the local models
back to the central cloud server, the global model
w is updated via the following aggregation:
. The above procedure is repeated until a desired accuracy is achieved.
Suppose that the global model is updated in every
k steps; otherwise, the local models are trained. Then it follows that
where
denotes the remainder when
a is divided by
b.
Non-Identical Distributions among UAVs
Given the differences in commercial types of UAVs and their hardware, the data acquired by those different UAVs are highly likely to be non-i.i.d. This is especially the case of the FL in the practical UAV networks, compared to the traditional machine learning, where the training data are expected to be uniform. Therefore, heterogeneity of the UAVs leads to poor performance and convergence behavior of the FL due to the large deviation among the local models trained at the devices [
20]. Under these circumstances, training a model
in its data
will not be representative of the joint global model
f [
19]:
There are several ways in which the data among devices can be deviated from being i.i.d:
Feature distribution skew: The marginal distributions vary among the devices. That means the features of data are different between the devices. For example, the picture of the same object might differ in terms of brightness, occlusion, camera sensor, etc.
Label distribution skew: The marginal distributions variance, where devices have access to a small subset of all available labels. For example, each device has access to a couple of images of a certain digit.
Concept shift (different features, same label): The conditional distributions vary among the devices. This is the case where the same label y might have different features x among devices. In the digit recognition case, digits might be written in drastically different ways, which results in varying underlying features for the same digit.
Concept shift (same features, different label): The conditional distributions vary among the devices. Here, similar features might be labeled differently across devices. For example, different digits are written in very similar ways, such as 5 and 6, or 3 and 8.
In real-world scenarios, at least each of the above ways can occur in practice, and most datasets usually contain a mixture of them. The problem becomes even more severe in the edge-aided UAV network due to the existence of additional intermediate nodes (i.e., edge servers).
3. Hierarchical FL Algorithm
The key idea of the proposed hierarchical FL algorithm is that the edge servers are used as intermediate aggregators with commonly shared data to improve the performance of learning, even with non-i.i.d data. For this purpose, in practice, one can collect exemplary data samples from the UAVs and employ them as the commonly shared data.
In the hierarchical FL, the commonly shared data are used to train the local models at the edge servers. In addition, we suggest aggregating the local models of both the UAVs and edge servers hierarchically. Detailed explanations are given in the following.
The proposed hierarchical FL algorithm for the edge-aided UAV network is presented in Algorithm 1, where T is the overall aggregation step. In addition, C denotes the fraction of UAVs participating in the hierarchical FL, which are selected from the total N UAVs.
Algorithm 1 works as follows: First, the local models of the UAVs and the edge servers are all initialized with random weights
, and each edge server is assigned the commonly shared public dataset
that is equivalent to 5% of the overall dataset. Then, the UAVs and edge servers start training their local models (i.e., the global model of the previous round) in parallel using their private and commonly shared data, respectively. In every step of the global aggregation, the UAVs update their models with globally aggregated parameters
from the previous round. By averaging the model update, the magnitude of poisoned models can be reduced in the case of attack, which ensures that the single backdoor has less effect on the overall update procedure.
Algorithm 1 Proposed hierarchical FL algorithm. |
- 1:
Initialize and - 2:
fordo - 3:
for each edge do // in parallel - 4:
- 5:
for do - 6:
for each UAV do // in parallel - 7:
for do - 8:
LocalUpdate() in Algorithm 2 - 9:
end for - 10:
end for - 11:
EdgeAggregation(l, ) in Algorithm 3 - 12:
end for - 13:
end for - 14:
GlobalAggregation( in Algorithm 4 - 15:
end for
|
Algorithm 2 Local update procedure. |
- 1:
functionLocalUpdate() - 2:
(split into batches of size B) - 3:
for each local epoch i from 1 to do - 4:
for batch do - 5:
- 6:
end for - 7:
end for - 8:
return w - 9:
end function
|
After
local iterations, each UAV sends its local model
trained with private dataset
to the edge servers. Upon receiving the local models from the corresponding UAVs, the edge servers perform the
EdgeAggregation procedure in Algorithm 3, wherein the local models
of the edge servers are trained with the shared dataset
and those are then aggregated together with the local models of the UAVs.
Algorithm 3 Edge aggregation procedure. |
- 1:
functionEdgeAggregation() - 2:
- 3:
LocalUpdate()) // Edge local update - 4:
return - 5:
end function
|
After
iterations of the
EdgeAggregation procedure, the edge servers send their aggregated models
to the cloud server, where the global model
is obtained according to the
GlobalAggregation procedure in Algorithm 4. Overall, the local update at the
ith UAV assigned to the
lth edge server takes the following form:
which is clearly different from that of the traditional FL in Equation (3). We note that when the intermediate aggregator is unable to perform the
EdgeAggregation procedure due to the low system resources, which is the case when
and
, the overall process reduces to FedAvg Equation (3).
Algorithm 4 Global aggregation procedure. |
- 1:
functionGlobalAggregation() - 2:
- 3:
return - 4:
end function
|
In any FL algorithm, there is a decrease in the accuracy of training a machine learning model compared to the centralized learning method due to the weight divergence, which is mainly caused by the following two factors: different initialization of the models of the UAVs in the training process and the non-i.i.d nature of the underlying data distribution [
20]. As a result, there are two important factors that should influence the performance of the proposed hierarchical FL algorithm. The first one is the number
of iterations in the local updates of the UAVs and the number
of aggregation steps in the edge server before transmitting the update result to the global server.
Lower values of and , that is, fewer iterations steps between global aggregations, will reduce the communication cost in practice.
The percentage of commonly shared data is the second factor.
Since the edge servers act as aggregators in the hierarchical FL, it is possible for the edge servers to fine-tune the sizes of the shared dataset independently depending on the data distributions of the UAVs assigned to them. The overall training process is shown in
Figure 2.
Complexity Analysis
Suppose the completion time for the UAV to finish a single training round is , the transmission time of the UAV updates to the edge server is , the transmission time of aggregated model updates from the edge servers to central server is , and the overall communication time complexity in each round is . Since the edge server also acts as a base station between UAVs and central server, the communication time complexity of FedAvg is . Since the active number of users is magnitudes higher than the number of edge servers L, our proposed algorithm yields a small communication complexity compared to FedAvg.
4. Numerical Results
In simulations, we consider the image classification task to evaluate and compare the performance of the various FL algorithms. In this task, we consider two scenarios with different degrees of non-i.i.d. data distributions.
First, in Scenario I, the widely used MNIST dataset [
29] is set to the private dataset
at the UAVs as well as the commonly shared dataset
at the edge servers. To consider the situation with extremely non-i.i.d. data distribution, 100 UAVs and 10 edge servers are selected such that each UAV is given the data samples only with one class and each edge server is assigned 10 UAVs with 2 different classes in total (for example, the first edge server can be assigned the labels 3 and 5, and thus, each UAV assigned to it has the data samples only with either the label 3 or 5). This scenario well describes the case with label distribution skew, i.e., both when each UAV has the data samples only with one class and when each edge server is assigned the UAVs that have the same labels.
Second, in Scenario II, the Federated Extended MNIST (FEMNIST) dataset [
30] is used to classify 52 handwritten uppercase and lowercase letters in addition to the 10 digits, and the dataset is divided according to the writer of the characters with an unbalanced number of samples per UAV. The purpose of considering this scenario is to study the impact of feature distribution skew on the FL, where
is set to be different among the UAVs.
In total, 360 UAVs are assigned to 18 edge servers randomly. In both Scenarios I and II, 5% of the dataset is selected as a shared dataset for the edge servers. The dynamical nature of the UAV networks can lead to some of the devices being the bottleneck in the system (i.e., the straggler effect).
Finally, in Scenario III, we also perform experiments using a very low value of , to demonstrate the robustness of our system to high dropout or low participation rates due to the straggler effects. We use similar settings to Scenario II, while increasing the number of users from FEMNIST dataset to 3500.
In addition, for the purpose of performance comparison, we report the accuracy of the model in every times.
For the MNIST dataset in Scenario I, we construct a convolutional neural network (CNN) with four layers: the first two convolutional layers using 10 and 20 filters, respectively, with a kernel size of 5, followed by two fully connected layers with 50 and 10 units, respectively. The FEMNIST dataset in Scenario II is evaluated using a similar CNN: two convolutional layers using 32 and 64 filters with a kernel size of 5 and two fully connected layers with 1024 and 62 units. At each UAV, the stochastic gradient descent is used to update the local models, where the batch size is set to 32, and the learning rate is set to 0.01 with exponential weight decay of 0.995 after every step of the global aggregation.
4.1. Evaluation Metrics
We split the data of each user into 90% training and 10% test sets, and report the results on the test set. In order to evaluate the performance of the model, we measure the top-1 accuracy of all users and average them to obtain the average test accuracy of the whole network. Since the average accuracy might not take poorly performing users into account, we also measure what percentage of users are achieving the desired target accuracy threshold. This allows to better understand the fairness of the model in non-i.i.d. scenarios, where the underlying data distribution can heavily affect the accuracy of the global model in a particular user. We set the target accuracy threshold to 98% and 80% for Scenario I and Scenario II, respectively.
4.2. Experimental Results
Figure 3 shows a performance comparison between the proposed hierarchical FL algorithm and other existing FL algorithms, such as FedAvg [
5], HierFAVG [
22], and HFEL [
23] for Scenario I.
In this figure, we set
to 10 and
C to 0.2. The naive FL algorithm in [
5] performs worst and even achieves low accuracy (below 70% in terms of average accuracy after 50 rounds). Although another FL algorithm developed in [
22,
23] performs better than that of [
5], it still fails to achieve the desired accuracy level (e.g., above 98%) for the case with the label distribution skew.
On the other hand, the proposed hierarchical FL algorithm not only achieves the highest accuracy (98.3% average test accuracy across the UAVs), but also converges very fast and stably (fewer than 20 iterations of the global aggregation). Experiments conducted with the FEMNIST in Scenario II also demonstrate a similar trend as shown in
Figure 4, where the proposed hierarchical FL algorithm still significantly outperforms the other FL algorithms.
In
Table 1, we compare the percentage of UAVs that achieve the target accuracy of 98% in both Scenarios I and II. From
Table 1, it can be seen that the proposed hierarchical FL algorithm considerably outperforms the others in both scenarios. Specifically, in Scenario I, the proposed hierarchical FL algorithm has 66% of the UAVs reaching the target accuracy level, higher than those of the FL algorithms in [
22,
23], more than two times. Note that the naive FL algorithm in [
5] only has 6% of UAVs above 98% accuracy, which is very low. Experiments for Scenario II also show the clear advantage of the proposed hierarchical FL algorithm over the other FL algorithms. Specifically, we can see the performance gain around 10% compared to [
22,
23], and it is almost 10 times better compared to [
5]. The performance gain is even greater for the case with label distribution skew. The trend continues in Scenario III, where the proposed algorithm performs even better in both of the metrics compared to other schemes.
Table 2 lists the performance of the proposed hierarchical FL algorithm in terms of the average test accuracy and the percentage of UAVs reaching the target accuracy levels, for both scenarios by varying the hyperparameters
C,
, and
. In Scenario I, almost all results with different values of
C converge to more than 98% average accuracy across the UAVs. However, there is an insignificant difference in accuracy with higher values of
C, but the percentage of UAVs with more than 98% test accuracy increases noticeably, in which moving from
to
results in approximately 10% difference. The increase in the number of iterations also leads to a significant improvement in terms of the percentage of UAVs reaching the target accuracy in both scenarios for all values of
C. Based on these results, there is a little advantage in increasing the number of participating UAVs if the other hyperparameters are tuned carefully. In addition, the high fraction of UAVs leads to significantly more communication and computational overheads, which might be an issue in practical resource-constrained edge-aided networks.
To further analyze the generalization of the proposed algorithm, we perform an additional set of experiments in NLP (natural language processing) tasks, using Shakespeare and Sent140 datasets from the LEAF [
30] benchmarks suite. We sample the users in non-i.i.d. fashion, where each speaking role and tweeter user represents the individual user in FL settings. For the Shakespeare dataset, we consider the model with the embedding layer that maps the input into 8 dimensions and LSTM of 2 layers with 256 units followed by a fully connected layer for prediction. We use the input the sequence with 80 characters, learning rate of 1.0 and 549 users with the participation rate
C of 0.1. For the Sent140 dataset, we construct a similar model with 2 layer LSTM with 100 units following pretrained 300D GloVe embeddings [
31] that takes a sequence with 25 characters as an input. We drop the users with fewer than 50 samples and set the learning rate to 0.1 for all experiments.
As shown in
Figure 5 and
Figure 6, our proposed method yields better results compared to the existing schemes in both metrics, validating that the proposed algorithm is robust in various FL applications.