Skip to content

Commit

Permalink
Merge pull request #56 from VectorInstitute/dbe/server_checkpointing_…
Browse files Browse the repository at this point in the history
…in_base

Add Server-side Checkpointing to Server Base
  • Loading branch information
emersodb authored Sep 26, 2023
2 parents 069bc16 + 89335aa commit a3a851e
Show file tree
Hide file tree
Showing 31 changed files with 323 additions and 210 deletions.
3 changes: 3 additions & 0 deletions examples/basic_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ n_server_rounds: 3 # The number of rounds to run FL
n_clients: 2 # The number of clients in the FL experiment
local_epochs: 3 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training

# checkpointing
checkpoint_path: "examples/basic_example"
23 changes: 17 additions & 6 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
from typing import Any, Dict

import flwr as fl
import torch.nn as nn
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config, Parameters
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.models.cnn_model import Net
from examples.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.server.base_server import FlServerWithCheckpointing
from fl4health.utils.config import load_config


def get_initial_model_parameters() -> Parameters:
def get_initial_model_parameters(model: nn.Module) -> Parameters:
# Initializing the model parameters on the server side.
# Currently uses the Pytorch default initialization for the model parameters.
initial_model = Net()
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])
return ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])


def fit_config(
Expand All @@ -35,6 +38,12 @@ def main(config: Dict[str, Any]) -> None:
config["batch_size"],
)

# Initializing the model on the server side
model = Net()
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestMetricTorchCheckpointer(config["checkpoint_path"], "best_model.pkl", maximize=False)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
min_fit_clients=config["n_clients"],
Expand All @@ -46,13 +55,15 @@ def main(config: Dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(),
initial_parameters=get_initial_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
strategy=strategy,
)


Expand Down
73 changes: 70 additions & 3 deletions fl4health/server/base_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from logging import INFO
from typing import List, Optional
from logging import INFO, WARNING
from typing import Dict, Generic, List, Optional, Tuple, TypeVar

import torch.nn as nn
from flwr.common.logger import log
from flwr.common.parameter import parameters_to_ndarrays
from flwr.common.typing import Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.server import EvaluateResultsAndFailures, Server
from flwr.server.strategy import Strategy

from fl4health.checkpointing.checkpointer import TorchCheckpointer
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.reporting.fl_wanb import ServerWandBReporter
from fl4health.server.polling import poll_clients
from fl4health.strategies.strategy_with_poll import StrategyWithPolling
Expand Down Expand Up @@ -40,6 +44,29 @@ def shutdown(self) -> None:
if self.wandb_reporter:
self.wandb_reporter.shutdown_reporter()

def _hydrate_model_for_checkpointing(self) -> nn.Module:
# This function is used for converting server parameters into a torch model that can be checkpointed
raise NotImplementedError()

def _maybe_checkpoint(self, checkpoint_metric: float, server_round: int) -> None:
if self.checkpointer:
try:
model = self._hydrate_model_for_checkpointing()
self.checkpointer.maybe_checkpoint(model, checkpoint_metric)
except NotImplementedError:
# Checkpointer is defined but there is no server-side model hydration to produce a model from the
# server state. This is not a deal breaker, but may be unintended behavior and the user will be warned
if server_round == 1:
# just log message on the first round
log(
WARNING,
"Server model hydration is not defined but checkpointer is defined. Not checkpointing "
"model. Please ensure that this is intended",
)
elif server_round == 1:
# No checkpointer, just log message on the first round
log(INFO, "No checkpointer present. Models will not be checkpointed on server-side.")

def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]:
# Poll clients for sample counts, if you want to use this functionality your strategy needs to inherit from
# the StrategyWithPolling ABC and implement a configure_poll function
Expand All @@ -56,3 +83,43 @@ def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]:
log(INFO, f"Polling complete: Retrieved {len(sample_counts)} sample counts")

return sample_counts

def evaluate_round(
self,
server_round: int,
timeout: Optional[float],
) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]:
# By default the checkpointing works off of the aggregated evaluation loss from each of the clients
# NOTE: parameter aggregation occurs **before** evaluation, so the parameters held by the server have been
# updated prior to this function being called.
eval_round_results = super().evaluate_round(server_round, timeout)
if eval_round_results:
loss_aggregated, metrics_aggregated, (results, failures) = eval_round_results
if loss_aggregated:
self._maybe_checkpoint(loss_aggregated, server_round)

return eval_round_results


ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger)


class FlServerWithCheckpointing(FlServer, Generic[ExchangerType]):
def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
parameter_exchanger: ExchangerType,
wandb_reporter: Optional[ServerWandBReporter] = None,
strategy: Optional[Strategy] = None,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(client_manager, strategy, wandb_reporter, checkpointer)
self.server_model = model
# To facilitate model rehydration from server-side state for checkpointing
self.parameter_exchanger = parameter_exchanger

def _hydrate_model_for_checkpointing(self) -> nn.Module:
model_ndarrays = parameters_to_ndarrays(self.parameters)
self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model)
return self.server_model
6 changes: 3 additions & 3 deletions research/flamby/fed_heart_disease/apfl/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def main(config: Dict[str, Any], server_address: str) -> None:
)

client_manager = SimpleClientManager()
client_model = APFLModule(Baseline())
summarize_model_info(client_model)
model = APFLModule(Baseline())
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -44,7 +44,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
)

server = PersonalServer(client_manager, strategy)
Expand Down
8 changes: 4 additions & 4 deletions research/flamby/fed_heart_disease/fedadam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def main(
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = Baseline()
summarize_model_info(client_model)
model = Baseline()
summarize_model_info(model)

strategy = FedAdam(
min_fit_clients=config["n_clients"],
Expand All @@ -50,11 +50,11 @@ def main(
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
eta=server_learning_rate,
)

server = FullExchangeServer(client_manager, client_model, strategy, checkpointer)
server = FullExchangeServer(client_manager, model, strategy, checkpointer=checkpointer)

fl.server.start_server(
server=server,
Expand Down
8 changes: 4 additions & 4 deletions research/flamby/fed_heart_disease/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = Baseline()
summarize_model_info(client_model)
model = Baseline()
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -49,10 +49,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
)

server = FullExchangeServer(client_manager, client_model, strategy, checkpointer=checkpointer)
server = FullExchangeServer(client_manager, model, strategy, checkpointer=checkpointer)

fl.server.start_server(
server=server,
Expand Down
8 changes: 4 additions & 4 deletions research/flamby/fed_heart_disease/fedprox/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = Baseline()
summarize_model_info(client_model)
model = Baseline()
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedProx(
Expand All @@ -49,11 +49,11 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
proximal_weight=mu,
)

server = FedProxServer(client_manager, client_model, strategy, checkpointer)
server = FedProxServer(client_manager, model, strategy, checkpointer)

fl.server.start_server(
server=server,
Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_heart_disease/fenda/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def main(config: Dict[str, Any], server_address: str) -> None:
)

client_manager = SimpleClientManager()
client_model = FedHeartDiseaseFendaModel()
summarize_model_info(client_model)
model = FedHeartDiseaseFendaModel()
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -43,7 +43,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
)

server = PersonalServer(client_manager, strategy)
Expand Down
8 changes: 4 additions & 4 deletions research/flamby/fed_heart_disease/scaffold/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def main(
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = FixedSamplingByFractionClientManager()
client_model = Baseline()
summarize_model_info(client_model)
model = Baseline()
summarize_model_info(model)

initial_parameters, initial_control_variates = get_initial_model_info_with_control_variates(client_model)
initial_parameters, initial_control_variates = get_initial_model_info_with_control_variates(model)

strategy = Scaffold(
fraction_fit=1.0,
Expand All @@ -57,7 +57,7 @@ def main(
learning_rate=server_learning_rate,
)

server = ScaffoldServer(client_manager, client_model, strategy, checkpointer)
server = ScaffoldServer(client_manager, model, strategy, checkpointer)

fl.server.start_server(
server=server,
Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_isic2019/apfl/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def main(config: Dict[str, Any], server_address: str) -> None:
)

client_manager = SimpleClientManager()
client_model = APFLModule(APFLEfficientNet(frozen_blocks=None, turn_off_bn_tracking=False))
summarize_model_info(client_model)
model = APFLModule(APFLEfficientNet(frozen_blocks=None, turn_off_bn_tracking=False))
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -44,7 +44,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
)

server = PersonalServer(client_manager, strategy)
Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_isic2019/fedadam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main(
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = FedAdamEfficientNet()
model = FedAdamEfficientNet()

strategy = FedAdam(
min_fit_clients=config["n_clients"],
Expand All @@ -48,11 +48,11 @@ def main(
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
eta=server_learning_rate,
)

server = FullExchangeServer(client_manager, client_model, strategy, checkpointer)
server = FullExchangeServer(client_manager, model, strategy, checkpointer=checkpointer)

fl.server.start_server(
server=server,
Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_isic2019/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = Baseline()
model = Baseline()

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -47,10 +47,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
)

server = FullExchangeServer(client_manager, client_model, strategy, checkpointer=checkpointer)
server = FullExchangeServer(client_manager, model, strategy, checkpointer=checkpointer)

fl.server.start_server(
server=server,
Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_isic2019/fedprox/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)

client_manager = SimpleClientManager()
client_model = Baseline()
model = Baseline()

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedProx(
Expand All @@ -47,11 +47,11 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(client_model),
initial_parameters=get_initial_model_parameters(model),
proximal_weight=mu,
)

server = FedProxServer(client_manager, client_model, strategy, checkpointer)
server = FedProxServer(client_manager, model, strategy, checkpointer)

fl.server.start_server(
server=server,
Expand Down
Loading

0 comments on commit a3a851e

Please sign in to comment.