generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into Add_contrastive_losses
- Loading branch information
Showing
37 changed files
with
2,127 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class HeadCnn(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.fc1 = nn.Linear(256, 10) | ||
|
||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | ||
x = self.fc1(input_tensor) | ||
return x | ||
|
||
|
||
class ProjectionCnn(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.fc1 = nn.Linear(120, 256) | ||
self.fc2 = nn.Linear(256, 256) | ||
|
||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | ||
x = F.relu(self.fc1(input_tensor)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
class BaseCnn(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Moon Federated Learning Example | ||
This example provides an example of training a Moon type model on a non-IID subset of the MNIST data. The FL server | ||
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client | ||
has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original | ||
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties | ||
of the clients training/validation data. In theory, the models should be able to perform well on their local data | ||
while learning from other clients data that has different statistical properties. The subsampling is specified by | ||
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`. | ||
|
||
The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification. | ||
|
||
## Running the Example | ||
In order to run the example, first ensure you have the virtual env of your choice activated and run | ||
``` | ||
pip install --upgrade pip | ||
pip install -r requirements.txt | ||
``` | ||
to install all of the dependencies for this project. | ||
|
||
## Starting Server | ||
|
||
The next step is to start the server by running | ||
``` | ||
python -m examples.moon_example.server --config_path /path/to/config.yaml | ||
``` | ||
from the FL4Health directory. The following arguments must be present in the specified config file: | ||
* `n_clients`: number of clients the server waits for in order to run the FL training | ||
* `local_epochs`: number of epochs each client will train for locally | ||
* `batch_size`: size of the batches each client will train on | ||
* `n_server_rounds`: The number of rounds to run FL | ||
* `downsampling_ratio`: The amount of downsampling to perform for minority digits | ||
|
||
## Starting Clients | ||
|
||
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three | ||
clients. This is done by simply running (remembering to activate your environment) | ||
``` | ||
python -m examples.moon_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers> | ||
``` | ||
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If | ||
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be | ||
automatically downloaded to the path specified and used in the run. | ||
|
||
The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to | ||
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client | ||
downsamples these digits (using the `downsampling_ratio` specified to the config). | ||
|
||
After both clients have been started federated learning should commence. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
from pathlib import Path | ||
from typing import Sequence, Set, Tuple | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from examples.models.moon_cnn import BaseCnn, HeadCnn, ProjectionCnn | ||
from fl4health.clients.moon_client import MoonClient | ||
from fl4health.model_bases.moon_base import MoonModel | ||
from fl4health.utils.load_data import load_mnist_data | ||
from fl4health.utils.metrics import Accuracy, Metric | ||
from fl4health.utils.sampler import MinorityLabelBasedSampler | ||
|
||
|
||
class MnistMoonClient(MoonClient): | ||
def __init__( | ||
self, | ||
data_path: Path, | ||
metrics: Sequence[Metric], | ||
device: torch.device, | ||
minority_numbers: Set[int], | ||
) -> None: | ||
super().__init__(data_path=data_path, metrics=metrics, device=device) | ||
self.minority_numbers = minority_numbers | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
batch_size = self.narrow_config_type(config, "batch_size", int) | ||
downsample_percentage = self.narrow_config_type(config, "downsampling_ratio", float) | ||
sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) | ||
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) | ||
return train_loader, val_loader | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
model: nn.Module = MoonModel(BaseCnn(), HeadCnn(), ProjectionCnn()).to(self.device) | ||
return model | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer: | ||
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Client Main") | ||
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") | ||
parser.add_argument( | ||
"--minority_numbers", default=[], nargs="*", help="MNIST numbers to be in the minority for the current client" | ||
) | ||
args = parser.parse_args() | ||
|
||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
data_path = Path(args.dataset_path) | ||
minority_numbers = {int(number) for number in args.minority_numbers} | ||
client = MnistMoonClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) | ||
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client) | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Parameters that describe server | ||
n_server_rounds: 3 # The number of rounds to run FL | ||
|
||
# Parameters that describe clients | ||
n_clients: 3 # The number of clients in the FL experiment | ||
local_epochs: 1 # The number of epochs to complete for client | ||
batch_size: 32 # The batch size for client training | ||
|
||
# Downsampling settings per client | ||
downsampling_ratio: 0.1 # percentage of original mnist data to keep for minority numbers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import argparse | ||
from functools import partial | ||
from typing import Any, Dict | ||
|
||
import flwr as fl | ||
from flwr.common.parameter import ndarrays_to_parameters | ||
from flwr.common.typing import Config, Parameters | ||
from flwr.server.strategy import FedAvg | ||
|
||
from examples.models.moon_cnn import BaseCnn, HeadCnn, ProjectionCnn | ||
from examples.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn | ||
from fl4health.model_bases.moon_base import MoonModel | ||
from fl4health.utils.config import load_config | ||
|
||
|
||
def get_initial_model_parameters() -> Parameters: | ||
# Initializing the model parameters on the server side. | ||
# Currently uses the Pytorch default initialization for the model parameters. | ||
initial_model = MoonModel(BaseCnn(), HeadCnn(), ProjectionCnn()) | ||
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()]) | ||
|
||
|
||
def fit_config( | ||
local_epochs: int, batch_size: int, n_server_rounds: int, downsampling_ratio: float, current_round: int | ||
) -> Config: | ||
return { | ||
"local_epochs": local_epochs, | ||
"batch_size": batch_size, | ||
"n_server_rounds": n_server_rounds, | ||
"downsampling_ratio": downsampling_ratio, | ||
"current_server_round": current_round, | ||
} | ||
|
||
|
||
def main(config: Dict[str, Any]) -> None: | ||
# This function will be used to produce a config that is sent to each client to initialize their own environment | ||
fit_config_fn = partial( | ||
fit_config, | ||
config["local_epochs"], | ||
config["batch_size"], | ||
config["n_server_rounds"], | ||
config["downsampling_ratio"], | ||
) | ||
|
||
# Server performs simple FedAveraging as its server-side optimization strategy | ||
strategy = FedAvg( | ||
min_fit_clients=config["n_clients"], | ||
min_evaluate_clients=config["n_clients"], | ||
# Server waits for min_available_clients before starting FL rounds | ||
min_available_clients=config["n_clients"], | ||
on_fit_config_fn=fit_config_fn, | ||
# We use the same fit config function, as nothing changes for eval | ||
on_evaluate_config_fn=fit_config_fn, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=get_initial_model_parameters(), | ||
) | ||
|
||
fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), | ||
strategy=strategy, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Server Main") | ||
parser.add_argument( | ||
"--config_path", | ||
action="store", | ||
type=str, | ||
help="Path to configuration file.", | ||
default="examples/moon_example/config.yaml", | ||
) | ||
args = parser.parse_args() | ||
|
||
config = load_config(args.config_path) | ||
|
||
main(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import copy | ||
from pathlib import Path | ||
from typing import Dict, Optional, Sequence | ||
|
||
import torch | ||
from flwr.common.typing import Config, NDArrays | ||
|
||
from fl4health.checkpointing.checkpointer import TorchCheckpointer | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.model_bases.moon_base import MoonModel | ||
from fl4health.utils.losses import Losses, LossMeterType | ||
from fl4health.utils.metrics import Metric, MetricMeterType | ||
|
||
|
||
class MoonClient(BasicClient): | ||
""" | ||
This client implements the MOON algorithm from Model-Contrastive Federated Learning. The key idea of MOON | ||
is to utilize the similarity between model representations to correct the local training of individual parties, | ||
i.e., conducting contrastive learning in model-level. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_path: Path, | ||
metrics: Sequence[Metric], | ||
device: torch.device, | ||
loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, | ||
checkpointer: Optional[TorchCheckpointer] = None, | ||
temperature: float = 0.5, | ||
contrastive_weight: float = 10, | ||
len_old_models_buffer: int = 1, | ||
) -> None: | ||
super().__init__( | ||
data_path=data_path, | ||
metrics=metrics, | ||
device=device, | ||
loss_meter_type=loss_meter_type, | ||
metric_meter_type=metric_meter_type, | ||
checkpointer=checkpointer, | ||
) | ||
self.cos_sim = torch.nn.CosineSimilarity(dim=-1) | ||
self.ce_criterion = torch.nn.CrossEntropyLoss().to(self.device) | ||
self.contrastive_weight = contrastive_weight | ||
self.temperature = temperature | ||
|
||
# Saving previous local models and global model at each communication round to compute contrastive loss | ||
self.len_old_models_buffer = len_old_models_buffer | ||
self.old_models_list: list[MoonModel] = [] | ||
self.global_model: MoonModel | ||
|
||
def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: | ||
preds = self.model(input) | ||
preds["old_features"] = torch.zeros(self.len_old_models_buffer, *preds["features"].size()).to(self.device) | ||
for i, old_model in enumerate(self.old_models_list): | ||
old_preds = old_model(input) | ||
preds["old_features"][i] = old_preds["features"] | ||
global_preds = self.global_model(input) | ||
preds["global_features"] = global_preds["features"] | ||
if isinstance(preds, dict): | ||
return preds | ||
elif isinstance(preds, torch.Tensor): | ||
return {"prediction": preds} | ||
else: | ||
raise ValueError("Model forward did not return a tensor or dictionary or tensors") | ||
|
||
def get_contrastive_loss( | ||
self, features: torch.Tensor, global_features: torch.Tensor, old_features: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
This constrastive loss is implemented based on https://github.com/QinbinLi/MOON. | ||
The primary idea is to enhance the similarity between the current local features and the global feature | ||
as positive pairs while reducing the similarity between the current local features and the previous local | ||
features as negative pairs. | ||
""" | ||
assert len(features) == len(global_features) | ||
posi = self.cos_sim(features, global_features) | ||
logits = posi.reshape(-1, 1) | ||
for old_feature in old_features: | ||
assert len(features) == len(old_feature) | ||
nega = self.cos_sim(features, old_feature) | ||
logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1) | ||
logits /= self.temperature | ||
labels = torch.zeros(features.size(0)).to(self.device).long() | ||
|
||
return self.ce_criterion(logits, labels) | ||
|
||
def set_parameters(self, parameters: NDArrays, config: Config) -> None: | ||
assert isinstance(self.model, MoonModel) | ||
|
||
# Save the parameters of the old local model | ||
old_model = copy.deepcopy(self.model) | ||
for param in old_model.parameters(): | ||
param.requires_grad = False | ||
old_model.eval() | ||
self.old_models_list.append(old_model) | ||
if len(self.old_models_list) > self.len_old_models_buffer: | ||
self.old_models_list.pop(0) | ||
|
||
# Set the parameters of the model | ||
output = super().set_parameters(parameters, config) | ||
|
||
# Save the parameters of the global model | ||
self.global_model = copy.deepcopy(self.model) | ||
for param in self.global_model.parameters(): | ||
param.requires_grad = False | ||
self.global_model.eval() | ||
return output | ||
|
||
def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses: | ||
if len(self.old_models_list) == 0: | ||
return super().compute_loss(preds, target) | ||
loss = self.criterion(preds["prediction"], target) | ||
contrastive_loss = self.get_contrastive_loss( | ||
preds["features"], preds["global_features"], preds["old_features"] | ||
) | ||
total_loss = loss + self.contrastive_weight * contrastive_loss | ||
losses = Losses(checkpoint=loss, backward=total_loss, additional_losses={"contrastive_loss": contrastive_loss}) | ||
return losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class MoonModel(nn.Module): | ||
def __init__( | ||
self, base_module: nn.Module, head_module: nn.Module, projection_module: Optional[nn.Module] = None | ||
) -> None: | ||
super().__init__() | ||
self.base_module = base_module | ||
self.projection_module = projection_module | ||
self.head_module = head_module | ||
|
||
def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: | ||
x = self.base_module.forward(input) | ||
if self.projection_module: | ||
p = self.projection_module.forward(x) | ||
else: | ||
p = x | ||
output = self.head_module.forward(p) | ||
return {"prediction": output, "features": p.view(len(p), -1)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.