Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified some of the reporting keys and overhauled the wandb reporter #288

Merged
merged 9 commits into from
Nov 18, 2024
61 changes: 39 additions & 22 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.initial_weights: Optional[NDArrays] = None

self.total_steps: int = 0 # Need to track total_steps across rounds for WANDB reporting
self.total_epochs: int = 0
scarere marked this conversation as resolved.
Show resolved Hide resolved

# Attributes to be initialized in setup_client
self.parameter_exchanger: ParameterExchanger
Expand Down Expand Up @@ -221,9 +222,9 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N
config (Config): The config from the server.

Returns:
Tuple[Union[int, None], Union[int, None], int, bool]: Returns the local_epochs, local_steps,
current_server_round and evaluate_after_fit. Ensures only one of local_epochs and local_steps
is defined in the config and sets the one that is not to None.
Tuple[Union[int, None], Union[int, None], int, bool, bool]: Returns the local_epochs, local_steps,
current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of
local_epochs and local_steps is defined in the config and sets the one that is not to None.

Raises:
ValueError: If the config contains both local_steps and local epochs or if local_steps, local_epochs or
Expand Down Expand Up @@ -307,15 +308,24 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict
# We perform a pre-aggregation checkpoint if applicable
self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION)

# Notes on report values:
# - Train by steps: round metrics/losses are computed using all samples from the round
# - Train by epochs: round metrics/losses computed using only the samples from the final epoch of the round
# - fit_round_metrics: Computed at the end of the round on the samples directly
# - fit_round_losses: The average of the losses computed for each step.
# * (Hence likely higher than the final loss of the round.)
self.reports_manager.report(
{
"fit_metrics": metrics,
"fit_losses": loss_dict,
"fit_round_metrics": metrics, # Metrics computed
scarere marked this conversation as resolved.
Show resolved Hide resolved
"fit_round_losses": loss_dict,
"round": current_server_round,
"round_start": str(round_start_time),
"round_end": str(datetime.datetime.now()),
"fit_start": str(fit_start_time),
"fit_end": str(fit_end_time),
"fit_round_start": str(fit_start_time),
"fit_round_time_elapsed": str(fit_end_time - fit_start_time),
"fit_round_end": str(fit_end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
},
current_server_round,
)
Expand Down Expand Up @@ -349,7 +359,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di
self.setup_client(config)

start_time = datetime.datetime.now()
current_server_round = narrow_dict_type(config, "current_server_round", int)
local_epochs, local_steps, current_server_round, _, _ = self.process_config(config)
scarere marked this conversation as resolved.
Show resolved Hide resolved

pack_losses_with_val_metrics = set_pack_losses_with_val_metrics(config)

Expand All @@ -364,14 +374,19 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di

self.reports_manager.report(
{
"eval_metrics": metrics,
"eval_loss": loss,
"eval_start": str(start_time),
"eval_time_elapsed": str(elapsed),
"eval_end": str(end_time),
"eval_round_metrics": metrics,
"eval_round_loss": loss,
"eval_round_start": str(start_time),
"eval_round_time_elapsed": str(elapsed),
"eval_round_end": str(end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
"round": current_server_round,
},
current_server_round,
)
# Have to report 3 times to make sure they get reported regardless of the reporting level
scarere marked this conversation as resolved.
Show resolved Hide resolved
# This is admittidly cleugy and we should think of a better way to handle reporting levels
scarere marked this conversation as resolved.
Show resolved Hide resolved

# EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics
# calculation results.
Expand Down Expand Up @@ -607,7 +622,7 @@ def train_by_epochs(
# update before epoch hook
self.update_before_epoch(epoch=local_epoch)
# Update report data dict
report_data.update({"fit_epoch": local_epoch})
report_data.update({"fit_epoch": self.total_epochs})
for input, target in maybe_progress_bar(self.train_loader, self.progress_bar):
self.update_before_step(steps_this_round, current_round)
# Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can
Expand All @@ -623,20 +638,22 @@ def train_by_epochs(
self.update_metric_manager(preds, target, self.train_metric_manager)
self.update_after_step(steps_this_round, current_round)
self.update_lr_schedulers(epoch=local_epoch)
report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps)
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1

# Log and report results
metrics = self.train_metric_manager.compute()
loss_dict = self.train_loss_meter.compute().as_dict()

# Log and report results
self._log_results(loss_dict, metrics, current_round, local_epoch)
report_data.update({"fit_metrics": metrics})
report_data.update({"fit_epoch_metrics": metrics, "fit_epoch_losses": loss_dict})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, local_epoch)
self.reports_manager.report(report_data, current_round, self.total_epochs)
self._log_results(loss_dict, metrics, current_round, local_epoch)

# Update internal epoch counter
self.total_epochs += 1

# Return final training metrics
return loss_dict, metrics
Expand Down Expand Up @@ -690,7 +707,7 @@ def train_by_steps(
self.update_metric_manager(preds, target, self.train_metric_manager)
self.update_after_step(step, current_round)
self.update_lr_schedulers(step=step)
report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
Expand Down
11 changes: 11 additions & 0 deletions fl4health/clients/nnunet_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.training.dataloading.utils import unpack_dataset
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
Expand Down Expand Up @@ -246,9 +247,15 @@ def get_model(self, config: Config) -> nn.Module:
return self.nnunet_trainer.network

def get_criterion(self, config: Config) -> _Loss:
if isinstance(self.nnunet_trainer.loss, DeepSupervisionWrapper):
self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.loss.__class__.__name__})
else:
self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.__class__.__name__})

return Module2LossWrapper(self.nnunet_trainer.loss)

def get_optimizer(self, config: Config) -> Optimizer:
self.reports_manager.report({"Optimizer": self.nnunet_trainer.optimizer.__class__.__name__})
return self.nnunet_trainer.optimizer

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler:
Expand Down Expand Up @@ -289,6 +296,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler:
# Create and return LR Scheduler Wrapper for the PolyLRScheduler so that it is
# compatible with Torch LRScheduler
# Create and return LR Scheduler. This is nnunet default for version 2.5.1
self.reports_manager.report({"LR Scheduler": "PolyLRScheduler"})
return PolyLRSchedulerWrapper(
self.optimizers[optimizer_key],
initial_lr=self.nnunet_trainer.initial_lr,
Expand Down Expand Up @@ -685,6 +693,9 @@ def get_client_specific_logs(
else:
return "", []

def get_client_specific_reports(self) -> Dict[str, Any]:
return {"learning_rate": float(self.optimizers["global"].param_groups[0]["lr"])}

@use_default_signal_handlers # Experiment planner spawns a process I think
def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""
Expand Down
9 changes: 5 additions & 4 deletions fl4health/reporting/base_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def report(
data (dict): The data to maybe report from the server or client.
round (int | None, optional): The current FL round. If None, this indicates that the method was called
outside of a round (e.g. for summary information). Defaults to None.
epoch (int | None, optional): The current epoch. If None then this method was not called at or within the
scope of an epoch. Defaults to None.
step (int | None, optional): The current step (total). If None then this method was called outside the
scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None.
epoch (int | None, optional): The current epoch (In total across all rounds). If None then this method was
scarere marked this conversation as resolved.
Show resolved Hide resolved
not called at or within the scope of an epoch. Defaults to None.
step (int | None, optional): The current step (In total across all rounds and epochs). If None then this
method was called outside the scope of a training or evaluation step (eg. at the end of an epoch or
round) Defaults to None.
"""
raise NotImplementedError

Expand Down
Loading
Loading