Skip to content

Commit

Permalink
Merge branch 'main' into Add_contrastive_losses
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml authored Nov 1, 2023
2 parents 810a980 + 0479620 commit 75fb81d
Show file tree
Hide file tree
Showing 37 changed files with 2,127 additions and 1 deletion.
41 changes: 41 additions & 0 deletions examples/models/moon_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class HeadCnn(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(256, 10)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = self.fc1(input_tensor)
return x


class ProjectionCnn(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(120, 256)
self.fc2 = nn.Linear(256, 256)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = F.relu(self.fc1(input_tensor))
x = self.fc2(x)
return x


class BaseCnn(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
return x
48 changes: 48 additions & 0 deletions examples/moon_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Moon Federated Learning Example
This example provides an example of training a Moon type model on a non-IID subset of the MNIST data. The FL server
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client
has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties
of the clients training/validation data. In theory, the models should be able to perform well on their local data
while learning from other clients data that has different statistical properties. The subsampling is specified by
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`.

The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification.

## Running the Example
In order to run the example, first ensure you have the virtual env of your choice activated and run
```
pip install --upgrade pip
pip install -r requirements.txt
```
to install all of the dependencies for this project.

## Starting Server

The next step is to start the server by running
```
python -m examples.moon_example.server --config_path /path/to/config.yaml
```
from the FL4Health directory. The following arguments must be present in the specified config file:
* `n_clients`: number of clients the server waits for in order to run the FL training
* `local_epochs`: number of epochs each client will train for locally
* `batch_size`: size of the batches each client will train on
* `n_server_rounds`: The number of rounds to run FL
* `downsampling_ratio`: The amount of downsampling to perform for minority digits

## Starting Clients

Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three
clients. This is done by simply running (remembering to activate your environment)
```
python -m examples.moon_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers>
```
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
automatically downloaded to the path specified and used in the run.

The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client
downsamples these digits (using the `downsampling_ratio` specified to the config).

After both clients have been started federated learning should commence.
Empty file.
63 changes: 63 additions & 0 deletions examples/moon_example/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
from pathlib import Path
from typing import Sequence, Set, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.moon_cnn import BaseCnn, HeadCnn, ProjectionCnn
from fl4health.clients.moon_client import MoonClient
from fl4health.model_bases.moon_base import MoonModel
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.sampler import MinorityLabelBasedSampler


class MnistMoonClient(MoonClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
minority_numbers: Set[int],
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)
self.minority_numbers = minority_numbers

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = self.narrow_config_type(config, "batch_size", int)
downsample_percentage = self.narrow_config_type(config, "downsampling_ratio", float)
sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
model: nn.Module = MoonModel(BaseCnn(), HeadCnn(), ProjectionCnn()).to(self.device)
return model

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
parser.add_argument(
"--minority_numbers", default=[], nargs="*", help="MNIST numbers to be in the minority for the current client"
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
minority_numbers = {int(number) for number in args.minority_numbers}
client = MnistMoonClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers)
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client)
client.shutdown()
10 changes: 10 additions & 0 deletions examples/moon_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Parameters that describe server
n_server_rounds: 3 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training

# Downsampling settings per client
downsampling_ratio: 0.1 # percentage of original mnist data to keep for minority numbers
79 changes: 79 additions & 0 deletions examples/moon_example/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
from functools import partial
from typing import Any, Dict

import flwr as fl
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config, Parameters
from flwr.server.strategy import FedAvg

from examples.models.moon_cnn import BaseCnn, HeadCnn, ProjectionCnn
from examples.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.model_bases.moon_base import MoonModel
from fl4health.utils.config import load_config


def get_initial_model_parameters() -> Parameters:
# Initializing the model parameters on the server side.
# Currently uses the Pytorch default initialization for the model parameters.
initial_model = MoonModel(BaseCnn(), HeadCnn(), ProjectionCnn())
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])


def fit_config(
local_epochs: int, batch_size: int, n_server_rounds: int, downsampling_ratio: float, current_round: int
) -> Config:
return {
"local_epochs": local_epochs,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
"downsampling_ratio": downsampling_ratio,
"current_server_round": current_round,
}


def main(config: Dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
config["local_epochs"],
config["batch_size"],
config["n_server_rounds"],
config["downsampling_ratio"],
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(),
)

fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
strategy=strategy,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Server Main")
parser.add_argument(
"--config_path",
action="store",
type=str,
help="Path to configuration file.",
default="examples/moon_example/config.yaml",
)
args = parser.parse_args()

config = load_config(args.config_path)

main(config)
119 changes: 119 additions & 0 deletions fl4health/clients/moon_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import copy
from pathlib import Path
from typing import Dict, Optional, Sequence

import torch
from flwr.common.typing import Config, NDArrays

from fl4health.checkpointing.checkpointer import TorchCheckpointer
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.moon_base import MoonModel
from fl4health.utils.losses import Losses, LossMeterType
from fl4health.utils.metrics import Metric, MetricMeterType


class MoonClient(BasicClient):
"""
This client implements the MOON algorithm from Model-Contrastive Federated Learning. The key idea of MOON
is to utilize the similarity between model representations to correct the local training of individual parties,
i.e., conducting contrastive learning in model-level.
"""

def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
temperature: float = 0.5,
contrastive_weight: float = 10,
len_old_models_buffer: int = 1,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
checkpointer=checkpointer,
)
self.cos_sim = torch.nn.CosineSimilarity(dim=-1)
self.ce_criterion = torch.nn.CrossEntropyLoss().to(self.device)
self.contrastive_weight = contrastive_weight
self.temperature = temperature

# Saving previous local models and global model at each communication round to compute contrastive loss
self.len_old_models_buffer = len_old_models_buffer
self.old_models_list: list[MoonModel] = []
self.global_model: MoonModel

def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
preds = self.model(input)
preds["old_features"] = torch.zeros(self.len_old_models_buffer, *preds["features"].size()).to(self.device)
for i, old_model in enumerate(self.old_models_list):
old_preds = old_model(input)
preds["old_features"][i] = old_preds["features"]
global_preds = self.global_model(input)
preds["global_features"] = global_preds["features"]
if isinstance(preds, dict):
return preds
elif isinstance(preds, torch.Tensor):
return {"prediction": preds}
else:
raise ValueError("Model forward did not return a tensor or dictionary or tensors")

def get_contrastive_loss(
self, features: torch.Tensor, global_features: torch.Tensor, old_features: torch.Tensor
) -> torch.Tensor:
"""
This constrastive loss is implemented based on https://github.com/QinbinLi/MOON.
The primary idea is to enhance the similarity between the current local features and the global feature
as positive pairs while reducing the similarity between the current local features and the previous local
features as negative pairs.
"""
assert len(features) == len(global_features)
posi = self.cos_sim(features, global_features)
logits = posi.reshape(-1, 1)
for old_feature in old_features:
assert len(features) == len(old_feature)
nega = self.cos_sim(features, old_feature)
logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1)
logits /= self.temperature
labels = torch.zeros(features.size(0)).to(self.device).long()

return self.ce_criterion(logits, labels)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:
assert isinstance(self.model, MoonModel)

# Save the parameters of the old local model
old_model = copy.deepcopy(self.model)
for param in old_model.parameters():
param.requires_grad = False
old_model.eval()
self.old_models_list.append(old_model)
if len(self.old_models_list) > self.len_old_models_buffer:
self.old_models_list.pop(0)

# Set the parameters of the model
output = super().set_parameters(parameters, config)

# Save the parameters of the global model
self.global_model = copy.deepcopy(self.model)
for param in self.global_model.parameters():
param.requires_grad = False
self.global_model.eval()
return output

def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses:
if len(self.old_models_list) == 0:
return super().compute_loss(preds, target)
loss = self.criterion(preds["prediction"], target)
contrastive_loss = self.get_contrastive_loss(
preds["features"], preds["global_features"], preds["old_features"]
)
total_loss = loss + self.contrastive_weight * contrastive_loss
losses = Losses(checkpoint=loss, backward=total_loss, additional_losses={"contrastive_loss": contrastive_loss})
return losses
23 changes: 23 additions & 0 deletions fl4health/model_bases/moon_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, Optional

import torch
import torch.nn as nn


class MoonModel(nn.Module):
def __init__(
self, base_module: nn.Module, head_module: nn.Module, projection_module: Optional[nn.Module] = None
) -> None:
super().__init__()
self.base_module = base_module
self.projection_module = projection_module
self.head_module = head_module

def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
x = self.base_module.forward(input)
if self.projection_module:
p = self.projection_module.forward(x)
else:
p = x
output = self.head_module.forward(p)
return {"prediction": output, "features": p.view(len(p), -1)}
1 change: 0 additions & 1 deletion fl4health/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def __init__(self, key_to_meter_map: Dict[str, MetricMeter]):
self.key_to_meter_map = key_to_meter_map

def update(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> None:

for pred_key in preds.keys():
if pred_key in self.key_to_meter_map.keys():
self.key_to_meter_map[pred_key].update(preds[pred_key], target)
Expand Down
Loading

0 comments on commit 75fb81d

Please sign in to comment.