With the help of data, companies are able to make more informed decisions, optimize their workflows and gain an edge in the competitive world of business using the power of Machine Learning (ML). However, handling data has become increasingly difficult. Today, firms face the challenge of adapting to ever-changing regulatory and cybersecurity requirements, as well as ensuring the privacy of data owners. Due to these challenges, many industries still have limited access to the cutting-edge data technologies of the 21st century.
So what can we do about it? Well, if we cannot move data to machine learning (ML) models, then how about moving the ML models to the data? This is exactly what federated learning is hoping to achieve.
Federated learning is all about moving the machine learning operations from the cloud to the edge devices and interacting with the data locally. The two basic components of federated learning setup are server and multiple clients. Below you can see the fundamental steps of a federated learning system:
As shown in the graphic (Authors Work), we start by initializing a global model. We then send this model to the individual clients, who train it on their local data. Later, these clients return their updates to the server, where they're aggregated. At this point, one iteration is complete, and we transition back to the first step – the server sends the updated global model back to the clients and the whole process starts all over again.
That's the basic concept of federated learning! As you can already guess, for each step that you see on the diagram, there is more to what you see. The big concepts include: strategy, privacy-enhancing techniques (PETs), and secure multiparty computation (SMPC). The goal of this article series, which consists of 2 blog posts, is to provide a fundamental overview over these topics and how each of them relates to federated learning.
Strategy for federated learning
What is "strategy" then in the context of federated learning? Strategy is our journey of trying to figure out the most effective way to train the model and aggregate the weights. Weights are the learnings of the model from the training. Let's think about client selection and assume that we have 1000 clients. Should we pick all the clients, when we start the training? More clients means more data. But is more data always better for ML? Research has shown that using all the clients might actually lead to slower convergence, which means that the machine learning model will learn slower compared to if we had used fewer clients with high-quality datasets and less communication overhead (Németh et al. 4). One thing to keep in mind is that this is what the empirical research suggests and not a rule, so it is not set in stone and all the ideas that are mentioned here should be taken with a grain of salt as the "best" strategy is heavily context dependent. So, how then do we choose the clients we want to include in our update step? Not all clients are the same in their communication or computational capabilities as well as their access to the data points. There have been many strategies suggested by research. Some of them include prioritizing the clients with unique data points (Németh et al. 6), while some suggest choosing the clients in the most energy-efficient way possible (Németh et al. 5).
Now let's have a look at it from the server (aggregator) side. The amount of data points available to each client varies and, therefore, the trained weights also have different “experience levels”. A model that has been trained on 1000 data points is at a different stage than one that has merely seen 100 data points. After the aggregation, the model with 1000 data points will have more influence on the global model. In order to give a fairer distribution to the clients, the FedAvg algorithm for example can be used to average the model updates coming from the clients.
In the above diagram (Németh et al. 5), you can see the broad range of ideas that were proposed by researchers over the years.
Proof of Concept: Predictive Maintenance
Now, let's take a deeper dive into it in an industrial context. I would like to present a proof of concept that I worked on to give you a glimpse of how federated learning in practice works. The proof of concept is about predictive maintenance. It's also a great example of how federated learning can be used to tackle key industrial challenges by enabling a greater use of ML.
Importance of ML-supported predictive maintenance
Machines are at the heart of the industry and their downtimes are associated with considerable costs. Predictive maintenance refers to the concept of predicting when machines will fail and performing maintenance before they do. According to a study by McKinsey & Company, with the help of artificial intelligence, “availability can sometimes increase by more than 20%. Inspection costs may be reduced by up to 25% and an overall reduction of up to 10% of annual maintenance costs is possible.” (McKinsey & Company, Inc. 8).
However, the elephant in the room is the limited availability of data to individual factory owners. How often do modern machines actually fail? Perhaps once or twice a month? And how often do these failures have the same cause? With only a small number of machines available to each organization, it will be hard to collect enough quality failure data.
In a traditional ML environment, after acknowledging the lack of sufficient data, we could look for similar data outside our organization. Understandably, however, factory owners are hesitant about sharing their data with external organizations - a tricky scenario for conventional ML, however not so much for federated learning since it allows a large number of machines to contribute their data to one larger central ML model, while also preserving privacy.
We chose the Flower framework as it is very beginner friendly and has quite an active community that is ready to help with discussions and questions. We used the “Machine Predictive Maintenance Classification predictive maintenance” dataset from Kaggle. It is a synthetic dataset and therefore perfect for our proof of concept since there is not much preprocessing involved. The dataset was partitioned into smaller pieces so that each client had an unique subset available to them to train the model locally. Let's take a closer look at the server and the client side.
On the server side, first, you define and compile your model as you normally would. As our strategy, we chose FedAvg, giving equal influence to each client on the global model when the models are aggregated. It is not the best algorithm to use but its simplicity makes it quite good for our proof of concept. For the initial parameters, we use random values. However, in a real-world scenario, you could request a client device to supply the initial weights by initiating the local training exclusively for that client. This approach provides more realistic starting weights, which can lead to faster convergence.
From the image (Authors Work), you can get an idea of what the server outputs and the steps it goes through. As you can see, after initialization, the server follows the steps that we described above. Each round, it samples a group of clients and uses them for training (fit_rounds) and then aggregates them and provides an evaluation of that round.
On the client side, we use the same model architecture. You might wonder why we chose the same architecture. The reason is that it would be impossible to aggregate the weights without knowing what kind of model architecture they belong to. Consider the weights as materials for constructing a building. Without the blueprint (the model architecture) of the building, placing the materials in the correct positions would be impossible.
In the image (Authors Work), you can see that the client first establishes a connection with the server. After that, it trains the model using local data and also indicates how the model performs on that data.
Federated Learning is a fascinating concept that is not only interesting from a technical perspective but also from a business perspective. The idea of training models directly on users' distributed data sources makes it possible to extend data-intensive ML applications into areas that were previously not possible, due to privacy concerns or limited data access. In the next blog post, we will dive into the advanced techniques and what problems they try to solve.
Németh, Gergely Dániel, et al. "A Snapshot of the Frontiers of Client Selection in Federated Learning." Transactions on Machine Learning Research, 2022, https://openreview.net/forum?id=vwOKBldzFu
McKinsey & Company, Inc. Smartening up with Artificial Intelligence (AI) - What’s in it for Germany and its Industrial Sector? Digital McKinsey, 2017, www.mckinsey.com/~/media/mckinsey/industries/semiconductors/our%20insights/smartening%20up%20with%20artificial%20intelligence/smartening-up-with-artificial-intelligence.ashx . Accessed 24 July 2023.
Your job at codecentric?
More articles in this subject area
Discover exciting further topics and let the codecentric world inspire you.