Skip to content

Commit

Permalink
Merge pull request #62 from VectorInstitute/update-apfl-client
Browse files Browse the repository at this point in the history
Update apfl client
  • Loading branch information
jewelltaylor authored Oct 18, 2023
2 parents 79e0f5d + 9959c89 commit 78cea5e
Show file tree
Hide file tree
Showing 16 changed files with 414 additions and 486 deletions.
37 changes: 18 additions & 19 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,40 @@
import argparse
from pathlib import Path
from typing import Sequence
from typing import Dict, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.metrics import Accuracy
from fl4health.utils.sampler import DirichletLabelBasedSampler


class MnistApflClient(ApflClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)

def setup_client(self, config: Config) -> None:
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = self.narrow_config_type(config, "batch_size", int)
self.model: ApflModule = ApflModule(MnistNetWithBnAndFrozen()).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

self.train_loader, self.val_loader, self.num_examples = load_mnist_data(self.data_path, batch_size, sampler)
self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange())
def get_model(self, config: Config) -> nn.Module:
return ApflModule(MnistNetWithBnAndFrozen()).to(self.device)

super().setup_client(config)
def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
return {"local": local_optimizer, "global": global_optimizer}

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_initial_model_parameters() -> Parameters:

def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config:
return {
"current_server_round": current_round,
"local_epochs": local_epochs,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedprox_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: True
enabled: False
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
Expand Down
Loading

0 comments on commit 78cea5e

Please sign in to comment.