TickLab LogoTickLab

Federated Learning – The AI Revolution Protecting Privacy

Federated Learning (FL), introduced in 2016 by a research team at Google, offers a groundbreaking approach to training artificial intelligence models using distributed data—without compromising user privacy. Unlike traditional methods where personal data is collected and centralized on servers, FL allows each device to train the model locally on its own data. Only the trained parameters (like weights and biases) are sent back to a central server, where they’re aggregated into a global model. This method not only ensures strong privacy protection but also makes efficient use of the computational power already available on millions of devices.

Let’s dive deeper into how this cutting-edge technology works in the following sections!

I. Introduction

Federated Learning emerged around 2016 in response to the urgent need for training artificial intelligence and machine learning models on distributed data—while still preserving user privacy.

Before the advent of Federated Learning, most machine learning systems relied on collecting data from millions of devices (like smartphones, IoT gadgets, etc.) and transferring it to centralized data centers for model training. However, this approach posed two major challenges:

  • Privacy concerns: Users have become increasingly worried about their personal data being collected, misused, or leaked by companies and online services. New data protection regulations—such as the GDPR in Europe—have compelled organizations to rethink how they handle user data.
  • Scalability and latency issues: Continuously transferring large volumes of data from millions of devices to a central server is inefficient in terms of bandwidth and computational delay. Moreover, modern end-user devices (like smartphones and smartwatches) are now powerful enough to handle local training tasks on their own.

Federated Learning was born to tackle these very challenges.

To address these challenges, a research team at Google proposed a novel approach called Federated Learning (FL). Instead of transferring data from user devices to a central server, FL enables each device to train an AI model locally using its own private data. The device then sends only the trained model parameters (such as weights and biases) to a central server, which aggregates them—often by averaging—to form a global model. This method eliminates the need to transmit raw data, thereby enhancing privacy, reducing bandwidth consumption, and leveraging the computing power already available on edge devices.

Real-world applications of Federated Learning:

  • Finance: Regulatory frameworks like the GDPR require that customer data be stored within specific regions (e.g., U.S. data stays in the U.S., EU data in the EU). FL allows each region to keep its data local while still contributing to the training of a unified model across distributed datasets.
  • Mobile input prediction: Google uses FL in its Gboard predictive keyboard. Each device updates its local model based on user interactions, and only the trained model parameters—not the raw typing data—are shared to refine the global model.

II. Theoretical Foundations

1. Definition

“In a Federated Learning system, multiple parties collaborate to train machine learning models without sharing their raw data. The outcome is a trained model for each party (which may be identical or personalized). A practical Federated Learning system is expected to produce models that outperform individually trained local models on a given evaluation metric, such as test accuracy, while using the same architecture.”

From this definition, we can draw two key insights:

  • First, each party involved in training does not exchange raw data (which may contain sensitive information). Instead, only updates—such as gradients or model weights—are sent to a central server.
  • Second, after the collaborative training process, each participant receives a trained machine learning model. These models may be identical or customized for each party, but in all cases, their performance should surpass what could be achieved by training solely on local data.

2. Core Components

  • Server: The server orchestrates the training process: it selects clients for each communication round, collects model updates, and aggregates them to update the global model. Crucially, the server never accesses any raw training data.
  • Clients: Clients are distributed entities—such as hospitals or user devices—that:
    • Train models locally on their own private data.
    • Send the locally trained updates (e.g., model weights or gradients) back to the server for aggregation.

3. Algorithm

a) Overview

In general, Federated Learning (FL) algorithms follow a standard set of steps:

  • The server initializes a global model.
  • For each communication round:
    • The server sends the current global model to a selected group of clients.
    • Each client receives the model and performs local training on its private dataset for K epochs.
    • Clients then send their locally updated models back to the server.
    • The server aggregates these local updates using a specific aggregation algorithm.
      • For instance, in the popular FedAvg (Federated Averaging) algorithm, the server computes a weighted average of the model parameters (e.g., weights and biases) received from the clients to produce the updated global model.
b) The FedAvg Algorithm

To better understand the mathematical formulation behind these algorithms, readers are encouraged to refer to the paper "Communication-Efficient Learning of Deep Networks from Decentralized Data". This foundational work introduces the FedAvg algorithm and lays the theoretical groundwork for modern Federated Learning systems.

III. Illustrative Example

To demonstrate how Federated Learning works in practice, let’s walk through a hands-on example. We'll simulate non-IID data (which we'll explain later) using a sine wave function with added noise.

1# Function to generate noisy sine wave data
2# Used to simulate non-IID data distribution
3def generate_data(n_samples=100, start=0, end=1):
4 x = torch.rand(n_samples, 1) * (end - start) + start
5 y = torch.sin(5 * X) + torch.rand(n_samples, 1) * 0.5
6 return x, y

Next, we generate non-IID datasets for 10 clients, each covering a distinct input range:

1# Generate data for 10 clients with different distributions
2num_clients = 10
3data_clients = [generate_data(50, start=(i - num_clients / 2), end=(i + 1.5 - num_clients / 2)) for i in range(num_clients)]
4
5# Visualize the data distribution for each client using different colors
6for i, (X, y) in enumerate(data_clients):
7 plt.scatter(X, y, alpha=0.5)

We then define a simple neural network architecture to solve this regression task:

1# Define a simple neural network for regression
2class NeuralNetwork(nn.Module):
3 def __init__(self):
4 super(NeuralNetwork, self).__init__()
5 self.fc1 = nn.Linear(1, 64) # First layer: input → 64 nodes
6 self.relu = nn.ReLU() # ReLU activation
7 self.fc2 = nn.Linear(64, 32) # Second hidden layer
8 self.fc3 = nn.Linear(32, 16) # Third hidden layer
9 self.fc4 = nn.Linear(16, 1) # Output layer: 1 regression output
10
11 def forward(self, x):
12 x = self.fc1(x)
13 x = self.relu(x)
14 x = self.fc2(x)
15 x = self.relu(x)
16 x = self.fc3(x)
17 x = self.relu(x)
18 x = self.fc4(x)
19 return x

Now, let’s configure the training parameters for our federated learning simulation:

1# Training configuration
2num_rounds = 1000 # Total number of federated rounds
3local_epochs = 5 # Number of epochs per client per round
4learning_rate = 0.01 # Learning rate
5global_model = NeuralNetwork() # Initialize the global model

Define the local training function used by each client:

1def train_local_model(model, X, y, epochs=10):
2 criterion = nn.MSELoss() # Mean Squared Error loss for regression
3 optimizer = optim.SGD(model.parameters(), lr=learning_rate) # SGD optimizer
4
5 for epoch in range(epochs):
6 optimizer.zero_grad() # Reset gradients
7 y_pred = model(X) # Forward pass
8 loss = criterion(y_pred, y) # Compute loss
9 loss.backward() # Backward pass
10 optimizer.step() # Update weights
11
12 return model.state_dict() # Return the trained model parameters

Now let’s run the training process and update the global model using the Federated Averaging (FedAvg) method:

1# Federated Learning training loop
2for round in range(num_rounds):
3 local_params = []
4
5 # Each client trains their local model
6 for i, (X, y) in enumerate(data_clients):
7 local_model = NeuralNetwork()
8 local_model.load_state_dict(global_model.state_dict()) # Sync local model with the global model
9 local_params.append(train_local_model(local_model, X, y, epochs=local_epochs))
10
11 # Aggregate local models to update the global model (FedAvg)
12 global_params = {}
13 for k in local_params[0].keys():
14 global_params[k] = torch.stack([local_params[i][k] for i in range(num_clients)]).mean(0)
15
16 global_model.load_state_dict(global_params)
17
18 # Print loss every 100 rounds
19 if round % 100 == 0:
20 loss = 0
21 for i, (X, y) in enumerate(data_clients):
22 y_pred = global_model(X)
23 loss += nn.MSELoss()(y_pred, y)
24 print(f'Round {round}, Loss: {loss.item() / num_clients}')

Finally, let’s evaluate and visualize how well the global model approximates the sine function on a test dataset:

1# Evaluate the global model
2
3# Set model to evaluation mode
4global_model.eval()
5
6# Create test dataset to evaluate predictions
7X_test = torch.linspace(-5, 5, 200).unsqueeze(1)
8with torch.no_grad():
9 y_pred = global_model(X_test)
10
11y_true = torch.sin(X_test) # Ground truth sine values
12
13# Plot client data and the global model's predictions
14for i, (X, y) in enumerate(data_clients):
15 plt.scatter(X, y, alpha=0.5)
16
17plt.plot(X_test.numpy(), y_pred.numpy(), label="Global Model Prediction", color="red", linewidth=2)
18plt.xlabel("x")
19plt.ylabel("sin(x)")
20plt.title("Global Model Approximation of sin(x)")
21plt.legend()
22plt.show()

IV. The Non-IID Problem

In Federated Learning (FL), the term non-IID stands for Non-Independent and Identically Distributed data. It refers to situations where the data on each client device does not follow the same statistical distribution. In other words, the data available on one device might not be representative of the overall data distribution across all clients.

For example, consider training an image classification model for animals. A dog lover’s phone might mostly contain dog pictures, while a cat owner’s phone may have mostly cat pictures. Each client has its own local bias, which creates data heterogeneity across the system.

Each color here represents one client. It’s clear that individual distributions vary widely and don’t capture the full diversity of the overall dataset.

To better understand this, let’s consider training a digit recognition model using the MNIST dataset. If we sort the dataset by label (e.g., all "1" images first, then "2", and so on), and then train the model in batches without shuffling, each batch would contain only one digit. These batches wouldn’t represent the overall distribution, leading to biased gradients that can steer the optimization away from the global minimum. This reduces training efficiency and convergence quality.

Non-IID data presents one of the key challenges in Federated Learning, especially when trying to train a robust global model across highly diverse and personalized datasets. To gain a deeper understanding of this topic, readers are encouraged to consult the paper "Communication-Efficient Learning of Deep Networks from Decentralized Data".

Examining FedAvg in Detail: Two Scenarios

Case 1: K = 1 (FedSGD)
  • Full-batch:
    • Theoretically, if each client computes the gradient over its entire local dataset, and the server aggregates them proportionally to the data size, the resulting gradient should be identical to that of centralized training—assuming there’s no data overlap across clients.
    • However, full-batch training is computationally inefficient, especially when the number of clients is large and each has a sizable local dataset.
  • Mini-batch:
    • In practice, mini-batch training is preferred for its efficiency, but it becomes significantly more sensitive to non-IID data.
    • The main reason is that each mini-batch within a client may be biased relative to the global distribution. When model updates are aggregated after just one step of mini-batch gradient descent (K=1), gradient variance increases due to the heterogeneity. It’s similar to the earlier MNIST example—where sorted-by-label batches lead to poor convergence because they fail to represent the overall data distribution.
Case 2: K > 1 (FedAvg)
  • As the number of local epochs (K) increases per round (e.g., K = 5, 10, ...), each client strays further from the global model before sending updates.
  • If the data across clients is heavily non-IID, this causes client drift—each client’s model increasingly adapts to its own distribution rather than the collective one.
  • When the server aggregates these "drifted" models, the convergence process may oscillate or slow down, and the global model's performance can degrade compared to the IID case.

IV. Other Federated Systems

a) Federated Database Systems (FDBSs)

  • As early as the 1990s, FDBSs were studied extensively. An FDBS is a collection of autonomous databases that collaborate for shared benefits.
  • FDBSs are typically characterized by three properties:
    • Autonomy: Each participating database is independently managed and continues to function even without the federation.
    • Heterogeneity: Systems can differ in schema, query languages, software, or communication protocols.
    • Distribution: A single logical record might be horizontally or vertically partitioned across databases, and even replicated for fault tolerance.

b) Federated Cloud Systems (FCSs)

  • Federated Cloud Systems involve deploying and managing services across multiple cloud providers. They help reduce cost by outsourcing parts of workloads to regions where it's more economical.
  • Two main features include:
    • Data migration: Resources can be moved between different cloud vendors as needed.
    • Data redundancy: Services can run simultaneously across multiple regions, with data processed in parallel across providers following a shared computational logic.

Examples:

  • A user (e.g., Company A) needs additional servers for an upcoming promotional campaign and submits performance and budget requirements to a Cloud Broker.
  • The Cloud Broker forwards the request to a Cloud Exchange, which maintains a catalog of services from multiple Cloud Coordinators (cloud providers).
  • The Cloud Exchange matches the request, collects quotes and service conditions, and returns the most suitable package to the Cloud Broker.
  • The Cloud Broker finalizes negotiations, signs an agreement with the selected Cloud Coordinator, and deploys the resources for the user.

c. Comparison with Federated Learning

  • Common Ground:
    • All three systems share the underlying principle of collaborative architecture among autonomous entities. They emphasize decentralization, diversity, and independent operation of participants.
  • Key Differences:
    • Federated Database Systems focus on managing distributed data across heterogeneous databases.
    • Federated Cloud Systems concentrate on orchestrating and optimizing resource scheduling across multiple cloud environments.
    • Federated Learning prioritizes secure, privacy-preserving collaborative computation, especially in environments where raw data cannot be shared.

Comparison of the number of research papers on “federated database”, “federated cloud”, and “federated learning” from 1990 to 2020:

References