Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update apfl client #62

Merged
merged 14 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,40 @@
import argparse
emersodb marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path
from typing import Sequence
from typing import Dict, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.metrics import Accuracy
from fl4health.utils.sampler import DirichletLabelBasedSampler


class MnistApflClient(ApflClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)

def setup_client(self, config: Config) -> None:
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = self.narrow_config_type(config, "batch_size", int)
self.model: ApflModule = ApflModule(MnistNetWithBnAndFrozen()).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

self.train_loader, self.val_loader, self.num_examples = load_mnist_data(self.data_path, batch_size, sampler)
self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange())
def get_model(self, config: Config) -> nn.Module:
return ApflModule(MnistNetWithBnAndFrozen()).to(self.device)

super().setup_client(config)
def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
return {"local": local_optimizer, "global": global_optimizer}

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_initial_model_parameters() -> Parameters:

def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config:
return {
"current_server_round": current_round,
"local_epochs": local_epochs,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedprox_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: True
enabled: False
emersodb marked this conversation as resolved.
Show resolved Hide resolved
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
Expand Down
Loading