From 2a96ad0c36fce4d6a8b30e825acada82b52f3048 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 28 Sep 2023 15:45:34 -0400 Subject: [PATCH 01/12] First pass at changing apfl client to leverage base client --- examples/apfl_example/client.py | 35 +-- examples/apfl_example/server.py | 1 + fl4health/clients/apfl_client.py | 279 ++++-------------- fl4health/clients/basic_client.py | 9 +- fl4health/model_bases/apfl_base.py | 19 +- .../flamby/fed_heart_disease/apfl/client.py | 86 +++--- research/flamby/fed_isic2019/apfl/client.py | 89 +++--- research/flamby/fed_ixi/apfl/client.py | 87 +++--- research/flamby/flamby_clients/__init__.py | 0 .../flamby_clients/flamby_apfl_client.py | 70 ----- .../parameter_exchange/test_apfl_exchange.py | 21 +- 11 files changed, 264 insertions(+), 432 deletions(-) delete mode 100644 research/flamby/flamby_clients/__init__.py delete mode 100644 research/flamby/flamby_clients/flamby_apfl_client.py diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index 70b1f9a39..dda1d9c9c 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -1,41 +1,38 @@ import argparse from pathlib import Path -from typing import Sequence +from typing import Tuple import flwr as fl import torch +import torch.nn as nn from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader from examples.models.cnn_model import MnistNetWithBnAndFrozen from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import APFLModule -from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.utils.load_data import load_mnist_data -from fl4health.utils.metrics import Accuracy, Metric +from fl4health.utils.metrics import Accuracy from fl4health.utils.sampler import DirichletLabelBasedSampler class MnistApflClient(ApflClient): - def __init__( - self, - data_path: Path, - metrics: Sequence[Metric], - device: torch.device, - ) -> None: - super().__init__(data_path=data_path, metrics=metrics, device=device) - - def setup_client(self, config: Config) -> None: + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: batch_size = self.narrow_config_type(config, "batch_size", int) - self.model: APFLModule = APFLModule(MnistNetWithBnAndFrozen()).to(self.device) - self.criterion = torch.nn.CrossEntropyLoss() - self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01) - self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75) + train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) + return train_loader, val_loader - self.train_loader, self.val_loader, self.num_examples = load_mnist_data(self.data_path, batch_size, sampler) - self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange()) + def get_model(self, config: Config) -> nn.Module: + return APFLModule(MnistNetWithBnAndFrozen()).to(self.device) - super().setup_client(config) + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.parameters(), lr=0.01) + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() if __name__ == "__main__": diff --git a/examples/apfl_example/server.py b/examples/apfl_example/server.py index de90a4172..3296e906d 100644 --- a/examples/apfl_example/server.py +++ b/examples/apfl_example/server.py @@ -22,6 +22,7 @@ def get_initial_model_parameters() -> Parameters: def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config: return { + "current_server_round": current_round, "local_epochs": local_epochs, "batch_size": batch_size, "n_server_rounds": n_server_rounds, diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 4187660be..1d1e97aa8 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -1,255 +1,92 @@ -from logging import INFO +import copy from pathlib import Path -from typing import Dict, Sequence, Tuple +from typing import Optional, Sequence, Tuple import torch -from flwr.common.logger import log -from flwr.common.typing import Config, NDArrays, Scalar -from torch.nn.modules.loss import _Loss -from torch.utils.data import DataLoader +import torch.nn as nn +from flwr.common.typing import Config +from torch.optim import Optimizer -from fl4health.clients.numpy_fl_client import NumpyFlClient +from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import APFLModule -from fl4health.utils.metrics import Metric, MetricAverageMeter, MetricMeter +from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger +from fl4health.utils.losses import Losses, LossMeterType +from fl4health.utils.metrics import Metric, MetricMeterType -LocalLoss = torch.Tensor -GlobalLoss = torch.Tensor -PersonalLoss = torch.Tensor -LocalPreds = torch.Tensor -GlobalPreds = torch.Tensor -PersonalPreds = torch.Tensor - -ApflTrainStepOutputs = Tuple[LocalLoss, GlobalLoss, PersonalLoss, LocalPreds, GlobalPreds, PersonalPreds] - - -class ApflClient(NumpyFlClient): +class ApflClient(BasicClient): def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, + use_wandb_reporter: bool = False, + checkpointer: Optional[TorchCheckpointer] = None, ) -> None: - super().__init__(data_path, device) - self.metrics = metrics - self.model: APFLModule - self.train_loader: DataLoader - self.val_loader: DataLoader - self.num_examples: Dict[str, int] - self.criterion: _Loss - self.local_optimizer: torch.optim.Optimizer - self.global_optimizer: torch.optim.Optimizer - - def is_start_of_local_training(self, epoch: int, step: int) -> bool: - return epoch == 0 and step == 0 - - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: - if not self.initialized: - self.setup_client(config) - - self.set_parameters(parameters, config) - local_epochs = self.narrow_config_type(config, "local_epochs", int) - - # Default APFL uses an average meter - global_meter = MetricAverageMeter(self.metrics, "global") - local_meter = MetricAverageMeter(self.metrics, "local") - personal_meter = MetricAverageMeter(self.metrics, "personal") - # By default the APFL client trains by epochs - metric_values = self.train_by_epochs(local_epochs, global_meter, local_meter, personal_meter) - # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics - # calculation results. - return ( - self.get_parameters(config), - self.num_examples["train_set"], - metric_values, - ) - - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: - if not self.initialized: - self.setup_client(config) - - self.set_parameters(parameters, config) - # Default APFL uses an average meter - global_meter = MetricAverageMeter(self.metrics, "global") - local_meter = MetricAverageMeter(self.metrics, "local") - personal_meter = MetricAverageMeter(self.metrics, "personal") - loss, metric_values = self.validate(global_meter, local_meter, personal_meter) - # EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics - # calculation results. - return ( - loss, - self.num_examples["validation_set"], - metric_values, + super().__init__( + data_path, metrics, device, loss_meter_type, metric_meter_type, use_wandb_reporter, checkpointer ) - def _handle_logging( - self, loss_dict: Dict[str, float], metrics_dict: Dict[str, Scalar], is_validation: bool = False - ) -> None: - loss_string = "\t".join([f"{key}: {str(val)}" for key, val in loss_dict.items()]) - metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()]) - metric_prefix = "Validation" if is_validation else "Training" - log( - INFO, - f"alpha: {self.model.alpha} \n" - f"Client {metric_prefix} Losses: {loss_string} \n" - f"Client {metric_prefix} Metrics: {metric_string}", - ) - - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> ApflTrainStepOutputs: + self.model: APFLModule + self.local_model: nn.Module + self.local_optimizer: Optimizer + + def is_start_of_local_training(self, step: int) -> bool: + return step == 0 + + def update_after_step(self, step: int) -> None: + if self.is_start_of_local_training(step) and self.model.adaptive_alpha: + self.model.update_alpha() + + def split_optimizer(self, global_optimizer: Optimizer) -> Tuple[Optimizer, Optimizer]: + global_optimizer.param_groups.clear() + global_optimizer.state.clear() + local_optimizer = copy.deepcopy(global_optimizer) + + global_optimizer.add_param_group({"params": [p for p in self.model.global_model.parameters()]}) + local_optimizer.add_param_group({"params": [p for p in self.model.local_model.parameters()]}) + return global_optimizer, local_optimizer + + def setup_client(self, config: Config) -> None: + """ + Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. + """ + super().setup_client(config) + global_optimizer, local_optimizer = self.split_optimizer(self.optimizer) + self.optimizer = global_optimizer + self.local_optimizer = local_optimizer + + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: # Mechanics of training loop follow from original implementation # https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py # Forward pass on global model and update global parameters - self.global_optimizer.zero_grad() - global_pred = self.model(input, personal=False)["global"] + self.optimizer.zero_grad() + global_pred = self.model.global_forward(input) global_loss = self.criterion(global_pred, target) global_loss.backward() - self.global_optimizer.step() + self.optimizer.step() # Make sure gradients are zero prior to foward passes of global and local model # to generate personalized predictions # NOTE: We zero the global optimizer grads because they are used (after the backward calculation below) # to update the scalar alpha (see update_alpha() where .grad is called.) - self.global_optimizer.zero_grad() + self.optimizer.zero_grad() self.local_optimizer.zero_grad() # Personal predictions are generated as a convex combination of the output # of local and global models - pred_dict = self.model(input, personal=True) - personal_pred, local_pred = pred_dict["personal"], pred_dict["local"] + personal_pred = self.predict(input) # Parameters of local model are updated to minimize loss of personalized model - personal_loss = self.criterion(personal_pred, target) - personal_loss.backward() + losses = self.compute_loss(personal_pred, target) + losses.backward.backward() self.local_optimizer.step() - with torch.no_grad(): - local_loss = self.criterion(local_pred, target) - - return local_loss, global_loss, personal_loss, local_pred, global_pred, personal_pred - - def train_by_steps( - self, steps: int, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter - ) -> Dict[str, Scalar]: - self.model.train() - loss_dict = {"personal": 0.0, "local": 0.0, "global": 0.0} - global_meter.clear() - local_meter.clear() - personal_meter.clear() - - train_iterator = iter(self.train_loader) - - for step in range(steps): - try: - input, target = next(train_iterator) - except StopIteration: - # StopIteration is thrown if dataset ends - # reinitialize data loader - train_iterator = iter(self.train_loader) - input, target = next(train_iterator) - - input, target = input.to(self.device), target.to(self.device) - local_loss, global_loss, personal_loss, local_preds, global_preds, personal_preds = self.train_step( - input, target - ) - - # Only update alpha if it is the first step of local training and adaptive alpha is true - if step == 0 and self.model.adaptive_alpha: - self.model.update_alpha() - - loss_dict["local"] += local_loss.item() - loss_dict["global"] += global_loss.item() - loss_dict["personal"] += personal_loss.item() - - global_meter.update(global_preds, target) - local_meter.update(local_preds, target) - personal_meter.update(personal_preds, target) - - loss_dict = {key: val / steps for key, val in loss_dict.items()} - global_metrics = global_meter.compute() - local_metrics = local_meter.compute() - personal_metrics = personal_meter.compute() - metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} - log(INFO, f"Performed {steps} Steps of Local training") - self._handle_logging(loss_dict, metrics) - - # return final training metrics - return metrics - - def train_by_epochs( - self, epochs: int, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter - ) -> Dict[str, Scalar]: - self.model.train() - for epoch in range(epochs): - loss_dict = {"personal": 0.0, "local": 0.0, "global": 0.0} - global_meter.clear() - local_meter.clear() - personal_meter.clear() - - for step, (input, target) in enumerate(self.train_loader): - input, target = input.to(self.device), target.to(self.device) - local_loss, global_loss, personal_loss, local_preds, global_preds, personal_preds = self.train_step( - input, target - ) - - # Only update alpha if it is the first epoch and first step of training - # and adaptive alpha is true - if self.is_start_of_local_training(epoch, step) and self.model.adaptive_alpha: - self.model.update_alpha() - - loss_dict["local"] += local_loss.item() - loss_dict["global"] += global_loss.item() - loss_dict["personal"] += personal_loss.item() - - global_meter.update(global_preds, target) - local_meter.update(local_preds, target) - personal_meter.update(personal_preds, target) - - loss_dict = {key: val / len(self.train_loader) for key, val in loss_dict.items()} - - global_metrics = global_meter.compute() - local_metrics = local_meter.compute() - personal_metrics = personal_meter.compute() - metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} - log(INFO, f"Performed {epochs} Epochs of Local training") - self._handle_logging(loss_dict, metrics) - - return metrics - - def validate( - self, global_meter: MetricMeter, local_meter: MetricMeter, personal_meter: MetricMeter - ) -> Tuple[float, Dict[str, Scalar]]: - self.model.eval() - loss_dict = {"global": 0.0, "personal": 0.0, "local": 0.0} - global_meter.clear() - local_meter.clear() - personal_meter.clear() - - with torch.no_grad(): - for input, target in self.val_loader: - input, target = input.to(self.device), target.to(self.device) - - global_pred = self.model(input, personal=False)["global"] - global_loss = self.criterion(global_pred, target) - - pred_dict = self.model(input, personal=True) - personal_pred, local_pred = pred_dict["personal"], pred_dict["local"] - personal_loss = self.criterion(personal_pred, target) - local_loss = self.criterion(local_pred, target) - - loss_dict["global"] += global_loss.item() - loss_dict["personal"] += personal_loss.item() - loss_dict["local"] += local_loss.item() - - global_meter.update(global_pred, target) - local_meter.update(local_pred, target) - personal_meter.update(personal_pred, target) - - loss_dict = {key: val / len(self.val_loader) for key, val in loss_dict.items()} - global_metrics = global_meter.compute() - local_metrics = local_meter.compute() - personal_metrics = personal_meter.compute() - metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} - self._handle_logging(loss_dict, metrics, is_validation=True) - self._maybe_checkpoint(loss_dict["personal"]) - return loss_dict["personal"], metrics + return losses, personal_pred + + def get_parameter_exchanger(self, config: Config) -> FixedLayerExchanger: + return FixedLayerExchanger(self.model.layers_to_exchange()) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 7e0f0dbff..baac79d38 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -178,11 +178,12 @@ def train_by_epochs( for local_epoch in range(epochs): self.train_metric_meter.clear() self.train_loss_meter.clear() - for input, target in self.train_loader: + for step, (input, target) in enumerate(self.train_loader): input, target = input.to(self.device), target.to(self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + self.update_after_step(step) metrics = self.train_metric_meter.compute() losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() @@ -202,7 +203,7 @@ def train_by_steps( self.train_loss_meter.clear() self.train_metric_meter.clear() - for _ in range(steps): + for step in range(steps): try: input, target = next(train_iterator) except StopIteration: @@ -215,6 +216,7 @@ def train_by_steps( losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + self.update_after_step(step) losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() @@ -327,3 +329,6 @@ def get_model(self, config: Config) -> nn.Module: def update_after_train(self, local_steps: int) -> None: pass + + def update_after_step(self, step: int) -> None: + pass diff --git a/fl4health/model_bases/apfl_base.py b/fl4health/model_bases/apfl_base.py index dfdc69fdd..67ab8087d 100644 --- a/fl4health/model_bases/apfl_base.py +++ b/fl4health/model_bases/apfl_base.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List +from typing import List import torch import torch.nn as nn @@ -21,17 +21,18 @@ def __init__( self.alpha = alpha self.alpha_lr = alpha_lr - def forward(self, input: torch.Tensor, personal: bool) -> Dict[str, torch.Tensor]: - if not personal: - global_logits = self.global_model(input) - return {"global": global_logits} + def global_forward(self, input: torch.Tensor) -> torch.Tensor: + return self.global_model(input) - global_logits = self.global_model(input) - local_logits = self.local_model(input) + def local_forward(self, input: torch.Tensor) -> torch.Tensor: + return self.local_model(input) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + global_logits = self.global_forward(input) + local_logits = self.local_forward(input) personal_logits = self.alpha * local_logits + (1.0 - self.alpha) * global_logits - results = {"personal": personal_logits, "local": local_logits} - return results + return personal_logits def update_alpha(self) -> None: # Updates to mixture parameter follow original implementation diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index 4deaf45cc..c26f7df15 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -1,57 +1,72 @@ import argparse +import os from logging import INFO -from typing import Sequence +from pathlib import Path +from typing import Optional, Sequence, Tuple import flwr as fl import torch from flamby.datasets.fed_heart_disease import BATCH_SIZE, LR, NUM_CLIENTS, Baseline, BaselineLoss from flwr.common.logger import log from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer from torch.utils.data import DataLoader +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import APFLModule -from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger -from fl4health.utils.metrics import Accuracy, Metric -from research.flamby.flamby_clients.flamby_apfl_client import FlambyApflClient +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import Accuracy, Metric, MetricMeterType from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets -class FedHeartDiseaseApflClient(FlambyApflClient): +class FedHeartDiseaseApflClient(ApflClient): def __init__( self, - learning_rate: float, - alpha_learning_rate: float, + data_path: Path, metrics: Sequence[Metric], device: torch.device, client_number: int, - checkpoint_stub: str, - dataset_dir: str, - run_name: str = "", + learning_rate: float, + alpha_learning_rate: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, + checkpointer: Optional[TorchCheckpointer] = None, + use_wandb_reporter: bool = False, ) -> None: - assert 0 <= client_number < NUM_CLIENTS super().__init__( - learning_rate, alpha_learning_rate, metrics, device, client_number, checkpoint_stub, dataset_dir, run_name + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + metric_meter_type=metric_meter_type, + use_wandb_reporter=use_wandb_reporter, + checkpointer=checkpointer, ) + assert 0 <= client_number < NUM_CLIENTS - def setup_client(self, config: Config) -> None: + self.learning_rate = learning_rate + self.alpha_learning_rate = alpha_learning_rate + self.client_number = client_number + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( - self.client_number, self.dataset_dir + self.client_number, str(self.data_path) ) + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) + return train_loader, val_loader - self.train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) - self.val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) - - self.num_examples = {"train_set": len(train_dataset), "validation_set": len(validation_dataset)} - - self.criterion = BaselineLoss() + def get_model(self, config: Config) -> APFLModule: + model: APFLModule = APFLModule(Baseline(), alpha_lr=self.alpha_learning_rate).to(self.device) + return model - self.model: APFLModule = APFLModule(Baseline(), alpha_lr=self.alpha_learning_rate).to(self.device) - self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) - self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) - self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange()) - - super().setup_client(config) + def get_criterion(self, config: Config) -> _Loss: + return BaselineLoss() if __name__ == "__main__": @@ -104,15 +119,18 @@ def setup_client(self, config: Config) -> None: log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + checkpoint_name = f"client_{args.client_number}_best_model.pkl" + checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False) + client = FedHeartDiseaseApflClient( - args.learning_rate, - args.alpha_learning_rate, - [Accuracy("FedHeartDisease_accuracy")], - DEVICE, - args.client_number, - args.artifact_dir, - args.dataset_dir, - args.run_name, + data_path=args.dataset_dir, + metrics=[Accuracy("FedHeartDisease_accuracy")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + alpha_learning_rate=args.alpha_learning_rate, + checkpointer=checkpointer, ) fl.client.start_numpy_client(server_address=args.server_address, client=client) diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 65ebcef87..545810377 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -1,60 +1,76 @@ import argparse +import os from logging import INFO -from typing import Sequence +from pathlib import Path +from typing import Optional, Sequence, Tuple import flwr as fl import torch +import torch.nn as nn from flamby.datasets.fed_isic2019 import BATCH_SIZE, LR, NUM_CLIENTS, BaselineLoss from flwr.common.logger import log from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer from torch.utils.data import DataLoader +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import APFLModule -from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger -from fl4health.utils.metrics import BalancedAccuracy, Metric +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import BalancedAccuracy, Metric, MetricMeterType from research.flamby.fed_isic2019.apfl.apfl_model import APFLEfficientNet -from research.flamby.flamby_clients.flamby_apfl_client import FlambyApflClient from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets -class FedIsic2019ApflClient(FlambyApflClient): +class FedIsic2019ApflClient(ApflClient): def __init__( self, - learning_rate: float, - alpha_learning_rate: float, + data_path: Path, metrics: Sequence[Metric], device: torch.device, client_number: int, - checkpoint_stub: str, - dataset_dir: str, - run_name: str = "", + learning_rate: float, + alpha_learning_rate: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, + checkpointer: Optional[TorchCheckpointer] = None, + use_wandb_reporter: bool = False, ) -> None: - assert 0 <= client_number < NUM_CLIENTS super().__init__( - learning_rate, alpha_learning_rate, metrics, device, client_number, checkpoint_stub, dataset_dir, run_name + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + metric_meter_type=metric_meter_type, + use_wandb_reporter=use_wandb_reporter, + checkpointer=checkpointer, ) + assert 0 <= client_number < NUM_CLIENTS - def setup_client(self, config: Config) -> None: - train_dataset, validation_dataset = construct_fedisic_train_val_datasets(self.client_number, self.dataset_dir) - - self.train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) - self.val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) + self.learning_rate = learning_rate + self.alpha_learning_rate = alpha_learning_rate + self.client_number = client_number - self.num_examples = {"train_set": len(train_dataset), "validation_set": len(validation_dataset)} + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + train_dataset, validation_dataset = construct_fedisic_train_val_datasets( + self.client_number, str(self.data_path) + ) + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) + return train_loader, val_loader - # NOTE: The class weights specified by alpha in this baseline loss are precomputed based on the weights of - # the pool dataset. This is a bit of cheating but FLamby does it in their paper. - self.criterion = BaselineLoss() + def get_criterion(self, config: Config) -> _Loss: + return BaselineLoss() - self.model: APFLModule = APFLModule( + def get_model(self, config: Config) -> nn.Module: + model: APFLModule = APFLModule( APFLEfficientNet(frozen_blocks=13, turn_off_bn_tracking=False), alpha_lr=self.alpha_learning_rate ).to(self.device) - self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) - self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + return model - self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange()) - - super().setup_client(config) + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) if __name__ == "__main__": @@ -107,15 +123,18 @@ def setup_client(self, config: Config) -> None: log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + checkpoint_name = f"client_{args.client_number}_best_model.pkl" + checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False) + client = FedIsic2019ApflClient( - args.learning_rate, - args.alpha_learning_rate, - [BalancedAccuracy("FedIsic2019_balanced_accuracy")], - DEVICE, - args.client_number, - args.artifact_dir, - args.dataset_dir, - args.run_name, + data_path=Path(args.dataset_dir), + metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + alpha_learning_rate=args.alpha_learning_rate, + checkpointer=checkpointer, ) fl.client.start_numpy_client(server_address=args.server_address, client=client) diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index 078a2f140..5e4136e11 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -1,56 +1,74 @@ import argparse +import os from logging import INFO -from typing import Sequence +from pathlib import Path +from typing import Optional, Sequence, Tuple import flwr as fl import torch +import torch.nn as nn from flamby.datasets.fed_ixi import BATCH_SIZE, LR, NUM_CLIENTS, BaselineLoss from flwr.common.logger import log from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer from torch.utils.data import DataLoader +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import APFLModule -from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger -from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric, MetricMeterType from research.flamby.fed_ixi.apfl.apfl_model import APFLUNet -from research.flamby.flamby_clients.flamby_apfl_client import FlambyApflClient from research.flamby.flamby_data_utils import construct_fed_ixi_train_val_datasets -class FedIxiApflClient(FlambyApflClient): +class FedIxiApflClient(ApflClient): def __init__( self, - learning_rate: float, - alpha_learning_rate: float, + data_path: Path, metrics: Sequence[Metric], device: torch.device, client_number: int, - checkpoint_stub: str, - dataset_dir: str, - run_name: str = "", + learning_rate: float, + alpha_learning_rate: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, + checkpointer: Optional[TorchCheckpointer] = None, + use_wandb_reporter: bool = False, ) -> None: - assert 0 <= client_number < NUM_CLIENTS super().__init__( - learning_rate, alpha_learning_rate, metrics, device, client_number, checkpoint_stub, dataset_dir, run_name + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + metric_meter_type=metric_meter_type, + use_wandb_reporter=use_wandb_reporter, + checkpointer=checkpointer, ) + assert 0 <= client_number < NUM_CLIENTS - def setup_client(self, config: Config) -> None: - train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets(self.client_number, self.dataset_dir) - - self.train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) - self.val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) - - self.num_examples = {"train_set": len(train_dataset), "validation_set": len(validation_dataset)} + self.learning_rate = learning_rate + self.alpha_learning_rate = alpha_learning_rate + self.client_number = client_number - self.criterion = BaselineLoss() + def get_dataloader(self, config: Config) -> Tuple[DataLoader, DataLoader]: + train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( + self.client_number, str(self.data_path) + ) + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False) + return train_loader, val_loader - self.model: APFLModule = APFLModule(APFLUNet(), alpha_lr=self.alpha_learning_rate).to(self.device) - self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) - self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + def get_model(self, config: Config) -> nn.Module: + model: APFLModule = APFLModule(APFLUNet(), alpha_lr=self.alpha_learning_rate).to(self.device) + return model - self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange()) + def get_optiizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) - super().setup_client(config) + def get_criterion(self, config: Config) -> _Loss: + return BaselineLoss() if __name__ == "__main__": @@ -103,15 +121,18 @@ def setup_client(self, config: Config) -> None: log(INFO, f"Learning Rate: {args.learning_rate}") log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}") + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + checkpoint_name = f"client_{args.client_number}_best_model.pkl" + checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False) + client = FedIxiApflClient( - args.learning_rate, - args.alpha_learning_rate, - [BinarySoftDiceCoefficient("FedIXI_dice")], - DEVICE, - args.client_number, - args.artifact_dir, - args.dataset_dir, - args.run_name, + data_path=Path(args.dataset_dir), + metrics=[BinarySoftDiceCoefficient("FedIXI_dice")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + alpha_learning_rate=args.alpha_learning_rate, + checkpointer=checkpointer, ) fl.client.start_numpy_client(server_address=args.server_address, client=client) diff --git a/research/flamby/flamby_clients/__init__.py b/research/flamby/flamby_clients/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/research/flamby/flamby_clients/flamby_apfl_client.py b/research/flamby/flamby_clients/flamby_apfl_client.py deleted file mode 100644 index d856ba34f..000000000 --- a/research/flamby/flamby_clients/flamby_apfl_client.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -from logging import INFO -from pathlib import Path -from typing import Dict, Sequence, Tuple - -import torch -from flwr.common.logger import log -from flwr.common.typing import Config, NDArrays, Scalar - -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer -from fl4health.clients.apfl_client import ApflClient -from fl4health.utils.metrics import Metric, MetricAccumulationMeter - - -class FlambyApflClient(ApflClient): - def __init__( - self, - learning_rate: float, - alpha_learning_rate: float, - metrics: Sequence[Metric], - device: torch.device, - client_number: int, - checkpoint_stub: str, - dataset_dir: str, - run_name: str = "", - ) -> None: - super().__init__(data_path=Path(""), metrics=metrics, device=device) - self.client_number = client_number - log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - checkpoint_dir = os.path.join(checkpoint_stub, run_name) - checkpoint_name = f"client_{self.client_number}_best_model.pkl" - self.learning_rate = learning_rate - self.alpha_learning_rate = alpha_learning_rate - self.checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False) - self.dataset_dir = dataset_dir - - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: - if not self.initialized: - self.setup_client(config) - - global_meter = MetricAccumulationMeter(self.metrics, "train_global") - local_meter = MetricAccumulationMeter(self.metrics, "train_local") - personal_meter = MetricAccumulationMeter(self.metrics, "train_personal") - self.set_parameters(parameters, config) - local_steps = self.narrow_config_type(config, "local_steps", int) - metric_values = self.train_by_steps(local_steps, global_meter, local_meter, personal_meter) - # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics - # calculation results. - return ( - self.get_parameters(config), - self.num_examples["train_set"], - metric_values, - ) - - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: - if not self.initialized: - self.setup_client(config) - - self.set_parameters(parameters, config) - global_meter = MetricAccumulationMeter(self.metrics, "val_global") - local_meter = MetricAccumulationMeter(self.metrics, "val_local") - personal_meter = MetricAccumulationMeter(self.metrics, "val_personal") - loss, metric_values = self.validate(global_meter, local_meter, personal_meter) - # EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics - # calculation results. - return ( - loss, - self.num_examples["validation_set"], - metric_values, - ) diff --git a/tests/parameter_exchange/test_apfl_exchange.py b/tests/parameter_exchange/test_apfl_exchange.py index 4cd9c2a1a..886daf7f4 100644 --- a/tests/parameter_exchange/test_apfl_exchange.py +++ b/tests/parameter_exchange/test_apfl_exchange.py @@ -46,20 +46,23 @@ def test_apfl_layer_exchange() -> None: assert np.array_equal(layer_parameters, model_state_dict[layer_name]) input = torch.ones((3, 1, 10, 10)) - # APFL returns a dictionary of tensors. In the case of personal predictions, it produces a convex combination of - # the dual toy model outputs, which have dimension 3 under the key personal and a prediction from the local model - # under the key local - apfl_output_dict = model(input, personal=True) - assert "local" in apfl_output_dict - personal_shape = apfl_output_dict["personal"].shape + # APFL returns the personal prediction which are a combination of the logits of local and global models + personal_shape = model(input).shape # Batch size assert personal_shape[0] == 3 # Output size assert personal_shape[1] == 3 - # Make sure that the APFL module still correctly functions when making predictions using only the global model. It - # should produce a dictionary with key "global" - global_shape = model(input, personal=False)["global"].shape + + # We can get the global preds with the global_forward method + global_shape = model.global_forward(input).shape # Batch size assert global_shape[0] == 3 # Output size assert global_shape[1] == 3 + + # We can get the local preds with the local_forward method + local_shape = model.local_forward(input).shape + # Batch size + assert local_shape[0] == 3 + # Output size + assert local_shape[1] == 3 From 41f4d7d12a656b8aa36b86a62e549bb12d7fa7c0 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 29 Sep 2023 11:36:49 -0400 Subject: [PATCH 02/12] Add comments. Add tests to make sure optimizers are being set properly. Make sure update_after_step is being passed the actual_step not just local step in each epoch. --- fl4health/clients/apfl_client.py | 15 ++++++++++++--- fl4health/clients/basic_client.py | 3 ++- tests/clients/fixtures.py | 3 +++ tests/clients/test_apfl_client.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 tests/clients/test_apfl_client.py diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 1d1e97aa8..98ca0a360 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -3,7 +3,6 @@ from typing import Optional, Sequence, Tuple import torch -import torch.nn as nn from flwr.common.typing import Config from torch.optim import Optimizer @@ -29,9 +28,12 @@ def __init__( super().__init__( data_path, metrics, device, loss_meter_type, metric_meter_type, use_wandb_reporter, checkpointer ) - + # Apfl Module which holds both local and global models + # and gives the ability to get personal, local and global predictions self.model: APFLModule - self.local_model: nn.Module + + # local_optimizer is used on the local model + # Usual self.optimizer is used for global model self.local_optimizer: Optimizer def is_start_of_local_training(self, step: int) -> bool: @@ -42,6 +44,10 @@ def update_after_step(self, step: int) -> None: self.model.update_alpha() def split_optimizer(self, global_optimizer: Optimizer) -> Tuple[Optimizer, Optimizer]: + """ + The optimizer from get_optimizer is for the entire APFLModule. We need one optimizer + for the local model and one optimizer for the global model. + """ global_optimizer.param_groups.clear() global_optimizer.state.clear() local_optimizer = copy.deepcopy(global_optimizer) @@ -55,6 +61,9 @@ def setup_client(self, config: Config) -> None: Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. """ super().setup_client(config) + + # Split optimizer from get_optimizer into two distinct optimizers + # One for local model and one for global model global_optimizer, local_optimizer = self.split_optimizer(self.optimizer) self.optimizer = global_optimizer self.local_optimizer = local_optimizer diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index baac79d38..55076d47c 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -183,7 +183,8 @@ def train_by_epochs( losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) - self.update_after_step(step) + actual_step = int(local_epoch * len(self.train_loader) + step) + self.update_after_step(actual_step) metrics = self.train_metric_meter.compute() losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() diff --git a/tests/clients/fixtures.py b/tests/clients/fixtures.py index 26df857ce..cafc8763b 100644 --- a/tests/clients/fixtures.py +++ b/tests/clients/fixtures.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset +from fl4health.clients.apfl_client import ApflClient from fl4health.clients.evaluate_client import EvaluateClient from fl4health.clients.fed_prox_client import FedProxClient from fl4health.clients.instance_level_privacy_client import InstanceLevelPrivacyClient @@ -34,6 +35,8 @@ def get_client(type: type, model: nn.Module) -> NumpyFlClient: client = DPScaffoldClient(data_path=Path(""), metrics=[Accuracy()], device=torch.device("cpu")) client.noise_multiplier = 1.0 client.clipping_bound = 5.0 + elif type == ApflClient: + client = ApflClient(data_path=Path(""), metrics=[Accuracy()], device=torch.device("cpu")) else: raise ValueError(f"{str(type)} is not a valid client type") diff --git a/tests/clients/test_apfl_client.py b/tests/clients/test_apfl_client.py new file mode 100644 index 000000000..f152ba855 --- /dev/null +++ b/tests/clients/test_apfl_client.py @@ -0,0 +1,30 @@ +import torch +import pytest +from fl4health.clients.apfl_client import ApflClient +from fl4health.model_bases.apfl_base import APFLModule +from tests.test_utils.models_for_test import SmallCnn +from tests.clients.fixtures import get_client + +@pytest.mark.parametrize("type,model", [(ApflClient, APFLModule(SmallCnn()))]) +def test_split_optimizer(get_client: ApflClient) -> None: + apfl_client = get_client + + global_optimizer, local_optimizer = apfl_client.split_optimizer(apfl_client.optimizer) + + # Check that global_optimizer and local_optimizer dont reference the same object + assert global_optimizer is not local_optimizer + + # Check that the param_groups are equivalent since the local and global models are exact copies + # at the start + global_param_groups = global_optimizer.param_groups + local_param_groups = local_optimizer.param_groups + for global_group, local_group in zip(global_param_groups, local_param_groups): + for (global_key, global_vals), (local_key, local_vals) in zip(global_group.items(), local_group.items()): + assert local_key == global_key + assert type(local_vals) == type(global_vals) + # Either Parameter Group or float representing lr + if isinstance(global_vals, list): + for global_val, local_val in zip(global_vals, local_vals): + assert torch.equal(global_val, local_val) + else: + assert global_vals == local_vals \ No newline at end of file From 059fbe4655da78af05f36b1b6a20098f51495aec Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 29 Sep 2023 12:13:45 -0400 Subject: [PATCH 03/12] fix pre-commit issues --- tests/clients/test_apfl_client.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/clients/test_apfl_client.py b/tests/clients/test_apfl_client.py index f152ba855..266eef70c 100644 --- a/tests/clients/test_apfl_client.py +++ b/tests/clients/test_apfl_client.py @@ -1,30 +1,32 @@ -import torch import pytest +import torch + from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import APFLModule +from tests.clients.fixtures import get_client # noqa from tests.test_utils.models_for_test import SmallCnn -from tests.clients.fixtures import get_client + @pytest.mark.parametrize("type,model", [(ApflClient, APFLModule(SmallCnn()))]) -def test_split_optimizer(get_client: ApflClient) -> None: +def test_split_optimizer(get_client: ApflClient) -> None: # noqa apfl_client = get_client global_optimizer, local_optimizer = apfl_client.split_optimizer(apfl_client.optimizer) - - # Check that global_optimizer and local_optimizer dont reference the same object + + # Check that global_optimizer and local_optimizer dont reference the same object assert global_optimizer is not local_optimizer - + # Check that the param_groups are equivalent since the local and global models are exact copies - # at the start + # at the start global_param_groups = global_optimizer.param_groups local_param_groups = local_optimizer.param_groups for global_group, local_group in zip(global_param_groups, local_param_groups): for (global_key, global_vals), (local_key, local_vals) in zip(global_group.items(), local_group.items()): assert local_key == global_key assert type(local_vals) == type(global_vals) - # Either Parameter Group or float representing lr + # Either Parameter Group or float representing lr if isinstance(global_vals, list): - for global_val, local_val in zip(global_vals, local_vals): + for global_val, local_val in zip(global_vals, local_vals): assert torch.equal(global_val, local_val) - else: - assert global_vals == local_vals \ No newline at end of file + else: + assert global_vals == local_vals From d50ae1eb23a57673481fd7322694a0d0d11ed0b9 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 12 Oct 2023 14:51:48 -0400 Subject: [PATCH 04/12] Remove the need to define fedprox fit function --- fl4health/clients/basic_client.py | 8 +++---- fl4health/clients/fed_prox_client.py | 33 +++++----------------------- fl4health/clients/scaffold_client.py | 4 ++-- 3 files changed, 12 insertions(+), 33 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 55076d47c..2e495e980 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -87,15 +87,15 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict self.set_parameters(parameters, config) if local_epochs is not None: - _, metrics = self.train_by_epochs(local_epochs, current_server_round) + loss_dict, metrics = self.train_by_epochs(local_epochs, current_server_round) local_steps = self.num_train_samples * local_epochs # total steps over training round elif local_steps is not None: - _, metrics = self.train_by_steps(local_steps, current_server_round) + loss_dict, metrics = self.train_by_steps(local_steps, current_server_round) else: raise ValueError("Must specify either local_epochs or local_steps in the Config.") # Update after train round (Used by Scaffold and DP-Scaffold Client to update control variates) - self.update_after_train(local_steps) + self.update_after_train(local_steps, loss_dict) # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics # calculation results. @@ -328,7 +328,7 @@ def get_model(self, config: Config) -> nn.Module: """ raise NotImplementedError - def update_after_train(self, local_steps: int) -> None: + def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: pass def update_after_step(self, step: int) -> None: diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index 0be0b2ebd..6713b355f 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence import torch -from flwr.common.typing import Config, NDArrays, Scalar +from flwr.common.typing import Config, NDArrays from fl4health.checkpointing.checkpointer import TorchCheckpointer from fl4health.clients.basic_client import BasicClient @@ -91,31 +91,6 @@ def set_parameters(self, parameters: NDArrays, config: Config) -> None: initial_layer_weights.detach().clone() for initial_layer_weights in self.model.parameters() ] - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: - local_epochs, local_steps, current_server_round = self.process_config(config) - - if not self.initialized: - self.setup_client(config) - - self.set_parameters(parameters, config) - - if local_epochs is not None: - loss_dict, metrics = self.train_by_epochs(local_epochs, current_server_round) - else: - assert isinstance(local_steps, int) - loss_dict, metrics = self.train_by_steps(local_steps, current_server_round) - - # Store current loss which is the vanilla loss without the proximal term added in - self.current_loss = loss_dict["checkpoint"] - - # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics - # calculation results. - return ( - self.get_parameters(config), - self.num_train_samples, - metrics, - ) - def compute_loss(self, preds: torch.Tensor, target: torch.Tensor) -> Losses: loss = self.criterion(preds, target) proximal_loss = self.get_proximal_loss() @@ -125,3 +100,7 @@ def compute_loss(self, preds: torch.Tensor, target: torch.Tensor) -> Losses: def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: return ParameterExchangerWithPacking(ParameterPackerFedProx()) + + def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: + # Store current loss which is the vanilla loss without the proximal term added in + self.current_loss = loss_dict["checkpoint"] diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 8ecba0311..f004edfd7 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple import numpy as np import torch @@ -189,7 +189,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: parameter_exchanger = ParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)) return parameter_exchanger - def update_after_train(self, local_steps: int) -> None: + def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: self.update_control_variates(local_steps) From f462d69c915f7f19a2d97d808e685d1d0ad39dd6 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 13 Oct 2023 14:58:56 -0400 Subject: [PATCH 05/12] Add MetricMeterManager to handle one or more MetricMeters. Define custom compute loss for APFL. Return dict with personal, global and local preds from ApflClient.predict. Make some changes to BasicClient for this all to work. Added some comments --- fl4health/clients/apfl_client.py | 26 ++++++--- fl4health/clients/basic_client.py | 45 ++++++++++------ fl4health/clients/fed_prox_client.py | 4 +- fl4health/clients/scaffold_client.py | 7 ++- fl4health/model_bases/apfl_base.py | 9 ++-- fl4health/utils/metrics.py | 54 ++++++++++++++++++- .../parameter_exchange/test_apfl_exchange.py | 12 ++--- 7 files changed, 119 insertions(+), 38 deletions(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 98ca0a360..ebcc8a157 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -1,6 +1,6 @@ import copy from pathlib import Path -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple, Union import torch from flwr.common.typing import Config @@ -68,7 +68,11 @@ def setup_client(self, config: Config) -> None: self.optimizer = global_optimizer self.local_optimizer = local_optimizer - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: + def train_step( + self, input: torch.Tensor, target: torch.Tensor + ) -> Union[Tuple[Losses, torch.Tensor], Tuple[Losses, Dict[str, torch.Tensor]]]: + # Return preds value of torch.Tensor containing personal, global and local predictions + # Mechanics of training loop follow from original implementation # https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py @@ -88,14 +92,24 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, # Personal predictions are generated as a convex combination of the output # of local and global models - personal_pred = self.predict(input) - + preds = self.predict(input) + assert isinstance(preds, dict) # Parameters of local model are updated to minimize loss of personalized model - losses = self.compute_loss(personal_pred, target) + losses = self.compute_loss(preds, target) losses.backward.backward() self.local_optimizer.step() - return losses, personal_pred + # Return dictionairy of predictions where key is used to name respective MetricMeters + return losses, preds def get_parameter_exchanger(self, config: Config) -> FixedLayerExchanger: return FixedLayerExchanger(self.model.layers_to_exchange()) + + def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses: + assert isinstance(preds, dict) + personal_loss = self.criterion(preds["personal"], target) + global_loss = self.criterion(preds["global"], target) + local_loss = self.criterion(preds["local"], target) + additional_losses = {"global": global_loss, "local": local_loss} + losses = Losses(checkpoint=personal_loss, backward=personal_loss, additional_losses=additional_losses) + return losses diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 2e495e980..92486502b 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -16,7 +16,7 @@ from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.fl_wanb import ClientWandBReporter from fl4health.utils.losses import Losses, LossMeter, LossMeterType -from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterType +from fl4health.utils.metrics import Metric, MetricMeterManager, MetricMeterType class BasicClient(NumpyFlClient): @@ -42,8 +42,8 @@ def __init__( self.checkpointer = checkpointer self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) - self.train_metric_meter = MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter") - self.val_metric_meter = MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter") + self.train_metric_meter_mngr = MetricMeterManager(self.metrics, metric_meter_type, "train_meter") + self.val_metric_meter_mngr = MetricMeterManager(self.metrics, metric_meter_type, "val_meter") self.model: nn.Module self.optimizer: torch.optim.Optimizer @@ -140,10 +140,14 @@ def _handle_logging( f"Client {metric_prefix} Losses: {loss_string} \n" f"Client {metric_prefix} Metrics: {metric_string}", ) - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: + def train_step( + self, input: torch.Tensor, target: torch.Tensor + ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """ Given input and target, generate predictions, compute loss, optionally update metrics if they exist. Assumes self.model is in train model already. + The preds value that is returned is torch.Tensor for normal cases when there is only a single prediction. + In cases where there are multiple prediction types (ie APFL), a Dict of torch.Tensor is returned. """ # Clear gradients from optimizer if they exist self.optimizer.zero_grad() @@ -158,10 +162,14 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, return losses, preds - def val_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: + def val_step( + self, input: torch.Tensor, target: torch.Tensor + ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """ Given input and target, compute loss, update loss and metrics Assumes self.model is in eval mode already. + The preds value that is returned is torch.Tensor for normal cases when there is only a single prediction. + In cases where there are multiple prediction types (ie APFL), a Dict of torch.Tensor is returned. """ # Get preds and compute loss @@ -176,16 +184,16 @@ def train_by_epochs( ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: self.model.train() for local_epoch in range(epochs): - self.train_metric_meter.clear() + self.train_metric_meter_mngr.clear() self.train_loss_meter.clear() for step, (input, target) in enumerate(self.train_loader): input, target = input.to(self.device), target.to(self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) - self.train_metric_meter.update(preds, target) + self.train_metric_meter_mngr.update(preds, target) actual_step = int(local_epoch * len(self.train_loader) + step) self.update_after_step(actual_step) - metrics = self.train_metric_meter.compute() + metrics = self.train_metric_meter_mngr.compute() losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() @@ -203,7 +211,7 @@ def train_by_steps( train_iterator = iter(self.train_loader) self.train_loss_meter.clear() - self.train_metric_meter.clear() + self.train_metric_meter_mngr.clear() for step in range(steps): try: input, target = next(train_iterator) @@ -216,12 +224,12 @@ def train_by_steps( input, target = input.to(self.device), target.to(self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) - self.train_metric_meter.update(preds, target) + self.train_metric_meter_mngr.update(preds, target) self.update_after_step(step) losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() - metrics = self.train_metric_meter.compute() + metrics = self.train_metric_meter_mngr.compute() self._handle_logging(loss_dict, metrics, current_round=current_round) @@ -229,19 +237,19 @@ def train_by_steps( def validate(self) -> Tuple[float, Dict[str, Scalar]]: self.model.eval() - self.val_metric_meter.clear() + self.val_metric_meter_mngr.clear() self.val_loss_meter.clear() with torch.no_grad(): for input, target in self.val_loader: input, target = input.to(self.device), target.to(self.device) losses, preds = self.val_step(input, target) self.val_loss_meter.update(losses) - self.val_metric_meter.update(preds, target) + self.val_metric_meter_mngr.update(preds, target) # Compute losses and metrics over validation set losses = self.val_loss_meter.compute() loss_dict = losses.as_dict() - metrics = self.val_metric_meter.compute() + metrics = self.val_metric_meter_mngr.compute() self._handle_logging(loss_dict, metrics, is_validation=True) # Checkpoint based on loss which is output of user defined compute_loss method @@ -288,16 +296,19 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ return FullParameterExchanger() - def predict(self, input: torch.Tensor) -> torch.Tensor: + def predict(self, input: torch.Tensor) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ - Return predictions when given input. User can override for more complex logic. + Return predictions when given input. Returns torch tensor for non personalized model with single output. + Returns dict of torch tensor for personalized model with multiple outputs (ie APFL) + User can override for more complex logic. """ return self.model(input) - def compute_loss(self, preds: torch.Tensor, target: torch.Tensor) -> Losses: + def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses: """ Computes loss given preds and torch and the user defined criterion. Optionally includes dictionairy of loss components if you wish to train the total loss as well as sub losses if they exist. + Input can be a single torch.Tensor or a Dict of torch tensors for each prediction type. """ loss = self.criterion(preds, target) losses = Losses(checkpoint=loss, backward=loss) diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index 6713b355f..913098409 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Union import torch from flwr.common.typing import Config, NDArrays @@ -91,7 +91,7 @@ def set_parameters(self, parameters: NDArrays, config: Config) -> None: initial_layer_weights.detach().clone() for initial_layer_weights in self.model.parameters() ] - def compute_loss(self, preds: torch.Tensor, target: torch.Tensor) -> Losses: + def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses: loss = self.criterion(preds, target) proximal_loss = self.get_proximal_loss() total_loss = loss + proximal_loss diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index f004edfd7..099a20bff 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -169,7 +169,9 @@ def compute_updated_control_variates( ] return updated_client_control_variates - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: + def train_step( + self, input: torch.Tensor, target: torch.Tensor + ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: # Clear gradients from optimizer if they exist self.optimizer.zero_grad() @@ -181,6 +183,7 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, losses.backward.backward() self.modify_grad() self.optimizer.step() + return losses, preds def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: diff --git a/fl4health/model_bases/apfl_base.py b/fl4health/model_bases/apfl_base.py index 67ab8087d..5813e1949 100644 --- a/fl4health/model_bases/apfl_base.py +++ b/fl4health/model_bases/apfl_base.py @@ -1,5 +1,5 @@ import copy -from typing import List +from typing import Dict, List import torch import torch.nn as nn @@ -27,12 +27,13 @@ def global_forward(self, input: torch.Tensor) -> torch.Tensor: def local_forward(self, input: torch.Tensor) -> torch.Tensor: return self.local_model(input) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: + # Forward return dictionairy because APFL has multiple different prediction types global_logits = self.global_forward(input) local_logits = self.local_forward(input) personal_logits = self.alpha * local_logits + (1.0 - self.alpha) * global_logits - - return personal_logits + preds = {"personal": personal_logits, "global": global_logits, "local": local_logits} + return preds def update_alpha(self) -> None: # Updates to mixture parameter follow original implementation diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index a2ace7534..cf740b38a 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -1,8 +1,9 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from enum import Enum -from typing import Dict, List, Sequence, Tuple +from typing import Dict, List, Sequence, Tuple, Union import numpy as np import torch @@ -228,3 +229,54 @@ def compute(self) -> Dict[str, Scalar]: def clear(self) -> None: self.metric_values_history = [[] for _ in range(len(self.metrics))] self.counts = [] + + +class MetricMeterManager: + """ + Class to manage one or metric meters. + """ + + def __init__(self, metrics: Sequence[Metric], metric_meter_type: MetricMeterType, name: str): + self.metrics = metrics + self.metric_meter_type = metric_meter_type + self.name = name + self.meters: Optional[Sequence[MetricMeter]] = None + + def update(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> None: + # Meters are initialized in the update so we know the number and name of meters + # If preds is a torch tensor, this is the standard case where the meter manager has a single meter + # If preds is a dict, there is multiple predictions (ie for APFL), so the amount of meters is equal + # to the amount of different prediction types (ie global, local, personal) + if self.meters is None: + + if isinstance(preds, torch.Tensor): + self.meters = [ + MetricMeter.get_meter_by_type(copy.deepcopy(self.metrics), self.metric_meter_type, self.name) + ] + else: + self.meters = [ + MetricMeter.get_meter_by_type( + copy.deepcopy(self.metrics), self.metric_meter_type, f"{self.name} {key}" + ) + for key in preds.keys() + ] + preds_list = [preds] if isinstance(preds, torch.Tensor) else preds.values() + for meter, preds in zip(self.meters, preds_list): + meter.update(preds, target) + + def compute(self) -> Dict[str, Scalar]: + assert self.meters is not None + all_results = {} + for meter in self.meters: + result = meter.compute() + all_results.update(result) + + return all_results + + def clear(self) -> None: + # If meters is none, no need to clear + if self.meters is None: + return + + for meter in self.meters: + meter.clear() diff --git a/tests/parameter_exchange/test_apfl_exchange.py b/tests/parameter_exchange/test_apfl_exchange.py index 886daf7f4..6c7e6091d 100644 --- a/tests/parameter_exchange/test_apfl_exchange.py +++ b/tests/parameter_exchange/test_apfl_exchange.py @@ -46,22 +46,22 @@ def test_apfl_layer_exchange() -> None: assert np.array_equal(layer_parameters, model_state_dict[layer_name]) input = torch.ones((3, 1, 10, 10)) - # APFL returns the personal prediction which are a combination of the logits of local and global models - personal_shape = model(input).shape + # APFL returns a dict with personal, global and local predicitons + # Assert return values of each prediction type is correct dim + preds = model(input) + personal_shape = preds["personal"].shape # Batch size assert personal_shape[0] == 3 # Output size assert personal_shape[1] == 3 - # We can get the global preds with the global_forward method - global_shape = model.global_forward(input).shape + global_shape = preds["global"].shape # Batch size assert global_shape[0] == 3 # Output size assert global_shape[1] == 3 - # We can get the local preds with the local_forward method - local_shape = model.local_forward(input).shape + local_shape = preds["local"].shape # Batch size assert local_shape[0] == 3 # Output size From 61693e7a371f5a3b4a0ab9237a87cc0dcf339b2d Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 13 Oct 2023 15:39:11 -0400 Subject: [PATCH 06/12] Have get optimizer return optimizer or dict of optimizer. Update apfl examlples accordingly. Remove no longer relevant split optimizer test --- examples/apfl_example/client.py | 8 ++-- fl4health/clients/apfl_client.py | 48 +++++++++++-------- fl4health/clients/basic_client.py | 7 ++- .../flamby/fed_heart_disease/apfl/client.py | 8 ++-- research/flamby/fed_isic2019/apfl/client.py | 8 ++-- research/flamby/fed_ixi/apfl/client.py | 8 ++-- tests/clients/test_apfl_client.py | 32 ------------- 7 files changed, 52 insertions(+), 67 deletions(-) delete mode 100644 tests/clients/test_apfl_client.py diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index dda1d9c9c..e84c27f80 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -1,6 +1,6 @@ import argparse from pathlib import Path -from typing import Tuple +from typing import Dict, Tuple import flwr as fl import torch @@ -28,8 +28,10 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_model(self, config: Config) -> nn.Module: return APFLModule(MnistNetWithBnAndFrozen()).to(self.device) - def get_optimizer(self, config: Config) -> Optimizer: - return torch.optim.AdamW(self.model.parameters(), lr=0.01) + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01) + global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01) + return {"local": local_optimizer, "global": global_optimizer} def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index ebcc8a157..bad436d26 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -1,4 +1,3 @@ -import copy from pathlib import Path from typing import Dict, Optional, Sequence, Tuple, Union @@ -10,6 +9,7 @@ from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import APFLModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger +from fl4health.reporting.fl_wanb import ClientWandBReporter from fl4health.utils.losses import Losses, LossMeterType from fl4health.utils.metrics import Metric, MetricMeterType @@ -43,35 +43,41 @@ def update_after_step(self, step: int) -> None: if self.is_start_of_local_training(step) and self.model.adaptive_alpha: self.model.update_alpha() - def split_optimizer(self, global_optimizer: Optimizer) -> Tuple[Optimizer, Optimizer]: - """ - The optimizer from get_optimizer is for the entire APFLModule. We need one optimizer - for the local model and one optimizer for the global model. - """ - global_optimizer.param_groups.clear() - global_optimizer.state.clear() - local_optimizer = copy.deepcopy(global_optimizer) - - global_optimizer.add_param_group({"params": [p for p in self.model.global_model.parameters()]}) - local_optimizer.add_param_group({"params": [p for p in self.model.local_model.parameters()]}) - return global_optimizer, local_optimizer - def setup_client(self, config: Config) -> None: """ Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. """ - super().setup_client(config) + model = self.get_model(config) + assert isinstance(model, APFLModule) + self.model = model + train_loader, val_loader = self.get_data_loaders(config) + self.train_loader = train_loader + self.val_loader = val_loader + + # The following lines are type ignored because torch datasets are not "Sized" + # IE __len__ is considered optionally defined. In practice, it is almost always defined + # and as such, we will make that assumption. + self.num_train_samples = len(self.train_loader.dataset) # type: ignore + self.num_val_samples = len(self.val_loader.dataset) # type: ignore + + optimizer_dict = self.get_optimizer(config) + assert isinstance(optimizer_dict, dict) + self.optimizer = optimizer_dict["global"] + self.local_optimizer = optimizer_dict["local"] + + self.learning_rate = self.optimizer.defaults["lr"] + self.criterion = self.get_criterion(config) + self.parameter_exchanger = self.get_parameter_exchanger(config) + + if self.use_wandb_reporter: + self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) - # Split optimizer from get_optimizer into two distinct optimizers - # One for local model and one for global model - global_optimizer, local_optimizer = self.split_optimizer(self.optimizer) - self.optimizer = global_optimizer - self.local_optimizer = local_optimizer + self.initialized = True def train_step( self, input: torch.Tensor, target: torch.Tensor ) -> Union[Tuple[Losses, torch.Tensor], Tuple[Losses, Dict[str, torch.Tensor]]]: - # Return preds value of torch.Tensor containing personal, global and local predictions + # Return preds value thats Dict of torch.Tensor containing personal, global and local predictions # Mechanics of training loop follow from original implementation # https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 92486502b..1365f5216 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -280,7 +280,10 @@ def setup_client(self, config: Config) -> None: self.num_train_samples = len(self.train_loader.dataset) # type: ignore self.num_val_samples = len(self.val_loader.dataset) # type: ignore - self.optimizer = self.get_optimizer(config) + optimizer = self.get_optimizer(config) + assert isinstance(optimizer, Optimizer) + self.optimizer = optimizer + self.learning_rate = self.optimizer.defaults["lr"] self.criterion = self.get_criterion(config) self.parameter_exchanger = self.get_parameter_exchanger(config) @@ -327,7 +330,7 @@ def get_criterion(self, config: Config) -> _Loss: """ raise NotImplementedError - def get_optimizer(self, config: Config) -> Optimizer: + def get_optimizer(self, config: Config) -> Union[Optimizer, Dict[str, Optimizer]]: """ Method to be defined by user that returns the PyTorch optimizer used to train models locally """ diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index c26f7df15..eee82e0b9 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -2,7 +2,7 @@ import os from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -62,8 +62,10 @@ def get_model(self, config: Config) -> APFLModule: model: APFLModule = APFLModule(Baseline(), alpha_lr=self.alpha_learning_rate).to(self.device) return model - def get_optimizer(self, config: Config) -> Optimizer: - return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) + global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + return {"local": local_optimizer, "global": global_optimizer} def get_criterion(self, config: Config) -> _Loss: return BaselineLoss() diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 545810377..ec204a77b 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -2,7 +2,7 @@ import os from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -69,8 +69,10 @@ def get_model(self, config: Config) -> nn.Module: ).to(self.device) return model - def get_optimizer(self, config: Config) -> Optimizer: - return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + local_optimizer: Optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) + global_optimizer: Optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + return {"local": local_optimizer, "global": global_optimizer} if __name__ == "__main__": diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index 5e4136e11..66a375e75 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -2,7 +2,7 @@ import os from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -64,8 +64,10 @@ def get_model(self, config: Config) -> nn.Module: model: APFLModule = APFLModule(APFLUNet(), alpha_lr=self.alpha_learning_rate).to(self.device) return model - def get_optiizer(self, config: Config) -> Optimizer: - return torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) + global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) + return {"local": local_optimizer, "global": global_optimizer} def get_criterion(self, config: Config) -> _Loss: return BaselineLoss() diff --git a/tests/clients/test_apfl_client.py b/tests/clients/test_apfl_client.py deleted file mode 100644 index 266eef70c..000000000 --- a/tests/clients/test_apfl_client.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -import torch - -from fl4health.clients.apfl_client import ApflClient -from fl4health.model_bases.apfl_base import APFLModule -from tests.clients.fixtures import get_client # noqa -from tests.test_utils.models_for_test import SmallCnn - - -@pytest.mark.parametrize("type,model", [(ApflClient, APFLModule(SmallCnn()))]) -def test_split_optimizer(get_client: ApflClient) -> None: # noqa - apfl_client = get_client - - global_optimizer, local_optimizer = apfl_client.split_optimizer(apfl_client.optimizer) - - # Check that global_optimizer and local_optimizer dont reference the same object - assert global_optimizer is not local_optimizer - - # Check that the param_groups are equivalent since the local and global models are exact copies - # at the start - global_param_groups = global_optimizer.param_groups - local_param_groups = local_optimizer.param_groups - for global_group, local_group in zip(global_param_groups, local_param_groups): - for (global_key, global_vals), (local_key, local_vals) in zip(global_group.items(), local_group.items()): - assert local_key == global_key - assert type(local_vals) == type(global_vals) - # Either Parameter Group or float representing lr - if isinstance(global_vals, list): - for global_val, local_val in zip(global_vals, local_vals): - assert torch.equal(global_val, local_val) - else: - assert global_vals == local_vals From 6b75b7429757d1ee0aa173050e6ad6602c35d98d Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 13 Oct 2023 15:49:31 -0400 Subject: [PATCH 07/12] Slight typing change --- fl4health/clients/apfl_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index bad436d26..ebd3f25f6 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -76,7 +76,7 @@ def setup_client(self, config: Config) -> None: def train_step( self, input: torch.Tensor, target: torch.Tensor - ) -> Union[Tuple[Losses, torch.Tensor], Tuple[Losses, Dict[str, torch.Tensor]]]: + ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: # Return preds value thats Dict of torch.Tensor containing personal, global and local predictions # Mechanics of training loop follow from original implementation From 6cf8386c62b20fc99dc7a8f93bd12efdc4293a76 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Mon, 16 Oct 2023 11:47:29 -0400 Subject: [PATCH 08/12] Initialize MetricMeterManager with mapping between prediciton key in Client constructor --- fl4health/clients/apfl_client.py | 41 ++++++++++++++++++++----- fl4health/clients/basic_client.py | 31 +++++++++++-------- fl4health/clients/fed_prox_client.py | 6 ++-- fl4health/clients/scaffold_client.py | 6 ++-- fl4health/utils/metrics.py | 46 ++++++---------------------- 5 files changed, 67 insertions(+), 63 deletions(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index a934c64c1..00ab867b4 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -4,14 +4,15 @@ import torch from flwr.common.typing import Config from torch.optim import Optimizer +from torch.utils.data import DataLoader from fl4health.checkpointing.checkpointer import TorchCheckpointer from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import APFLModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.reporting.fl_wanb import ClientWandBReporter -from fl4health.utils.losses import Losses, LossMeterType -from fl4health.utils.metrics import Metric, MetricMeterType +from fl4health.utils.losses import Losses, LossMeter, LossMeterType +from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType class ApflClient(BasicClient): @@ -24,7 +25,36 @@ def __init__( metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, metric_meter_type, checkpointer) + super(BasicClient, self).__init__(data_path, device) + + self.metrics = metrics + self.checkpointer = checkpointer + self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + + train_key_to_meter_map = { + "personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train meter - personal"), + "global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train meter - global"), + "local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train meter - local"), + } + self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map) + val_key_to_meter_map = { + "personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val meter - personal"), + "global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val meter - global"), + "local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val meter - local"), + } + self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) + + self.optimizer: torch.optim.Optimizer + + self.train_loader: DataLoader + self.val_loader: DataLoader + self.num_train_samples: int + self.num_val_samples: int + self.learning_rate: float + + # Need to track total_steps across rounds for WANDB reporting + self.total_steps: int = 0 # Apfl Module which holds both local and global models # and gives the ability to get personal, local and global predictions self.model: APFLModule @@ -70,9 +100,7 @@ def setup_client(self, config: Config) -> None: self.initialized = True - def train_step( - self, input: torch.Tensor, target: torch.Tensor - ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, Dict[str, torch.Tensor]]: # Return preds value thats Dict of torch.Tensor containing personal, global and local predictions # Mechanics of training loop follow from original implementation @@ -95,7 +123,6 @@ def train_step( # Personal predictions are generated as a convex combination of the output # of local and global models preds = self.predict(input) - assert isinstance(preds, dict) # Parameters of local model are updated to minimize loss of personalized model losses = self.compute_loss(preds, target) losses.backward.backward() diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 76f36c401..69f7c4136 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -16,7 +16,7 @@ from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.fl_wanb import ClientWandBReporter from fl4health.utils.losses import Losses, LossMeter, LossMeterType -from fl4health.utils.metrics import Metric, MetricMeterManager, MetricMeterType +from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType class BasicClient(NumpyFlClient): @@ -40,8 +40,15 @@ def __init__( self.checkpointer = checkpointer self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) - self.train_metric_meter_mngr = MetricMeterManager(self.metrics, metric_meter_type, "train_meter") - self.val_metric_meter_mngr = MetricMeterManager(self.metrics, metric_meter_type, "val_meter") + + train_key_to_meter_map = { + "prediction": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter") + } + self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map) + val_key_to_meter_map = { + "prediction": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter") + } + self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) self.model: nn.Module self.optimizer: torch.optim.Optimizer @@ -162,9 +169,7 @@ def _handle_reporting( reporting_dict.update(metric_dict) self.wandb_reporter.report_metrics(reporting_dict) - def train_step( - self, input: torch.Tensor, target: torch.Tensor - ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, Dict[str, torch.Tensor]]: """ Given input and target, generate predictions, compute loss, optionally update metrics if they exist. Assumes self.model is in train model already. @@ -184,9 +189,7 @@ def train_step( return losses, preds - def val_step( - self, input: torch.Tensor, target: torch.Tensor - ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def val_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, Dict[str, torch.Tensor]]: """ Given input and target, compute loss, update loss and metrics Assumes self.model is in eval mode already. @@ -327,21 +330,23 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ return FullParameterExchanger() - def predict(self, input: torch.Tensor) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: """ Return predictions when given input. Returns torch tensor for non personalized model with single output. Returns dict of torch tensor for personalized model with multiple outputs (ie APFL) User can override for more complex logic. """ - return self.model(input) + preds = self.model(input) + preds = preds if isinstance(preds, dict) else {"prediction": preds} + return preds - def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses: + def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses: """ Computes loss given preds and torch and the user defined criterion. Optionally includes dictionairy of loss components if you wish to train the total loss as well as sub losses if they exist. Input can be a single torch.Tensor or a Dict of torch tensors for each prediction type. """ - loss = self.criterion(preds, target) + loss = self.criterion(preds["prediction"], target) losses = Losses(checkpoint=loss, backward=loss) return losses diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index 308b53e4f..cee75891e 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence import torch from flwr.common.typing import Config, NDArrays @@ -89,8 +89,8 @@ def set_parameters(self, parameters: NDArrays, config: Config) -> None: initial_layer_weights.detach().clone() for initial_layer_weights in self.model.parameters() ] - def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> Losses: - loss = self.criterion(preds, target) + def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses: + loss = self.criterion(preds["prediction"], target) proximal_loss = self.get_proximal_loss() total_loss = loss + proximal_loss losses = Losses(checkpoint=loss, backward=total_loss, additional_losses={"proximal_loss": proximal_loss}) diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 8853a5d75..15dfe4458 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple import numpy as np import torch @@ -167,9 +167,7 @@ def compute_updated_control_variates( ] return updated_client_control_variates - def train_step( - self, input: torch.Tensor, target: torch.Tensor - ) -> Tuple[Losses, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, Dict[str, torch.Tensor]]: # Clear gradients from optimizer if they exist self.optimizer.zero_grad() diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index cf740b38a..8864a92da 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -1,9 +1,8 @@ from __future__ import annotations -import copy from abc import ABC, abstractmethod from enum import Enum -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Sequence, Tuple import numpy as np import torch @@ -236,47 +235,22 @@ class MetricMeterManager: Class to manage one or metric meters. """ - def __init__(self, metrics: Sequence[Metric], metric_meter_type: MetricMeterType, name: str): - self.metrics = metrics - self.metric_meter_type = metric_meter_type - self.name = name - self.meters: Optional[Sequence[MetricMeter]] = None - - def update(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], target: torch.Tensor) -> None: - # Meters are initialized in the update so we know the number and name of meters - # If preds is a torch tensor, this is the standard case where the meter manager has a single meter - # If preds is a dict, there is multiple predictions (ie for APFL), so the amount of meters is equal - # to the amount of different prediction types (ie global, local, personal) - if self.meters is None: - - if isinstance(preds, torch.Tensor): - self.meters = [ - MetricMeter.get_meter_by_type(copy.deepcopy(self.metrics), self.metric_meter_type, self.name) - ] - else: - self.meters = [ - MetricMeter.get_meter_by_type( - copy.deepcopy(self.metrics), self.metric_meter_type, f"{self.name} {key}" - ) - for key in preds.keys() - ] - preds_list = [preds] if isinstance(preds, torch.Tensor) else preds.values() - for meter, preds in zip(self.meters, preds_list): - meter.update(preds, target) + def __init__(self, key_to_meter_map: Dict[str, MetricMeter]): + self.key_to_meter_map = key_to_meter_map + + def update(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> None: + for map_key, pred_key in zip(sorted(self.key_to_meter_map.keys()), sorted(preds.keys())): + assert map_key == pred_key + self.key_to_meter_map[map_key].update(preds[pred_key], target) def compute(self) -> Dict[str, Scalar]: - assert self.meters is not None all_results = {} - for meter in self.meters: + for meter in self.key_to_meter_map.values(): result = meter.compute() all_results.update(result) return all_results def clear(self) -> None: - # If meters is none, no need to clear - if self.meters is None: - return - - for meter in self.meters: + for meter in self.key_to_meter_map.values(): meter.clear() From e15e2f8a3dac6f637c956cdbabf3265a1af6bd41 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Mon, 16 Oct 2023 12:07:21 -0400 Subject: [PATCH 09/12] Add some comments as suggested --- fl4health/clients/apfl_client.py | 5 +++++ fl4health/clients/basic_client.py | 23 ++++++++++++++++------- fl4health/clients/fed_prox_client.py | 4 ++++ fl4health/clients/scaffold_client.py | 4 ++++ 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 00ab867b4..7f7a4a0a1 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -32,6 +32,7 @@ def __init__( self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + # Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val train_key_to_meter_map = { "personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train meter - personal"), "global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train meter - global"), @@ -67,6 +68,10 @@ def is_start_of_local_training(self, step: int) -> bool: return step == 0 def update_after_step(self, step: int) -> None: + """ + Called after local train step on client. step is an integer that represents + the local training step that was most recently completed. + """ if self.is_start_of_local_training(step) and self.model.adaptive_alpha: self.model.update_alpha() diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 69f7c4136..845f66445 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -41,6 +41,7 @@ def __init__( self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + # Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val train_key_to_meter_map = { "prediction": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter") } @@ -173,8 +174,6 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, """ Given input and target, generate predictions, compute loss, optionally update metrics if they exist. Assumes self.model is in train model already. - The preds value that is returned is torch.Tensor for normal cases when there is only a single prediction. - In cases where there are multiple prediction types (ie APFL), a Dict of torch.Tensor is returned. """ # Clear gradients from optimizer if they exist self.optimizer.zero_grad() @@ -193,8 +192,6 @@ def val_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, D """ Given input and target, compute loss, update loss and metrics Assumes self.model is in eval mode already. - The preds value that is returned is torch.Tensor for normal cases when there is only a single prediction. - In cases where there are multiple prediction types (ie APFL), a Dict of torch.Tensor is returned. """ # Get preds and compute loss @@ -332,8 +329,9 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: """ - Return predictions when given input. Returns torch tensor for non personalized model with single output. - Returns dict of torch tensor for personalized model with multiple outputs (ie APFL) + Return dict of str and torch.Tensor contaiing predictions when given input. + In the default case, the dict has a single item with key prediction. + In more complicated approaches such as APFL, the dict has as many items as prediction types User can override for more complex logic. """ preds = self.model(input) @@ -344,7 +342,10 @@ def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> """ Computes loss given preds and torch and the user defined criterion. Optionally includes dictionairy of loss components if you wish to train the total loss as well as sub losses if they exist. - Input can be a single torch.Tensor or a Dict of torch tensors for each prediction type. + Predicitons are a dictionairy of str and torch.Tensor. In the base case we have one set of prediction + stored in the prediction key of the dict. + For more complicated loss computations (additional loss components or multiple prediction types) + this method should be overridden. """ loss = self.criterion(preds["prediction"], target) losses = Losses(checkpoint=loss, backward=loss) @@ -376,7 +377,15 @@ def get_model(self, config: Config) -> nn.Module: raise NotImplementedError def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: + """ + Called after training with the number of local_steps performed over the FL round and + the corresponding loss dictionairy. + """ pass def update_after_step(self, step: int) -> None: + """ + Called after local train step on client. step is an integer that represents + the local training step that was most recently completed. + """ pass diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index cee75891e..b56882f8e 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -100,5 +100,9 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: return ParameterExchangerWithPacking(ParameterPackerFedProx()) def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: + """ + Called after training with the number of local_steps performed over the FL round and + the corresponding loss dictionairy. + """ # Store current loss which is the vanilla loss without the proximal term added in self.current_loss = loss_dict["checkpoint"] diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 15dfe4458..7cdd16e4e 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -189,6 +189,10 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: return parameter_exchanger def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None: + """ + Called after training with the number of local_steps performed over the FL round and + the corresponding loss dictionairy. + """ self.update_control_variates(local_steps) From 898c274bc2cf10a00dc47edf14c5a9936a70fc57 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Tue, 17 Oct 2023 16:26:29 -0400 Subject: [PATCH 10/12] Fix unwanted changes from resolving merge conflict --- fl4health/clients/apfl_client.py | 40 +++++++++++++++++++++++-------- fl4health/clients/basic_client.py | 1 + 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 1573c316c..5d58f93e5 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -11,8 +11,8 @@ from fl4health.model_bases.apfl_base import ApflModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.reporting.fl_wanb import ClientWandBReporter -from fl4health.utils.losses import Losses, LossMeterType -from fl4health.utils.metrics import Metric, MetricMeterType +from fl4health.utils.losses import Losses, LossMeter, LossMeterType +from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType class ApflClient(BasicClient): @@ -26,23 +26,37 @@ def __init__( checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super(BasicClient, self).__init__(data_path, device) - self.metrics = metrics + self.checkpointer = checkpointer + self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) + + # Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val + train_key_to_meter_map = { + "local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_local"), + "global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_global"), + "personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "train_meter_personal"), + } + self.train_metric_meter_mngr = MetricMeterManager(train_key_to_meter_map) + val_key_to_meter_map = { + "local": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_local"), + "global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_global"), + "personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_personal"), + } + self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) + self.train_loader: DataLoader self.val_loader: DataLoader self.num_train_samples: int self.num_val_samples: int + + self.model: ApflModule self.learning_rate: float + self.optimizer: torch.optim.Optimizer + self.local_optimizer: torch.optim.Optimizer # Need to track total_steps across rounds for WANDB reporting self.total_steps: int = 0 - # Apfl Module which holds both local and global models - # and gives the ability to get personal, local and global predictions - self.model: ApflModule - - # local_optimizer is used on the local model - # Usual self.optimizer is used for global model - self.local_optimizer: Optimizer def is_start_of_local_training(self, step: int) -> bool: return step == 0 @@ -127,3 +141,9 @@ def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], targ additional_losses = {"global": global_loss, "local": local_loss} losses = Losses(checkpoint=personal_loss, backward=personal_loss, additional_losses=additional_losses) return losses + + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + """ + Returns a dictionairy with global and local optimizers with string keys 'global' and 'local' respectively. + """ + raise NotImplementedError diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 698dae3e9..e50a1d358 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -364,6 +364,7 @@ def get_criterion(self, config: Config) -> _Loss: def get_optimizer(self, config: Config) -> Union[Optimizer, Dict[str, Optimizer]]: """ Method to be defined by user that returns the PyTorch optimizer used to train models locally + Return value can be a single torch optimizer or a dictionary of string and torch optimizer. """ raise NotImplementedError From 2498ee5922672ad8635ee779c9a6e20779136515 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 18 Oct 2023 16:32:01 -0400 Subject: [PATCH 11/12] Fix APFL example --- examples/apfl_example/client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index 1081d4258..cd73440e8 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -21,10 +21,6 @@ class MnistApflClient(ApflClient): def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: batch_size = self.narrow_config_type(config, "batch_size", int) - self.model: ApflModule = ApflModule(MnistNetWithBnAndFrozen()).to(self.device) - self.criterion = torch.nn.CrossEntropyLoss() - self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01) - self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) return train_loader, val_loader From 9959c896eba8b13846f86a313be23b34619dc1cb Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 18 Oct 2023 18:57:57 -0400 Subject: [PATCH 12/12] Updates based on Davids suggestions --- fl4health/clients/apfl_client.py | 37 ++++----------------- fl4health/clients/basic_client.py | 26 +++++++++++---- fl4health/utils/metrics.py | 7 ++-- research/flamby/fed_isic2019/apfl/client.py | 2 ++ 4 files changed, 31 insertions(+), 41 deletions(-) diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 5d58f93e5..a967910e4 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -10,7 +10,6 @@ from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger -from fl4health.reporting.fl_wanb import ClientWandBReporter from fl4health.utils.losses import Losses, LossMeter, LossMeterType from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType @@ -69,36 +68,6 @@ def update_after_step(self, step: int) -> None: if self.is_start_of_local_training(step) and self.model.adaptive_alpha: self.model.update_alpha() - def setup_client(self, config: Config) -> None: - """ - Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. - """ - model = self.get_model(config) - assert isinstance(model, ApflModule) - self.model = model - train_loader, val_loader = self.get_data_loaders(config) - self.train_loader = train_loader - self.val_loader = val_loader - - # The following lines are type ignored because torch datasets are not "Sized" - # IE __len__ is considered optionally defined. In practice, it is almost always defined - # and as such, we will make that assumption. - self.num_train_samples = len(self.train_loader.dataset) # type: ignore - self.num_val_samples = len(self.val_loader.dataset) # type: ignore - - optimizer_dict = self.get_optimizer(config) - assert isinstance(optimizer_dict, dict) - self.optimizer = optimizer_dict["global"] - self.local_optimizer = optimizer_dict["local"] - - self.learning_rate = self.optimizer.defaults["lr"] - self.criterion = self.get_criterion(config) - self.parameter_exchanger = self.get_parameter_exchanger(config) - - self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) - - self.initialized = True - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, Dict[str, torch.Tensor]]: # Return preds value thats Dict of torch.Tensor containing personal, global and local predictions @@ -142,6 +111,12 @@ def compute_loss(self, preds: Union[torch.Tensor, Dict[str, torch.Tensor]], targ losses = Losses(checkpoint=personal_loss, backward=personal_loss, additional_losses=additional_losses) return losses + def set_optimizer(self, config: Config) -> None: + optimizer_dict = self.get_optimizer(config) + assert isinstance(optimizer_dict, dict) + self.optimizer = optimizer_dict["global"] + self.local_optimizer = optimizer_dict["local"] + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: """ Returns a dictionairy with global and local optimizers with string keys 'global' and 'local' respectively. diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index e50a1d358..ec04c146c 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -306,10 +306,7 @@ def setup_client(self, config: Config) -> None: self.num_train_samples = len(self.train_loader.dataset) # type: ignore self.num_val_samples = len(self.val_loader.dataset) # type: ignore - optimizer = self.get_optimizer(config) - assert isinstance(optimizer, Optimizer) - self.optimizer = optimizer - + self.set_optimizer(config) self.learning_rate = self.optimizer.defaults["lr"] self.criterion = self.get_criterion(config) self.parameter_exchanger = self.get_parameter_exchanger(config) @@ -326,14 +323,19 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: """ - Return dict of str and torch.Tensor contaiing predictions when given input. + Return dict of str and torch.Tensor containing predictions when given input. In the default case, the dict has a single item with key prediction. In more complicated approaches such as APFL, the dict has as many items as prediction types User can override for more complex logic. """ preds = self.model(input) - preds = preds if isinstance(preds, dict) else {"prediction": preds} - return preds + + if isinstance(preds, dict): + return preds + elif isinstance(preds, torch.Tensor): + return {"prediction": preds} + else: + raise ValueError("Model forward did not return a tensor or dictionary or tensors") def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses: """ @@ -348,6 +350,16 @@ def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> losses = Losses(checkpoint=loss, backward=loss) return losses + def set_optimizer(self, config: Config) -> None: + """ + Method called in the the setup_client method to set optimizer attribute returned by used-defined get_optimizer. + In the simplest case, get_optimizer returns an optimizer. For more advanced use cases where a dictionairy of + string and optimizer are returned (ie APFL), the use must override this method. + """ + optimizer = self.get_optimizer(config) + assert not isinstance(optimizer, dict) + self.optimizer = optimizer + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: """ User defined method that returns a PyTorch Train DataLoader diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 8864a92da..c3b59782d 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -239,9 +239,10 @@ def __init__(self, key_to_meter_map: Dict[str, MetricMeter]): self.key_to_meter_map = key_to_meter_map def update(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> None: - for map_key, pred_key in zip(sorted(self.key_to_meter_map.keys()), sorted(preds.keys())): - assert map_key == pred_key - self.key_to_meter_map[map_key].update(preds[pred_key], target) + # Assert that set of preds keys and map keys are the same + assert set(preds.keys()) == set(self.key_to_meter_map.keys()) + for pred_key in preds.keys(): + self.key_to_meter_map[pred_key].update(preds[pred_key], target) def compute(self) -> Dict[str, Scalar]: all_results = {} diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 2514b289f..08e1ed2ff 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -59,6 +59,8 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: return train_loader, val_loader def get_criterion(self, config: Config) -> _Loss: + # NOTE: The class weights specified by alpha in this baseline loss are precomputed based on the weights of + # the pool dataset. This is a bit of cheating but FLamby does it in their paper. return BaselineLoss() def get_model(self, config: Config) -> nn.Module: