Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update apfl client #62

Merged
merged 14 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,38 @@
import argparse
emersodb marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path
from typing import Sequence
from typing import Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import APFLModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.metrics import Accuracy
from fl4health.utils.sampler import DirichletLabelBasedSampler


class MnistApflClient(ApflClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)

def setup_client(self, config: Config) -> None:
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = self.narrow_config_type(config, "batch_size", int)
self.model: APFLModule = APFLModule(MnistNetWithBnAndFrozen()).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

self.train_loader, self.val_loader, self.num_examples = load_mnist_data(self.data_path, batch_size, sampler)
self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange())
def get_model(self, config: Config) -> nn.Module:
return APFLModule(MnistNetWithBnAndFrozen()).to(self.device)

super().setup_client(config)
def get_optimizer(self, config: Config) -> Optimizer:
emersodb marked this conversation as resolved.
Show resolved Hide resolved
return torch.optim.AdamW(self.model.parameters(), lr=0.01)

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_initial_model_parameters() -> Parameters:

def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config:
return {
"current_server_round": current_round,
"local_epochs": local_epochs,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
Expand Down
288 changes: 67 additions & 221 deletions fl4health/clients/apfl_client.py
Original file line number Diff line number Diff line change
@@ -1,255 +1,101 @@
from logging import INFO
import copy
from pathlib import Path
from typing import Dict, Sequence, Tuple
from typing import Optional, Sequence, Tuple

import torch
from flwr.common.logger import log
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from flwr.common.typing import Config
from torch.optim import Optimizer

from fl4health.clients.numpy_fl_client import NumpyFlClient
from fl4health.checkpointing.checkpointer import TorchCheckpointer
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.apfl_base import APFLModule
from fl4health.utils.metrics import Metric, MetricAverageMeter, MetricMeter
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.losses import Losses, LossMeterType
from fl4health.utils.metrics import Metric, MetricMeterType

LocalLoss = torch.Tensor
GlobalLoss = torch.Tensor
PersonalLoss = torch.Tensor

LocalPreds = torch.Tensor
GlobalPreds = torch.Tensor
PersonalPreds = torch.Tensor

ApflTrainStepOutputs = Tuple[LocalLoss, GlobalLoss, PersonalLoss, LocalPreds, GlobalPreds, PersonalPreds]


class ApflClient(NumpyFlClient):
class ApflClient(BasicClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(data_path, device)
self.metrics = metrics
self.model: APFLModule
self.train_loader: DataLoader
self.val_loader: DataLoader
self.num_examples: Dict[str, int]
self.criterion: _Loss
self.local_optimizer: torch.optim.Optimizer
self.global_optimizer: torch.optim.Optimizer

def is_start_of_local_training(self, epoch: int, step: int) -> bool:
return epoch == 0 and step == 0

def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
if not self.initialized:
self.setup_client(config)

self.set_parameters(parameters, config)
local_epochs = self.narrow_config_type(config, "local_epochs", int)

# Default APFL uses an average meter
global_meter = MetricAverageMeter(self.metrics, "global")
local_meter = MetricAverageMeter(self.metrics, "local")
personal_meter = MetricAverageMeter(self.metrics, "personal")
# By default the APFL client trains by epochs
metric_values = self.train_by_epochs(local_epochs, global_meter, local_meter, personal_meter)
# FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
# calculation results.
return (
self.get_parameters(config),
self.num_examples["train_set"],
metric_values,
)

def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]:
if not self.initialized:
self.setup_client(config)

self.set_parameters(parameters, config)
# Default APFL uses an average meter
global_meter = MetricAverageMeter(self.metrics, "global")
local_meter = MetricAverageMeter(self.metrics, "local")
personal_meter = MetricAverageMeter(self.metrics, "personal")
emersodb marked this conversation as resolved.
Show resolved Hide resolved
loss, metric_values = self.validate(global_meter, local_meter, personal_meter)
# EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics
# calculation results.
return (
loss,
self.num_examples["validation_set"],
metric_values,
)

def _handle_logging(
self, loss_dict: Dict[str, float], metrics_dict: Dict[str, Scalar], is_validation: bool = False
) -> None:
loss_string = "\t".join([f"{key}: {str(val)}" for key, val in loss_dict.items()])
metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()])
metric_prefix = "Validation" if is_validation else "Training"
log(
INFO,
f"alpha: {self.model.alpha} \n"
f"Client {metric_prefix} Losses: {loss_string} \n"
f"Client {metric_prefix} Metrics: {metric_string}",
super().__init__(
data_path, metrics, device, loss_meter_type, metric_meter_type, use_wandb_reporter, checkpointer
)
# Apfl Module which holds both local and global models
# and gives the ability to get personal, local and global predictions
self.model: APFLModule

def train_step(self, input: torch.Tensor, target: torch.Tensor) -> ApflTrainStepOutputs:
# local_optimizer is used on the local model
# Usual self.optimizer is used for global model
self.local_optimizer: Optimizer

def is_start_of_local_training(self, step: int) -> bool:
return step == 0

def update_after_step(self, step: int) -> None:
emersodb marked this conversation as resolved.
Show resolved Hide resolved
if self.is_start_of_local_training(step) and self.model.adaptive_alpha:
self.model.update_alpha()

def split_optimizer(self, global_optimizer: Optimizer) -> Tuple[Optimizer, Optimizer]:
emersodb marked this conversation as resolved.
Show resolved Hide resolved
"""
The optimizer from get_optimizer is for the entire APFLModule. We need one optimizer
for the local model and one optimizer for the global model.
"""
global_optimizer.param_groups.clear()
global_optimizer.state.clear()
local_optimizer = copy.deepcopy(global_optimizer)

global_optimizer.add_param_group({"params": [p for p in self.model.global_model.parameters()]})
local_optimizer.add_param_group({"params": [p for p in self.model.local_model.parameters()]})
return global_optimizer, local_optimizer

def setup_client(self, config: Config) -> None:
"""
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these.
"""
super().setup_client(config)

# Split optimizer from get_optimizer into two distinct optimizers
# One for local model and one for global model
global_optimizer, local_optimizer = self.split_optimizer(self.optimizer)
self.optimizer = global_optimizer
self.local_optimizer = local_optimizer

def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]:
# Mechanics of training loop follow from original implementation
# https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py

# Forward pass on global model and update global parameters
self.global_optimizer.zero_grad()
global_pred = self.model(input, personal=False)["global"]
self.optimizer.zero_grad()
global_pred = self.model.global_forward(input)
global_loss = self.criterion(global_pred, target)
global_loss.backward()
self.global_optimizer.step()
self.optimizer.step()

# Make sure gradients are zero prior to foward passes of global and local model
# to generate personalized predictions
# NOTE: We zero the global optimizer grads because they are used (after the backward calculation below)
# to update the scalar alpha (see update_alpha() where .grad is called.)
self.global_optimizer.zero_grad()
self.optimizer.zero_grad()
self.local_optimizer.zero_grad()

# Personal predictions are generated as a convex combination of the output
# of local and global models
pred_dict = self.model(input, personal=True)
personal_pred, local_pred = pred_dict["personal"], pred_dict["local"]
personal_pred = self.predict(input)
emersodb marked this conversation as resolved.
Show resolved Hide resolved

# Parameters of local model are updated to minimize loss of personalized model
personal_loss = self.criterion(personal_pred, target)
personal_loss.backward()
losses = self.compute_loss(personal_pred, target)
losses.backward.backward()
self.local_optimizer.step()

with torch.no_grad():
local_loss = self.criterion(local_pred, target)

return local_loss, global_loss, personal_loss, local_pred, global_pred, personal_pred
emersodb marked this conversation as resolved.
Show resolved Hide resolved

def train_by_steps(
self, steps: int, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter
) -> Dict[str, Scalar]:
self.model.train()
loss_dict = {"personal": 0.0, "local": 0.0, "global": 0.0}
global_meter.clear()
local_meter.clear()
personal_meter.clear()

train_iterator = iter(self.train_loader)

for step in range(steps):
try:
input, target = next(train_iterator)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
train_iterator = iter(self.train_loader)
input, target = next(train_iterator)

input, target = input.to(self.device), target.to(self.device)
local_loss, global_loss, personal_loss, local_preds, global_preds, personal_preds = self.train_step(
input, target
)

# Only update alpha if it is the first step of local training and adaptive alpha is true
if step == 0 and self.model.adaptive_alpha:
self.model.update_alpha()

loss_dict["local"] += local_loss.item()
loss_dict["global"] += global_loss.item()
loss_dict["personal"] += personal_loss.item()

global_meter.update(global_preds, target)
local_meter.update(local_preds, target)
personal_meter.update(personal_preds, target)

loss_dict = {key: val / steps for key, val in loss_dict.items()}
global_metrics = global_meter.compute()
local_metrics = local_meter.compute()
personal_metrics = personal_meter.compute()
metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics}
log(INFO, f"Performed {steps} Steps of Local training")
self._handle_logging(loss_dict, metrics)

# return final training metrics
return metrics

def train_by_epochs(
self, epochs: int, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter
) -> Dict[str, Scalar]:
self.model.train()
for epoch in range(epochs):
loss_dict = {"personal": 0.0, "local": 0.0, "global": 0.0}
global_meter.clear()
local_meter.clear()
personal_meter.clear()

for step, (input, target) in enumerate(self.train_loader):
input, target = input.to(self.device), target.to(self.device)
local_loss, global_loss, personal_loss, local_preds, global_preds, personal_preds = self.train_step(
input, target
)

# Only update alpha if it is the first epoch and first step of training
# and adaptive alpha is true
if self.is_start_of_local_training(epoch, step) and self.model.adaptive_alpha:
self.model.update_alpha()

loss_dict["local"] += local_loss.item()
loss_dict["global"] += global_loss.item()
loss_dict["personal"] += personal_loss.item()

global_meter.update(global_preds, target)
local_meter.update(local_preds, target)
personal_meter.update(personal_preds, target)

loss_dict = {key: val / len(self.train_loader) for key, val in loss_dict.items()}

global_metrics = global_meter.compute()
local_metrics = local_meter.compute()
personal_metrics = personal_meter.compute()
metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics}
log(INFO, f"Performed {epochs} Epochs of Local training")
self._handle_logging(loss_dict, metrics)

return metrics

def validate(
self, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter
) -> Tuple[float, Dict[str, Scalar]]:
self.model.eval()
loss_dict = {"global": 0.0, "personal": 0.0, "local": 0.0}
global_meter.clear()
local_meter.clear()
personal_meter.clear()

with torch.no_grad():
for input, target in self.val_loader:
input, target = input.to(self.device), target.to(self.device)

global_pred = self.model(input, personal=False)["global"]
global_loss = self.criterion(global_pred, target)

pred_dict = self.model(input, personal=True)
personal_pred, local_pred = pred_dict["personal"], pred_dict["local"]
personal_loss = self.criterion(personal_pred, target)
local_loss = self.criterion(local_pred, target)

loss_dict["global"] += global_loss.item()
loss_dict["personal"] += personal_loss.item()
loss_dict["local"] += local_loss.item()

global_meter.update(global_pred, target)
local_meter.update(local_pred, target)
personal_meter.update(personal_pred, target)

loss_dict = {key: val / len(self.val_loader) for key, val in loss_dict.items()}
global_metrics = global_meter.compute()
local_metrics = local_meter.compute()
personal_metrics = personal_meter.compute()
metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics}
self._handle_logging(loss_dict, metrics, is_validation=True)
self._maybe_checkpoint(loss_dict["personal"])
return loss_dict["personal"], metrics
return losses, personal_pred

def get_parameter_exchanger(self, config: Config) -> FixedLayerExchanger:
return FixedLayerExchanger(self.model.layers_to_exchange())
Loading