diff --git a/examples/fedopt_example/metrics.py b/examples/fedopt_example/metrics.py index 5c21708b6..a29833ce1 100644 --- a/examples/fedopt_example/metrics.py +++ b/examples/fedopt_example/metrics.py @@ -1,12 +1,15 @@ import json -from typing import Dict, List +from typing import Dict, List, TypeVar import numpy as np from flwr.common.typing import Metrics from sklearn.metrics import confusion_matrix +from torch import Tensor from examples.fedopt_example.client_data import LabelEncoder +T = TypeVar("T", np.ndarray, Tensor) + class Outcome: def __init__(self, class_name: str) -> None: @@ -107,7 +110,7 @@ def summarize(self) -> str: log_string = f"{log_string}\naverage_f1:{str(sum_f1/n_topics)}" return log_string - def update_performance(self, predictions: np.ndarray, labels: np.ndarray) -> None: + def update_performance(self, predictions: T, labels: T) -> None: confusion = confusion_matrix(labels, predictions, labels=range(self.n_classes)) for i in range(self.n_classes): true_class = self.label_to_class[i] diff --git a/examples/fedprox_example/config.yaml b/examples/fedprox_example/config.yaml index bb61d3a1d..971025f81 100644 --- a/examples/fedprox_example/config.yaml +++ b/examples/fedprox_example/config.yaml @@ -16,7 +16,7 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False + enabled: True project_name: FL4Health # Name of the project under which everything should be logged run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 7e0f0dbff..5c681723a 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,6 +1,6 @@ from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -33,12 +33,10 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__(data_path, device) self.metrics = metrics - self.use_wandb_reporter = use_wandb_reporter self.checkpointer = checkpointer self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type) @@ -54,6 +52,9 @@ def __init__( self.num_val_samples: int self.learning_rate: float + # Need to track total_steps across rounds for WANDB reporting + self.total_steps: int = 0 + def set_parameters(self, parameters: NDArrays, config: Config) -> None: # Set the model weights and initialize the correct weights with the parameter exchanger. super().set_parameters(parameters, config) @@ -140,6 +141,27 @@ def _handle_logging( f"Client {metric_prefix} Losses: {loss_string} \n" f"Client {metric_prefix} Metrics: {metric_string}", ) + def _handle_reporting( + self, + loss_dict: Dict[str, float], + metric_dict: Dict[str, Scalar], + current_round: Optional[int] = None, + ) -> None: + + # If reporter is None we do not report to wandb and return + if self.wandb_reporter is None: + return + + # If no current_round is passed or current_round is None, set current_round to 0 + # This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly + current_round = current_round if current_round is not None else 0 + + reporting_dict: Dict[str, Any] = {"server_round": current_round} + reporting_dict.update({"step": self.total_steps}) + reporting_dict.update(loss_dict) + reporting_dict.update(metric_dict) + self.wandb_reporter.report_metrics(reporting_dict) + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]: """ Given input and target, generate predictions, compute loss, optionally update metrics if they exist. @@ -183,11 +205,14 @@ def train_by_epochs( losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + self.total_steps += 1 metrics = self.train_metric_meter.compute() losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() - self._handle_logging(loss_dict, metrics, current_epoch=local_epoch, current_round=current_round) + # Log results and maybe report via WANDB + self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) + self._handle_reporting(loss_dict, metrics, current_round=current_round) # Return final training metrics return loss_dict, metrics @@ -216,11 +241,15 @@ def train_by_steps( self.train_loss_meter.update(losses) self.train_metric_meter.update(preds, target) + self.total_steps += 1 + losses = self.train_loss_meter.compute() loss_dict = losses.as_dict() metrics = self.train_metric_meter.compute() + # Log results and maybe report via WANDB self._handle_logging(loss_dict, metrics, current_round=current_round) + self._handle_reporting(loss_dict, metrics, current_round=current_round) return loss_dict, metrics @@ -274,8 +303,7 @@ def setup_client(self, config: Config) -> None: self.criterion = self.get_criterion(config) self.parameter_exchanger = self.get_parameter_exchanger(config) - if self.use_wandb_reporter: - self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) + self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) super().setup_client(config) diff --git a/fl4health/clients/clipping_client.py b/fl4health/clients/clipping_client.py index 3e5371686..fd7d21d79 100644 --- a/fl4health/clients/clipping_client.py +++ b/fl4health/clients/clipping_client.py @@ -30,7 +30,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -39,7 +38,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.parameter_exchanger: ParameterExchangerWithPacking[float] diff --git a/fl4health/clients/fed_prox_client.py b/fl4health/clients/fed_prox_client.py index 0be0b2ebd..34c4c0020 100644 --- a/fl4health/clients/fed_prox_client.py +++ b/fl4health/clients/fed_prox_client.py @@ -27,7 +27,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -36,7 +35,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.initial_tensors: List[torch.Tensor] diff --git a/fl4health/clients/instance_level_privacy_client.py b/fl4health/clients/instance_level_privacy_client.py index 222a73548..f3c3508f9 100644 --- a/fl4health/clients/instance_level_privacy_client.py +++ b/fl4health/clients/instance_level_privacy_client.py @@ -26,7 +26,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -35,7 +34,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.clipping_bound: float diff --git a/fl4health/clients/numpy_fl_client.py b/fl4health/clients/numpy_fl_client.py index f5a559131..fb5893798 100644 --- a/fl4health/clients/numpy_fl_client.py +++ b/fl4health/clients/numpy_fl_client.py @@ -1,7 +1,7 @@ import random import string from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Optional, Type, TypeVar import torch import torch.nn as nn @@ -32,10 +32,6 @@ def __init__(self, data_path: Path, device: torch.device) -> None: def generate_hash(self, length: int = 8) -> str: return "".join(random.choice(string.ascii_lowercase) for i in range(length)) - def _maybe_log_metrics(self, to_log: Dict[str, Any]) -> None: - if self.wandb_reporter: - self.wandb_reporter.report_metrics(to_log) - def _maybe_checkpoint(self, comparison_metric: float) -> None: if self.checkpointer: self.checkpointer.maybe_checkpoint(self.model, comparison_metric) diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 8ecba0311..4208c4fef 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -31,7 +31,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: super().__init__( @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.learning_rate: float # eta_l in paper @@ -207,7 +205,6 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, - use_wandb_reporter: bool = False, checkpointer: Optional[TorchCheckpointer] = None, ) -> None: ScaffoldClient.__init__( @@ -217,7 +214,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) @@ -228,6 +224,5 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) diff --git a/fl4health/reporting/fl_wanb.py b/fl4health/reporting/fl_wanb.py index 18ed7be36..e3b2b0953 100644 --- a/fl4health/reporting/fl_wanb.py +++ b/fl4health/reporting/fl_wanb.py @@ -161,7 +161,7 @@ def add_client_model_type(self, client_name: str, model_type: str) -> None: @classmethod def from_config(cls, client_name: str, config: Dict[str, Any]) -> Optional["ClientWandBReporter"]: - if config["reporting_enabled"]: + if "reporting_enabled" in config and config["reporting_enabled"]: return ClientWandBReporter(client_name, config["project_name"], config["group_name"], config["entity"]) else: return None diff --git a/requirements.txt b/requirements.txt index 379c14220..b812ad815 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,10 @@ torcheval torchinfo torchtext torchvision +types-protobuf +types-PyYAML types-requests types-setuptools +types-six +types-tabulate wandb diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index 688c95b2d..dc4a0f510 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index 3a498f2e9..160b359a9 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index e4b019317..431d8b0f5 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index 6eb6aa4c5..12cd57da5 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index 5a686e533..dba721134 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 402de4709..e1b3c902e 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index e1eccb0cd..fd66d9f47 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index 34d493986..856479793 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index 88907a0b7..449d6fdc0 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 89d237e1c..18ab9939a 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index 6549d0055..d4ef9a900 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index 3e36ba707..e9e77d762 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index 089cbd324..4c71c45f9 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 0658e8c8e..373e22505 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -33,7 +33,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -41,7 +40,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index da8005f74..5f874f767 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -32,7 +32,6 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION, checkpointer: Optional[TorchCheckpointer] = None, - use_wandb_reporter: bool = False, ) -> None: super().__init__( data_path=data_path, @@ -40,7 +39,6 @@ def __init__( device=device, loss_meter_type=loss_meter_type, metric_meter_type=metric_meter_type, - use_wandb_reporter=use_wandb_reporter, checkpointer=checkpointer, ) self.client_number = client_number diff --git a/tests/reporting/__init__.py b/tests/reporting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/reporting/test_wandb_reporter.py b/tests/reporting/test_wandb_reporter.py new file mode 100644 index 000000000..bda170621 --- /dev/null +++ b/tests/reporting/test_wandb_reporter.py @@ -0,0 +1,20 @@ +from pathlib import Path +from unittest import mock + +from fl4health.reporting.fl_wanb import ClientWandBReporter, ServerWandBReporter + + +def test_server_wandb_reporter(tmp_path: Path) -> None: + with mock.patch.object(ServerWandBReporter, "__init__", lambda a, b, c, d, e, f, g, h: None): + reporter = ServerWandBReporter("", "", "", "", None, None, {}) + log_dir = str(tmp_path.joinpath("fl_wandb_logs")) + reporter._maybe_create_local_log_directory(log_dir) + assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir())) + + +def test_client_wandb_reporter(tmp_path: Path) -> None: + with mock.patch.object(ClientWandBReporter, "__init__", lambda a, b, c, d, e: None): + reporter = ClientWandBReporter("", "", "", "") + log_dir = str(tmp_path.joinpath("fl_wandb_logs")) + reporter._maybe_create_local_log_directory(log_dir) + assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir()))