Skip to content

Commit

Permalink
Merge pull request #61 from VectorInstitute/fix-wandb
Browse files Browse the repository at this point in the history
Fix wandb
  • Loading branch information
jewelltaylor authored Oct 13, 2023
2 parents a3a851e + 26edf8d commit 03849f7
Show file tree
Hide file tree
Showing 27 changed files with 66 additions and 56 deletions.
7 changes: 5 additions & 2 deletions examples/fedopt_example/metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
from typing import Dict, List
from typing import Dict, List, TypeVar

import numpy as np
from flwr.common.typing import Metrics
from sklearn.metrics import confusion_matrix
from torch import Tensor

from examples.fedopt_example.client_data import LabelEncoder

T = TypeVar("T", np.ndarray, Tensor)


class Outcome:
def __init__(self, class_name: str) -> None:
Expand Down Expand Up @@ -107,7 +110,7 @@ def summarize(self) -> str:
log_string = f"{log_string}\naverage_f1:{str(sum_f1/n_topics)}"
return log_string

def update_performance(self, predictions: np.ndarray, labels: np.ndarray) -> None:
def update_performance(self, predictions: T, labels: T) -> None:
confusion = confusion_matrix(labels, predictions, labels=range(self.n_classes))
for i in range(self.n_classes):
true_class = self.label_to_class[i]
Expand Down
2 changes: 1 addition & 1 deletion examples/fedprox_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: False
enabled: True
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
Expand Down
40 changes: 34 additions & 6 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -33,12 +33,10 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(data_path, device)
self.metrics = metrics
self.use_wandb_reporter = use_wandb_reporter
self.checkpointer = checkpointer
self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
Expand All @@ -54,6 +52,9 @@ def __init__(
self.num_val_samples: int
self.learning_rate: float

# Need to track total_steps across rounds for WANDB reporting
self.total_steps: int = 0

def set_parameters(self, parameters: NDArrays, config: Config) -> None:
# Set the model weights and initialize the correct weights with the parameter exchanger.
super().set_parameters(parameters, config)
Expand Down Expand Up @@ -140,6 +141,27 @@ def _handle_logging(
f"Client {metric_prefix} Losses: {loss_string} \n" f"Client {metric_prefix} Metrics: {metric_string}",
)

def _handle_reporting(
self,
loss_dict: Dict[str, float],
metric_dict: Dict[str, Scalar],
current_round: Optional[int] = None,
) -> None:

# If reporter is None we do not report to wandb and return
if self.wandb_reporter is None:
return

# If no current_round is passed or current_round is None, set current_round to 0
# This situation only arises when we do local finetuning and call train_by_epochs or train_by_steps explicitly
current_round = current_round if current_round is not None else 0

reporting_dict: Dict[str, Any] = {"server_round": current_round}
reporting_dict.update({"step": self.total_steps})
reporting_dict.update(loss_dict)
reporting_dict.update(metric_dict)
self.wandb_reporter.report_metrics(reporting_dict)

def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses, torch.Tensor]:
"""
Given input and target, generate predictions, compute loss, optionally update metrics if they exist.
Expand Down Expand Up @@ -183,11 +205,14 @@ def train_by_epochs(
losses, preds = self.train_step(input, target)
self.train_loss_meter.update(losses)
self.train_metric_meter.update(preds, target)
self.total_steps += 1
metrics = self.train_metric_meter.compute()
losses = self.train_loss_meter.compute()
loss_dict = losses.as_dict()

self._handle_logging(loss_dict, metrics, current_epoch=local_epoch, current_round=current_round)
# Log results and maybe report via WANDB
self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
self._handle_reporting(loss_dict, metrics, current_round=current_round)

# Return final training metrics
return loss_dict, metrics
Expand Down Expand Up @@ -216,11 +241,15 @@ def train_by_steps(
self.train_loss_meter.update(losses)
self.train_metric_meter.update(preds, target)

self.total_steps += 1

losses = self.train_loss_meter.compute()
loss_dict = losses.as_dict()
metrics = self.train_metric_meter.compute()

# Log results and maybe report via WANDB
self._handle_logging(loss_dict, metrics, current_round=current_round)
self._handle_reporting(loss_dict, metrics, current_round=current_round)

return loss_dict, metrics

Expand Down Expand Up @@ -274,8 +303,7 @@ def setup_client(self, config: Config) -> None:
self.criterion = self.get_criterion(config)
self.parameter_exchanger = self.get_parameter_exchanger(config)

if self.use_wandb_reporter:
self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config)
self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config)

super().setup_client(config)

Expand Down
2 changes: 0 additions & 2 deletions fl4health/clients/clipping_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(
Expand All @@ -39,7 +38,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.parameter_exchanger: ParameterExchangerWithPacking[float]
Expand Down
2 changes: 0 additions & 2 deletions fl4health/clients/fed_prox_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(
Expand All @@ -36,7 +35,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.initial_tensors: List[torch.Tensor]
Expand Down
2 changes: 0 additions & 2 deletions fl4health/clients/instance_level_privacy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(
Expand All @@ -35,7 +34,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.clipping_bound: float
Expand Down
6 changes: 1 addition & 5 deletions fl4health/clients/numpy_fl_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import string
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Optional, Type, TypeVar

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,10 +32,6 @@ def __init__(self, data_path: Path, device: torch.device) -> None:
def generate_hash(self, length: int = 8) -> str:
return "".join(random.choice(string.ascii_lowercase) for i in range(length))

def _maybe_log_metrics(self, to_log: Dict[str, Any]) -> None:
if self.wandb_reporter:
self.wandb_reporter.report_metrics(to_log)

def _maybe_checkpoint(self, comparison_metric: float) -> None:
if self.checkpointer:
self.checkpointer.maybe_checkpoint(self.model, comparison_metric)
Expand Down
5 changes: 0 additions & 5 deletions fl4health/clients/scaffold_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(
Expand All @@ -40,7 +39,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.learning_rate: float # eta_l in paper
Expand Down Expand Up @@ -207,7 +205,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
use_wandb_reporter: bool = False,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
ScaffoldClient.__init__(
Expand All @@ -217,7 +214,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)

Expand All @@ -228,6 +224,5 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
2 changes: 1 addition & 1 deletion fl4health/reporting/fl_wanb.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def add_client_model_type(self, client_name: str, model_type: str) -> None:

@classmethod
def from_config(cls, client_name: str, config: Dict[str, Any]) -> Optional["ClientWandBReporter"]:
if config["reporting_enabled"]:
if "reporting_enabled" in config and config["reporting_enabled"]:
return ClientWandBReporter(client_name, config["project_name"], config["group_name"], config["entity"])
else:
return None
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ torcheval
torchinfo
torchtext
torchvision
types-protobuf
types-PyYAML
types-requests
types-setuptools
types-six
types-tabulate
wandb
2 changes: 0 additions & 2 deletions research/flamby/fed_heart_disease/fedadam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_heart_disease/fedavg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_heart_disease/fedprox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_heart_disease/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_heart_disease/scaffold/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_isic2019/fedadam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_isic2019/fedavg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
2 changes: 0 additions & 2 deletions research/flamby/fed_isic2019/fedprox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
use_wandb_reporter: bool = False,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
metric_meter_type=metric_meter_type,
use_wandb_reporter=use_wandb_reporter,
checkpointer=checkpointer,
)
self.client_number = client_number
Expand Down
Loading

0 comments on commit 03849f7

Please sign in to comment.