From 915552be38254d47e05f981a045debef994932a7 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 27 Sep 2023 15:43:22 -0400 Subject: [PATCH 01/13] First pass at adding back in wandb integration --- examples/fedprox_example/client.py | 2 +- examples/fedprox_example/config.yaml | 4 ++-- fl4health/clients/basic_client.py | 35 ++++++++++++++++++++++++++-- fl4health/clients/numpy_fl_client.py | 2 +- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/examples/fedprox_example/client.py b/examples/fedprox_example/client.py index 4b22940e7..2754dc8c6 100644 --- a/examples/fedprox_example/client.py +++ b/examples/fedprox_example/client.py @@ -52,7 +52,7 @@ def get_criterion(self, config: Config) -> _Loss: log(INFO, f"Device to be used: {DEVICE}") log(INFO, f"Server Address: {args.server_address}") - client = MnistFedProxClient(data_path, [Accuracy()], DEVICE) + client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, use_wandb_reporter=True) fl.client.start_numpy_client(server_address=args.server_address, client=client) # Shutdown the client gracefully diff --git a/examples/fedprox_example/config.yaml b/examples/fedprox_example/config.yaml index bb61d3a1d..65b2c6d62 100644 --- a/examples/fedprox_example/config.yaml +++ b/examples/fedprox_example/config.yaml @@ -16,10 +16,10 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False + enabled: True project_name: FL4Health # Name of the project under which everything should be logged run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored - entity: "your_entity_here" # WandB user name + entity: "False" # WandB user name notes: "Testing WB reporting" tags: ["Test", "FedProx"] diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 7e0f0dbff..572e7b1ba 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,6 +1,6 @@ from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -54,6 +54,9 @@ def __init__( self.num_val_samples: int self.learning_rate: float + self.total_steps: Optional[int] = None + self.total_epochs: Optional[int] = None + def set_parameters(self, parameters: NDArrays, config: Config) -> None: # Set the model weights and initialize the correct weights with the parameter exchanger. super().set_parameters(parameters, config) @@ -67,9 +70,11 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N if ("local_epochs" in config) and ("local_steps" in config): raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.") elif "local_epochs" in config: + self.total_epochs = 0 if self.total_epochs is None else self.total_epochs local_epochs = self.narrow_config_type(config, "local_epochs", int) local_steps = None elif "local_steps" in config: + self.total_steps = 0 if self.total_steps is None else self.total_steps local_steps = self.narrow_config_type(config, "local_steps", int) local_epochs = None else: @@ -140,6 +145,24 @@ def _handle_logging( f"Client {metric_prefix} Losses: {loss_string} \n" f"Client {metric_prefix} Metrics: {metric_string}", ) + def _handle_reporting( + self, + loss_dict: Dict[str, float], + metric_dict: Dict[str, Scalar], + current_round: Optional[int] = None, + ) -> None: + current_round = current_round if current_round is not None else 0 + reporting_dict: Dict[str, Any] = {"server_round": current_round} + reporting_dict = ( + {**reporting_dict, "epoch": self.total_epochs} if self.total_epochs is not None else reporting_dict + ) + reporting_dict = ( + {**reporting_dict, "step": self.total_steps} if self.total_steps is not None else reporting_dict + ) + reporting_dict.update(loss_dict) + reporting_dict.update(metric_dict) + self._maybe_report_metrics(reporting_dict) + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: """ Given input and target, generate predictions, compute loss, optionally update metrics if they exist. @@ -187,7 +210,11 @@ def train_by_epochs( losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() - self._handle_logging(loss_dict, metrics, current_epoch=local_epoch, current_round=current_round) + self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) + self._handle_reporting(loss_dict, metrics, current_round=current_round) + + assert self.total_epochs is not None + self.total_epochs += 1 # Return final training metrics return loss_dict, metrics @@ -216,11 +243,15 @@ def train_by_steps( self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + assert self.total_steps is not None + self.total_steps += 1 + losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() metrics = self.train_metric_meter.compute() self._handle_logging(loss_dict, metrics, current_round=current_round) + self._handle_reporting(loss_dict, metrics, current_round=current_round) return loss_dict, metrics diff --git a/fl4health/clients/numpy_fl_client.py b/fl4health/clients/numpy_fl_client.py index f5a559131..99da6faa6 100644 --- a/fl4health/clients/numpy_fl_client.py +++ b/fl4health/clients/numpy_fl_client.py @@ -32,7 +32,7 @@ def __init__(self, data_path: Path, device: torch.device) -> None: def generate_hash(self, length: int = 8) -> str: return "".join(random.choice(string.ascii_lowercase) for i in range(length)) - def _maybe_log_metrics(self, to_log: Dict[str, Any]) -> None: + def _maybe_report_metrics(self, to_log: Dict[str, Any]) -> None: if self.wandb_reporter: self.wandb_reporter.report_metrics(to_log) From f4ff3435ac78d6368bd446a7da607d345b15a48f Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 27 Sep 2023 15:53:19 -0400 Subject: [PATCH 02/13] Fix config value I accidentally changed --- examples/fedprox_example/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fedprox_example/config.yaml b/examples/fedprox_example/config.yaml index 65b2c6d62..971025f81 100644 --- a/examples/fedprox_example/config.yaml +++ b/examples/fedprox_example/config.yaml @@ -20,6 +20,6 @@ reporting_config: project_name: FL4Health # Name of the project under which everything should be logged run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored - entity: "False" # WandB user name + entity: "your_entity_here" # WandB user name notes: "Testing WB reporting" tags: ["Test", "FedProx"] From e600c4b28bedab88c68637c700b36baa8aff098b Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 27 Sep 2023 16:27:23 -0400 Subject: [PATCH 03/13] Add some comments --- fl4health/clients/basic_client.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 572e7b1ba..e420a4772 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -54,6 +54,9 @@ def __init__( self.num_val_samples: int self.learning_rate: float + # Need to track total_steps or total_epochs across rounds for WANDB reporting + # Only one will be initialized to 0 in the first call to fit in process_config + # And subsequently incremented for each epoch or step in all of the following rounds self.total_steps: Optional[int] = None self.total_epochs: Optional[int] = None @@ -70,10 +73,12 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N if ("local_epochs" in config) and ("local_steps" in config): raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.") elif "local_epochs" in config: + # Initialize total_epochs attribute with 0 if None self.total_epochs = 0 if self.total_epochs is None else self.total_epochs local_epochs = self.narrow_config_type(config, "local_epochs", int) local_steps = None elif "local_steps" in config: + # Initialize total_steps attribute with 0 if None self.total_steps = 0 if self.total_steps is None else self.total_steps local_steps = self.narrow_config_type(config, "local_steps", int) local_epochs = None @@ -151,7 +156,13 @@ def _handle_reporting( metric_dict: Dict[str, Scalar], current_round: Optional[int] = None, ) -> None: + + # If no current_round is passed or current_round is None, set current_round to 0 + # This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly current_round = current_round if current_round is not None else 0 + + # We enforce that only one of self.total_epochs and self.total_steps is defined + # So only one of the two following lines will update the reporting dict reporting_dict: Dict[str, Any] = {"server_round": current_round} reporting_dict = ( {**reporting_dict, "epoch": self.total_epochs} if self.total_epochs is not None else reporting_dict @@ -210,12 +221,14 @@ def train_by_epochs( losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() - self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) - self._handle_reporting(loss_dict, metrics, current_round=current_round) - + # Ensure total_epochs is not None and increment assert self.total_epochs is not None self.total_epochs += 1 + # Log results and maybe report via WANDB + self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) + self._handle_reporting(loss_dict, metrics, current_round=current_round) + # Return final training metrics return loss_dict, metrics @@ -243,13 +256,15 @@ def train_by_steps( self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + # Ensure total_steps is not None and increment assert self.total_steps is not None self.total_steps += 1 losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() metrics = self.train_metric_meter.compute() - + + # Log results and maybe report via WANDB self._handle_logging(loss_dict, metrics, current_round=current_round) self._handle_reporting(loss_dict, metrics, current_round=current_round) From ea36a3824fc7ed2cb426da9496dcd09f9cf00881 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 27 Sep 2023 16:28:10 -0400 Subject: [PATCH 04/13] Fix pre-commit --- fl4health/clients/basic_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index e420a4772..328320386 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -161,7 +161,7 @@ def _handle_reporting( # This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly current_round = current_round if current_round is not None else 0 - # We enforce that only one of self.total_epochs and self.total_steps is defined + # We enforce that only one of self.total_epochs and self.total_steps is defined # So only one of the two following lines will update the reporting dict reporting_dict: Dict[str, Any] = {"server_round": current_round} reporting_dict = ( @@ -256,14 +256,14 @@ def train_by_steps( self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) - # Ensure total_steps is not None and increment + # Ensure total_steps is not None and increment assert self.total_steps is not None self.total_steps += 1 losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() metrics = self.train_metric_meter.compute() - + # Log results and maybe report via WANDB self._handle_logging(loss_dict, metrics, current_round=current_round) self._handle_reporting(loss_dict, metrics, current_round=current_round) From c9a023ab5d1f19dce69e969d263f81b47b112cc8 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 11 Oct 2023 16:54:11 -0400 Subject: [PATCH 05/13] Server passes bool to let client know if reporting is enabled instead of setting it in client constructor --- examples/fedprox_example/client.py | 2 +- fl4health/clients/basic_client.py | 5 +---- fl4health/clients/clipping_client.py | 2 -- fl4health/clients/fed_prox_client.py | 2 -- fl4health/clients/instance_level_privacy_client.py | 2 -- fl4health/clients/scaffold_client.py | 5 ----- fl4health/reporting/fl_wanb.py | 2 +- research/flamby/fed_heart_disease/fedadam/client.py | 2 -- research/flamby/fed_heart_disease/fedavg/client.py | 2 -- research/flamby/fed_heart_disease/fedprox/client.py | 2 -- research/flamby/fed_heart_disease/fenda/client.py | 2 -- research/flamby/fed_heart_disease/scaffold/client.py | 2 -- research/flamby/fed_isic2019/fedadam/client.py | 2 -- research/flamby/fed_isic2019/fedavg/client.py | 2 -- research/flamby/fed_isic2019/fedprox/client.py | 2 -- research/flamby/fed_isic2019/fenda/client.py | 2 -- research/flamby/fed_isic2019/scaffold/client.py | 2 -- research/flamby/fed_ixi/fedadam/client.py | 2 -- research/flamby/fed_ixi/fedavg/client.py | 2 -- research/flamby/fed_ixi/fedprox/client.py | 2 -- research/flamby/fed_ixi/fenda/client.py | 2 -- research/flamby/fed_ixi/scaffold/client.py | 2 -- 22 files changed, 3 insertions(+), 47 deletions(-) diff --git a/examples/fedprox_example/client.py b/examples/fedprox_example/client.py index 2754dc8c6..4b22940e7 100644 --- a/examples/fedprox_example/client.py +++ b/examples/fedprox_example/client.py @@ -52,7 +52,7 @@ def get_criterion(self, config: Config) -> _Loss: log(INFO, f"Device to be used: {DEVICE}") log(INFO, f"Server Address: {args.server_address}") - client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, use_wandb_reporter=True) + client = MnistFedProxClient(data_path, [Accuracy()], DEVICE) fl.client.start_numpy_client(server_address=args.server_address, client=client) # Shutdown the client gracefully diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 328320386..faca23bc4 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -33,12 +33,10 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__(data_path, device) self.metrics = metrics - self.use_wandb_reporter = use_wandb_reporter self.checkpointer = checkpointer self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) @@ -320,8 +318,7 @@ def setup_client(self, config: Config) -> None: self.criterion = self.get_criterion(config) self.parameter_exchanger = self.get_parameter_exchanger(config) - if self.use_wandb_reporter: - self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) + self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) super().setup_client(config) diff --git a/fl4health/clients/clipping_client.py b/fl4health/clients/clipping_client.py index 3e5371686..fd7d21d79 100644 --- a/fl4health/clients/clipping_client.py +++ b/fl4health/clients/clipping_client.py @@ -30,7 +30,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -39,7 +38,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.parameter_exchanger: ParameterExchangerWithPacking[float] diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index 0be0b2ebd..34c4c0020 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -27,7 +27,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -36,7 +35,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.initial_tensors: List[torch.Tensor] diff --git a/fl4health/clients/instance_level_privacy_client.py b/fl4health/clients/instance_level_privacy_client.py index 222a73548..f3c3508f9 100644 --- a/fl4health/clients/instance_level_privacy_client.py +++ b/fl4health/clients/instance_level_privacy_client.py @@ -26,7 +26,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -35,7 +34,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.clipping_bound: float diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 8ecba0311..4208c4fef 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -31,7 +31,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.learning_rate: float # eta_l in paper @@ -207,7 +205,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: ScaffoldClient.__init__( @@ -217,7 +214,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) @@ -228,6 +224,5 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) diff --git a/fl4health/reporting/fl_wanb.py b/fl4health/reporting/fl_wanb.py index 18ed7be36..e3b2b0953 100644 --- a/fl4health/reporting/fl_wanb.py +++ b/fl4health/reporting/fl_wanb.py @@ -161,7 +161,7 @@ def add_client_model_type(self, client_name: str, model_type: str) -> None: @classmethod def from_config(cls, client_name: str, config: Dict[str, Any]) -> Optional["ClientWandBReporter"]: - if config["reporting_enabled"]: + if "reporting_enabled" in config and config["reporting_enabled"]: return ClientWandBReporter(client_name, config["project_name"], config["group_name"], config["entity"]) else: return None diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index 688c95b2d..dc4a0f510 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index 3a498f2e9..160b359a9 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index e4b019317..431d8b0f5 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index 6eb6aa4c5..12cd57da5 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index 5a686e533..dba721134 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 402de4709..e1b3c902e 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index e1eccb0cd..fd66d9f47 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index 34d493986..856479793 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index 88907a0b7..449d6fdc0 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 89d237e1c..18ab9939a 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index 6549d0055..d4ef9a900 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index 3e36ba707..e9e77d762 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index 089cbd324..4c71c45f9 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 0658e8c8e..373e22505 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index da8005f74..5f874f767 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number From 31d94bf2464f27115b4672a5fae1a25a9c95e7fb Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 11 Oct 2023 18:12:22 -0400 Subject: [PATCH 06/13] Change reporting to use steps only, regardless if client is uses local_steps or local_epochs --- fl4health/clients/basic_client.py | 27 ++++----------------------- requirements.txt | 4 ++++ 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index faca23bc4..aa503f28f 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -52,11 +52,8 @@ def __init__( self.num_val_samples: int self.learning_rate: float - # Need to track total_steps or total_epochs across rounds for WANDB reporting - # Only one will be initialized to 0 in the first call to fit in process_config - # And subsequently incremented for each epoch or step in all of the following rounds - self.total_steps: Optional[int] = None - self.total_epochs: Optional[int] = None + # Need to track total_steps across rounds for WANDB reporting + self.total_steps: int = 0 def set_parameters(self, parameters: NDArrays, config: Config) -> None: # Set the model weights and initialize the correct weights with the parameter exchanger. @@ -71,13 +68,9 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N if ("local_epochs" in config) and ("local_steps" in config): raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.") elif "local_epochs" in config: - # Initialize total_epochs attribute with 0 if None - self.total_epochs = 0 if self.total_epochs is None else self.total_epochs local_epochs = self.narrow_config_type(config, "local_epochs", int) local_steps = None elif "local_steps" in config: - # Initialize total_steps attribute with 0 if None - self.total_steps = 0 if self.total_steps is None else self.total_steps local_steps = self.narrow_config_type(config, "local_steps", int) local_epochs = None else: @@ -159,15 +152,8 @@ def _handle_reporting( # This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly current_round = current_round if current_round is not None else 0 - # We enforce that only one of self.total_epochs and self.total_steps is defined - # So only one of the two following lines will update the reporting dict reporting_dict: Dict[str, Any] = {"server_round": current_round} - reporting_dict = ( - {**reporting_dict, "epoch": self.total_epochs} if self.total_epochs is not None else reporting_dict - ) - reporting_dict = ( - {**reporting_dict, "step": self.total_steps} if self.total_steps is not None else reporting_dict - ) + reporting_dict.update({"step": self.total_steps}) reporting_dict.update(loss_dict) reporting_dict.update(metric_dict) self._maybe_report_metrics(reporting_dict) @@ -215,14 +201,11 @@ def train_by_epochs( losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + self.total_steps += 1 metrics = self.train_metric_meter.compute() losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() - # Ensure total_epochs is not None and increment - assert self.total_epochs is not None - self.total_epochs += 1 - # Log results and maybe report via WANDB self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) self._handle_reporting(loss_dict, metrics, current_round=current_round) @@ -254,8 +237,6 @@ def train_by_steps( self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) - # Ensure total_steps is not None and increment - assert self.total_steps is not None self.total_steps += 1 losses = self.train_loss_meter.compute() diff --git a/requirements.txt b/requirements.txt index 379c14220..03bd5dd18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,5 +21,9 @@ torchinfo torchtext torchvision types-requests +types-PyYAML +types-protobuf +types-six +types-tabulate types-setuptools wandb From 69d99a95e574ecfa6eb3e0d5e5ca87db82713f72 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 11 Oct 2023 18:19:33 -0400 Subject: [PATCH 07/13] Fix requirements ordering causing failed pre-commit checks --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 03bd5dd18..b812ad815 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,10 +20,10 @@ torcheval torchinfo torchtext torchvision -types-requests -types-PyYAML types-protobuf +types-PyYAML +types-requests +types-setuptools types-six types-tabulate -types-setuptools wandb From f8942556e2db5014fdbe48399c0e068d4ee7e1fc Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 11 Oct 2023 18:35:31 -0400 Subject: [PATCH 08/13] fix precommit pt 2 --- examples/fedopt_example/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fedopt_example/metrics.py b/examples/fedopt_example/metrics.py index 5c21708b6..366841707 100644 --- a/examples/fedopt_example/metrics.py +++ b/examples/fedopt_example/metrics.py @@ -1,7 +1,7 @@ import json from typing import Dict, List -import numpy as np +import torch from flwr.common.typing import Metrics from sklearn.metrics import confusion_matrix @@ -107,7 +107,7 @@ def summarize(self) -> str: log_string = f"{log_string}\naverage_f1:{str(sum_f1/n_topics)}" return log_string - def update_performance(self, predictions: np.ndarray, labels: np.ndarray) -> None: + def update_performance(self, predictions: torch.Tensor, labels: torch.Tensor) -> None: confusion = confusion_matrix(labels, predictions, labels=range(self.n_classes)) for i in range(self.n_classes): true_class = self.label_to_class[i] From 028422b3cc4efdce5bf2e37afcbc3ef98bfd4e24 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 12 Oct 2023 10:48:36 -0400 Subject: [PATCH 09/13] Move handling reporting to basic client --- fl4health/clients/basic_client.py | 6 +++++- fl4health/clients/numpy_fl_client.py | 6 +----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index aa503f28f..5c681723a 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -148,6 +148,10 @@ def _handle_reporting( current_round: Optional[int] = None, ) -> None: + # If reporter is None we do not report to wandb and return + if self.wandb_reporter is None: + return + # If no current_round is passed or current_round is None, set current_round to 0 # This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly current_round = current_round if current_round is not None else 0 @@ -156,7 +160,7 @@ def _handle_reporting( reporting_dict.update({"step": self.total_steps}) reporting_dict.update(loss_dict) reporting_dict.update(metric_dict) - self._maybe_report_metrics(reporting_dict) + self.wandb_reporter.report_metrics(reporting_dict) def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: """ diff --git a/fl4health/clients/numpy_fl_client.py b/fl4health/clients/numpy_fl_client.py index 99da6faa6..fb5893798 100644 --- a/fl4health/clients/numpy_fl_client.py +++ b/fl4health/clients/numpy_fl_client.py @@ -1,7 +1,7 @@ import random import string from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Optional, Type, TypeVar import torch import torch.nn as nn @@ -32,10 +32,6 @@ def __init__(self, data_path: Path, device: torch.device) -> None: def generate_hash(self, length: int = 8) -> str: return "".join(random.choice(string.ascii_lowercase) for i in range(length)) - def _maybe_report_metrics(self, to_log: Dict[str, Any]) -> None: - if self.wandb_reporter: - self.wandb_reporter.report_metrics(to_log) - def _maybe_checkpoint(self, comparison_metric: float) -> None: if self.checkpointer: self.checkpointer.maybe_checkpoint(self.model, comparison_metric) From 3d179d430ba22e3a20eb656931eea85d90d685bb Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 12 Oct 2023 12:41:16 -0400 Subject: [PATCH 10/13] Add tests to server and client reporting classes --- tests/reporting/__init__.py | 0 tests/reporting/test_wandb_reporter.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 tests/reporting/__init__.py create mode 100644 tests/reporting/test_wandb_reporter.py diff --git a/tests/reporting/__init__.py b/tests/reporting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/reporting/test_wandb_reporter.py b/tests/reporting/test_wandb_reporter.py new file mode 100644 index 000000000..3735fb88d --- /dev/null +++ b/tests/reporting/test_wandb_reporter.py @@ -0,0 +1,17 @@ +from pathlib import Path +from unittest import mock +from fl4health.reporting.fl_wanb import ClientWandBReporter, ServerWandBReporter + +def test_server_wandb_reporter(tmp_path: Path) -> None: + with mock.patch.object(ServerWandBReporter, "__init__", lambda a, b, c, d, e, f, g, h : None): + reporter = ServerWandBReporter("", "", "", "", None, None, {}) + log_dir = str(tmp_path.joinpath("fl_wandb_logs")) + reporter._maybe_create_local_log_directory(log_dir) + assert log_dir in list(map(lambda x : str(x), tmp_path.iterdir())) + +def test_client_wandb_reporter(tmp_path: Path) -> None: + with mock.patch.object(ClientWandBReporter, "__init__", lambda a, b, c, d, e: None): + reporter = ClientWandBReporter("", "", "", "") + log_dir = str(tmp_path.joinpath("fl_wandb_logs")) + reporter._maybe_create_local_log_directory(log_dir) + assert log_dir in list(map(lambda x : str(x), tmp_path.iterdir())) \ No newline at end of file From aefaa798ed2202aa23207888a2856f142ecb36f4 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 12 Oct 2023 12:52:01 -0400 Subject: [PATCH 11/13] Fix pre-commit issues --- tests/reporting/test_wandb_reporter.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/reporting/test_wandb_reporter.py b/tests/reporting/test_wandb_reporter.py index 3735fb88d..bda170621 100644 --- a/tests/reporting/test_wandb_reporter.py +++ b/tests/reporting/test_wandb_reporter.py @@ -1,17 +1,20 @@ -from pathlib import Path +from pathlib import Path from unittest import mock + from fl4health.reporting.fl_wanb import ClientWandBReporter, ServerWandBReporter -def test_server_wandb_reporter(tmp_path: Path) -> None: - with mock.patch.object(ServerWandBReporter, "__init__", lambda a, b, c, d, e, f, g, h : None): + +def test_server_wandb_reporter(tmp_path: Path) -> None: + with mock.patch.object(ServerWandBReporter, "__init__", lambda a, b, c, d, e, f, g, h: None): reporter = ServerWandBReporter("", "", "", "", None, None, {}) log_dir = str(tmp_path.joinpath("fl_wandb_logs")) reporter._maybe_create_local_log_directory(log_dir) - assert log_dir in list(map(lambda x : str(x), tmp_path.iterdir())) + assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir())) + -def test_client_wandb_reporter(tmp_path: Path) -> None: +def test_client_wandb_reporter(tmp_path: Path) -> None: with mock.patch.object(ClientWandBReporter, "__init__", lambda a, b, c, d, e: None): reporter = ClientWandBReporter("", "", "", "") log_dir = str(tmp_path.joinpath("fl_wandb_logs")) reporter._maybe_create_local_log_directory(log_dir) - assert log_dir in list(map(lambda x : str(x), tmp_path.iterdir())) \ No newline at end of file + assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir())) From c331b601ebe309de941eb66115fbde3b5b77e20f Mon Sep 17 00:00:00 2001 From: John Jewell Date: Thu, 12 Oct 2023 13:16:25 -0400 Subject: [PATCH 12/13] Initially changed predictions and labels to torch.tensor so we pass pre-commit checks. But confusion matrix expects ndarray, so just added the torch to numpy conversion so nothing breaks --- examples/fedopt_example/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/fedopt_example/metrics.py b/examples/fedopt_example/metrics.py index 366841707..2b6fae8d0 100644 --- a/examples/fedopt_example/metrics.py +++ b/examples/fedopt_example/metrics.py @@ -108,6 +108,7 @@ def summarize(self) -> str: return log_string def update_performance(self, predictions: torch.Tensor, labels: torch.Tensor) -> None: + predictions, labels = predictions.numpy(), labels.numpy() confusion = confusion_matrix(labels, predictions, labels=range(self.n_classes)) for i in range(self.n_classes): true_class = self.label_to_class[i] From 26edf8dc47ff01ca4badcb40c76eef0e11128890 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Fri, 13 Oct 2023 10:22:18 -0400 Subject: [PATCH 13/13] Fix pre-commit issues --- examples/fedopt_example/metrics.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/fedopt_example/metrics.py b/examples/fedopt_example/metrics.py index 2b6fae8d0..a29833ce1 100644 --- a/examples/fedopt_example/metrics.py +++ b/examples/fedopt_example/metrics.py @@ -1,12 +1,15 @@ import json -from typing import Dict, List +from typing import Dict, List, TypeVar -import torch +import numpy as np from flwr.common.typing import Metrics from sklearn.metrics import confusion_matrix +from torch import Tensor from examples.fedopt_example.client_data import LabelEncoder +T = TypeVar("T", np.ndarray, Tensor) + class Outcome: def __init__(self, class_name: str) -> None: @@ -107,8 +110,7 @@ def summarize(self) -> str: log_string = f"{log_string}\naverage_f1:{str(sum_f1/n_topics)}" return log_string - def update_performance(self, predictions: torch.Tensor, labels: torch.Tensor) -> None: - predictions, labels = predictions.numpy(), labels.numpy() + def update_performance(self, predictions: T, labels: T) -> None: confusion = confusion_matrix(labels, predictions, labels=range(self.n_classes)) for i in range(self.n_classes): true_class = self.label_to_class[i]