見出し画像

Safeguarding Privacy with Federated Learning

TL;DR

1. What is federated learning?
    a. Federated learning is a distributed machine learning framework where a shared model is trained across multiple devices without the direct exchange of sensitive data such as financial or healthcare records.
    b. The philosophy behind federated learning is to move the model to the data, not the data to the model as in non-federated frameworks.
2. What are the advantages of federated learning?
    a. Due to privacy and security concerns, certain datasets cannot be shared directly, and therefore cannot be used by non-federated learning frameworks. However, these datasets hold untapped reserves of information which are useful in building more powerful models, and will enable the creation of new innovative services.
    b. Federated learning was developed as a solution to this problem, allowing the training of machine learning models on data that cannot be used directly.
3. What are some current applications of federated learning?
    a. Google uses federated learning to train Google G-board locally on Android devices.
    b. The healthcare sector is exploring the application of federated learning to digital health solutions.
    c. The finance sector has started to apply federated learning to the next generation of secure fintech products and services.
4. How do federated learning models compare against non-federated learning models?
    a. With sufficient training iterations, federated learning models can achieve performances comparable to non-federated learning models while preserving data privacy and security.

The Challenge of Privacy

Imagine you are the CEO of a pharmaceutical company. You want to bring your company into the modern age and exploit machine learning to refine its products. You want to use all the data generated by wearables, such as smartwatches and smartphones, to understand your patients better and speed up the development process of new products. You ask your DX team to create an app that will gather all the essential data from wearables and upload it to your company servers. That is when you learn about privacy and security issues.

Centralizing private data is a perilous exercise, but collecting personal health data complicates it even more. Depending on the country, different legislations exist to protect the patients from companies wanting to exploit their personal data, which complicates the app your company wants to create. Another potential problem is guaranteeing the security of the collected data. A centralized storage is practical, but also increases the risk of hacking.

Fortunately for you, your DX team has machine learning engineers who follow the latest technologies and they propose another approach: instead of collecting the data to train a model, the model is trained on the patient’s device and only the changes to the model are collected. The collected changes are then sent back to the company and incorporated into the model stored on the company servers. None of the private data is centralized and the model is trained nevertheless. This approach is called federated learning and is part of a family of privacy maintaining machine learning algorithms.

What is Federated Learning?

Generally, machine learning and data science is performed on data aggregated from multiple sources at a central server (figure 1). This poses both security and privacy issues as hackers only need to hack a centralized database to access the entire dataset, and system managers have direct access to all the data. Such security and privacy issues are one of the main factors certain industries such as healthcare and the financial services are reluctant to adopt machine learning and data science.

画像1

Figure 1: Non-federated machine learning architecture. Data from all clients are aggregated at a central server to directly train a global machine learning model.

One viable solution to the central database problem is federated learning [1]. Federated learning is a form of privacy preserving machine learning framework implemented across a loose federation of clients which are managed by a central server (figure 2). Each client contains a local dataset which is not shared with other clients or the central server. The central server manages the global model. During model training, each individual client calculates its local update to the current global model, and shares its updated model parameters with the central server. In contrast with non-federated learning algorithms, the individual datasets stay on the clients’ servers. Therefore, this method decouples model training from direct access to the original data, and can significantly reduce both privacy and security risks.

画像2

Figure 2: Federated learning architecture. Data from each client is used to calculate individual updates to the global model, which are aggregated at the central server.

Although federated learning was originally developed around the TensorFlow library and neural networks [1], the underlying framework is model agnostic, and can be applied to other machine learning models such as logistic regression [2] or random forests [3]. We expect to see a proliferation of federated learning libraries and packages in the near future as the technology matures.

Implementing Federated Learning

In this section, we present the implementation of the federated averaging algorithm proposed in [1]. Federated averaging is an implementation of federated learning which has been successfully implemented in federated learning libraries such as TensorFlow Federated. We will not go through the full derivation of the federated averaging algorithm, and interested readers should refer to the original paper for the full details [1]. Figure 3 shows the process flow of one iteration of the federated averaging algorithm which is described below.

画像3

Figure 3: One iteration of the federated averaging algorithm. Global model weights 𝑤 are broadcast to the local clients. Each local client uses its data to calculate local updates to 𝑤. These updates are then sent back to and averaged at the central server.

Due to typographical limitations of note.com, we are unable to write certain terms in proper mathematical notation within the text. While we have done our best to use unicode characters to replace subscripts, the unicode character for subscript 𝑘 is not well used, and has been written as "_𝑘" instead. 

As with non-federated learning problems, we want to find the set of global model weights w that minimizes the global loss function 𝑓(𝑤) for all 𝑛 data points

画像6

where 𝑥ᵢ and 𝑦ᵢ are the features and targets of the 𝒊th data point respectively. For non-federated learning problems, all 𝑛 data points are located on a centralized location, and can therefore be summed over directly in equation (1). However in federated learning the sum over all 𝑛 data points cannot be carried out directly as the data is spread out over 𝐾 individual clients. Therefore, we first calculate the values of the local loss function 𝑓_𝑘(𝑤) for the 𝑛_𝑘 data points on client 𝑘, where "_𝑘" means subscript 𝑘 

画像6

where 𝑃_𝑘 is the set of indices for the data on client 𝑘. The value of the global loss function 𝑓(𝑤) is then calculated by taking the weighted average of all 𝐾 local loss functions 𝑓_𝑘(𝑤)

画像6

Therefore, the sum over all 𝑛 data points on the centralized server in equation (1) is replaced by an inner sum over all 𝑛_𝑘 data points on client 𝑘, and an outer sum over all 𝐾 clients in equation (3).

The 𝑛_𝑘 data points for client 𝑘 are split into several smaller batches with size 𝑩, and stochastic gradient descent [4] is used to calculate the local updates 𝑤_𝑘 to the global model weights 𝑤 for all 𝑩 data points in each batch

画像7

where η is the local stochastic gradient descent learning rate. The operation in equation (4) is usually performed over multiple local training epochs 𝐸 to ensure convergence of the numerical solution. The local updates 𝑤_𝑘 are then aggregated at the central server, and the weighted average of 𝑤_𝑘 is taken as the update to the global model weights 𝑤

画像8

The set of operations outlined in equations (1) to (5) above is usually performed over several federated learning rounds to ensure that the global model weights converge to the correct numerical solution.

Practically, in a real ecosystem, there might be thousands of clients, and using them all during the training process is not feasible. Therefore, for the sake of efficiency, instead of using all the clients, only a fraction 𝐶 of the total number of clients is randomly selected for training during each training round.

The operations above can be outlined in the algorithm below [1]. For global model weights 𝑤, federated learning iterations 𝑡, 𝐾 clients with index 𝑘, client training fraction 𝐶 ∈ [0, 1], local training batch size 𝑩, local training epoch size 𝐸 and local stochastic gradient descent learning rate η,

Central server performs:
   Initialize global weights w
   for t = 0, 1do:
       m := max(C·K, 1)
       S_k := Set of m randomly chosen local clients
       for k ∈ S_k do:
           w_k := client_update(k, w)
       w := Σ_k w_k·n_k/n

client_update(k, w):
   β_k := Split client k data into several batches of size B
   for i = 0, 1… E-1 do:
       for batch b ∈  β_k do:
           w := stochastic_gradient_descent(w, b, η)
   return w

stochastic_gradient_descent(w, b, η):
   while tolerance level is not achieved do:
       Randomly shuffle the order of the data in batch b
       for (xᵢ, yᵢ) ∈ b do:
           w := w - η∇f(xᵢ, yᵢ, w)
   return w

Federated Learning vs Non-Federated Learning

Now that the theory is out of the way, let’s look at a simple demonstration of federated learning based on the TensorFlow Federated tutorial “Federated Learning for Image Classification”[5] using the TensorFlow Federated version of the extended MNIST handwritten digits dataset [6]. Compared to the original tutorial, we changed the original multilayer perceptron model to a convolutional neural network which is more suitable for analyzing handwriting images. We also implemented classification functionalities which were not originally present in the tutorial in order to test the model performance.

The MNIST dataset contains handwritten digits from 0 to 9 written by 3383 different handwriting individuals. A sample of the dataset is shown in figure 4. The dataset is heavily unbalanced as some individuals have only 10 handwriting submissions, while others have more than 100 as shown in figure 5. This is representative of datasets encountered in actual federated learning problems, where data will not be distributed equally amongst all the clients.

画像9

Figure 4: Sample of the MNIST handwritten dataset. Each image is accompanied by a target label at the top of the image.

For the demonstration, we split the entire dataset into 3383 individual clients according to their unique ID values. For each training iteration, 20 out of the 3383 clients are randomly selected for training, and training was performed for a total of 50 iterations. This represents the real life scenario where only a subset of all the clients will be available for training at any given time.

画像10

Figure 5: Distribution of the number of train and test handwriting samples per individual in the MNIST data. The amount of data per individual is not the same.

Instead of using the neural network presented in the original TensorFlow Federated tutorial, we use a convolutional neural network (CNN) with the architecture shown in figure 6. CNNs are more suited for the modelling of the image data found in the MNIST dataset as compared to multilayer perceptrons, and should converge to the correct solution more quickly. This is important as in federated learning, model training can take up significant amounts of bandwidth between the clients and the central server to transfer the model parameters.

画像11

Figure 6: Convolutional neural network architecture.

In addition to the federated learning model, we also use the same CNN architecture trained in the non-federated framework in order to investigate the differences in performance between federated and non-federated learning models.

Figure 7 shows the accuracy and loss for all 50 training iterations for both the federated and non-federated neural networks. The federated learning model converged to an accuracy of about 0.9 by the 30th training iteration. A total of 847 unique individuals out of 3383 were randomly used during the training process. This corresponds to a total of 85884 out of 341873 handwriting submissions (~25%) in the training dataset being actually used for federated model training.

画像12

Figure 7: Top: Federated vs. non-federated learning sparse categorical accuracy. Bottom: Federated learning vs non-federated learning sparse categorical cross entropy.

With non-federated learning models, all the data is aggregated in one single server and is used to train the model. As a result, it is to be expected that convergence is attained much quicker, and accuracy is higher than in the federated learning version as the model has seen all the data at least once. However, in Figure 7 the increase in accuracy is only from about 0.94 in the federated learning model to about 0.99 in the standard neural network. Furthermore, this level of accuracy was achieved despite the model seeing only about 25% of the full data set. Therefore both models can be said to have comparable performances.

Figure 8 shows some predictions by the federated learning model for a randomly chosen client. Each individual handwritten digit image in figure 8 is accompanied by the true value on the top left, and the predicted value on the top right. In general the predictions are accurate although some misclassifications, such as misclassifying 7 as 9, were made.

画像13

Figure 8: Test set of handwritten digits. The true label is on the top left, and the predicted label is on the top right of each handwritten digit image.

Conclusion

We have explored the fundamentals behind federated learning using a simple CNN demonstration. We have managed to build a handwriting classification model with good performance using the federated learning framework and showed that with sufficient training, federated learning models can have performances comparable to non-federated learning models. Although most federated learning libraries and packages are currently still in the experimental stages, we expect that federated learning will become widely used by the machine learning community, especially in the financial and healthcare sectors, in the coming future.

References

[1] Communication-Efficient Learning of Deep Networks from Decentralized Data, H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson and Blaise Aüera y Arcas. Proceedings of the 20 th International Conference on Artificial Intelligence and Statistics (AISTATS) 2017. JMLR: W&CP volume 54. arXiv:1602.05629v3.
[2] Parallel Distributed Logistic Regression for Vertical Federated Learning without Third-Party Coordinator, Shengwen Yang, Bing Ren, Xuhui Zhou and Liping Liu. https://arxiv.org/abs/1911.09824
[3] Federated Forest, Yang Liu, Yingting Liu, Zhijie Liu, Junbo Zhang, Chuishi Meng and Yu Zheng. https://arxiv.org/abs/1905.10053.
[4] https://en.wikipedia.org/wiki/Stochastic_gradient_descent
[5]https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification
[6] http://yann.lecun.com/exdb/mnist/