Skip to content

Commit

Permalink
update precommit type changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 8, 2025
1 parent 3142889 commit f4ad580
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
5 changes: 4 additions & 1 deletion fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays:
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0:
log(INFO, "Loading save best model state before sending model to server.")
self.early_stopper.load_snapshot(["model"])
assert self.model is not None and self.parameter_exchanger is not None
return self.parameter_exchanger.push_parameters(self.model, config=config)

Expand Down Expand Up @@ -893,7 +896,7 @@ def setup_early_stopper(
self,
patience: int = -1,
interval_steps: int = 5,
snapshot_dir: Optional[Path] = None,
snapshot_dir: Path | None = None,
) -> None:
from fl4health.utils.early_stopper import EarlyStopper

Expand Down
33 changes: 25 additions & 8 deletions fl4health/utils/early_stopper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from logging import INFO
from pathlib import Path
from typing import Any, Callable, Optional, Type
from typing import Any

import torch.nn as nn
from flwr.common.logger import log
Expand Down Expand Up @@ -28,18 +29,33 @@ class EarlyStopper:
def __init__(
self,
client: BasicClient,
patience: int = -1,
patience: int = 0,
interval_steps: int = 5,
snapshot_dir: Optional[Path] = None,
snapshot_dir: Path | None = None,
) -> None:
"""
Early stopping class is an plugin for the client that allows to stop local training based on the validation
loss. At each training step this class saves the best state of the client and restores it if the client is
stopped. If the client starts to overfit, the early stopper will stop the training process and restore the best
state of the client before sending the model to the server.
Args:
client (BasicClient): The client to be monitored.
patience (int, optional): Number of steps to wait before stopping the training. If it is equal to 0 client
never stops, but still loads the best state before sending the model to the server. Defaults to 0.
interval_steps (int, optional): Determins how often the early stopper should check the validation loss.
Defaults to 5.
snapshot_dir (Path | None, optional): Rather than keeping best state in the memory we can checkpoint it to
the given directory. If it is not given, the best state is kept in the memory. Defaults to None.
"""

self.client = client

self.patience = patience
self.counte_down = patience
self.interval_steps = interval_steps

self.best_score: Optional[float] = None
self.best_score: float | None = None
self.snapshot_ckpt: dict[str, Any] = {}

self.default_snapshot_args: dict = {
Expand Down Expand Up @@ -70,7 +86,7 @@ def __init__(
self.checkpointer = PerRoundStateCheckpointer(snapshot_dir)

def add_default_snapshot_arg(
self, name: str, snapshot_class: Callable[[BasicClient], Snapshotter], input_type: Type[T]
self, name: str, snapshot_class: Callable[[BasicClient], Snapshotter], input_type: type[T]
) -> None:
self.default_snapshot_args.update({name: (snapshot_class(self.client), input_type)})

Expand All @@ -94,7 +110,7 @@ def save_snapshot(self) -> None:
f"Saving client temp best state to checkpoint at {self.checkpointer.checkpoint_dir}",
)

def load_snapshot(self) -> None:
def load_snapshot(self, args: list[str]) -> None:
"""
Load checkpoint dict consisting of client name, total steps, lr schedulers, metrics
reporter and optimizers state. Method can be overridden to augment loaded checkpointed state.
Expand All @@ -106,7 +122,8 @@ def load_snapshot(self) -> None:
if self.checkpointer.checkpoint_exists(f"temp_{self.client.client_name}.pt"):
self.snapshot_ckpt = self.checkpointer.load_checkpoint(f"temp_{self.client.client_name}.pt")

for arg, (snapshotter_function, expected_type) in self.default_snapshot_args.items():
for arg in args:
snapshotter_function, expected_type = self.default_snapshot_args[arg]
snapshotter_function.load(self.snapshot_ckpt, arg, expected_type)

def should_stop(self) -> bool:
Expand Down Expand Up @@ -139,7 +156,7 @@ def should_stop(self) -> bool:
# Reduce patience counter and check for early stopping
self.count_down -= 1
if self.count_down == 0:
self.load_snapshot()
self.load_snapshot(list(self.default_snapshot_args.keys()))
return True

return False
8 changes: 4 additions & 4 deletions fl4health/utils/snapshotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from abc import ABC, abstractmethod
from typing import Any, Generic, Type, TypeVar
from typing import Any, Generic, TypeVar

import torch.nn as nn
from torch.optim import Optimizer
Expand All @@ -18,7 +18,7 @@ class Snapshotter(ABC, Generic[T]):
def __init__(self, client: BasicClient) -> None:
self.client = client

def dict_wrap_attr(self, name: str, expected_type: Type[T]) -> dict[str, T]:
def dict_wrap_attr(self, name: str, expected_type: type[T]) -> dict[str, T]:
attribute = copy.deepcopy(getattr(self.client, name))
if isinstance(attribute, expected_type):
return {"None": attribute}
Expand All @@ -30,11 +30,11 @@ def dict_wrap_attr(self, name: str, expected_type: Type[T]) -> dict[str, T]:
else:
raise ValueError(f"Uncompatible type of attribute {type(attribute)}")

def save(self, name: str, expected_type: Type[T]) -> Any:
def save(self, name: str, expected_type: type[T]) -> Any:
attribute = self.dict_wrap_attr(name, expected_type)
return self.save_attribute(attribute)

def load(self, ckpt: dict[str, Any], name: str, expected_type: Type[T]) -> None:
def load(self, ckpt: dict[str, Any], name: str, expected_type: type[T]) -> None:
attribute = self.dict_wrap_attr(name, expected_type)
self.load_attribute(ckpt[name], attribute)
if list(attribute.keys()) == ["None"]:
Expand Down

0 comments on commit f4ad580

Please sign in to comment.