-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add early stop module #301
base: main
Are you sure you want to change the base?
Conversation
…4Health into sa_early_stop
for more information, see https://pre-commit.ci
@@ -11,7 +11,7 @@ | |||
from flwr.common.typing import Config, NDArrays, Scalar | |||
from torch.nn.modules.loss import _Loss | |||
from torch.optim import Optimizer | |||
from torch.optim.lr_scheduler import _LRScheduler | |||
from torch.optim.lr_scheduler import LRScheduler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was necessary due to precommit errors. I think _LRScheduler has been deprecated in new versions,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few things we definitely want to make sure we think about carefully. Specifically, I just want to make sure we aren't going to have a lot of additional memory overhead with the way we're doing snapshotting.
@@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays: | |||
# Need all parameters even if normally exchanging partial | |||
return FullParameterExchanger().push_parameters(self.model, config=config) | |||
else: | |||
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can encapsulate this in a class method with an informative name?
@@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays: | |||
# Need all parameters even if normally exchanging partial | |||
return FullParameterExchanger().push_parameters(self.model, config=config) | |||
else: | |||
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be a naive question, but is there a reason we are doing a hasattr
for early stopper rather than having it be an optional property that defaults to None in the init function? That is,
self.early_stopper: EarlyStopper | None = None
if we do that, then here we can just check if it isn't None and then the set_early_stopper
method can default to just logging the not activated message and leaving it None. This would eliminate the need for the try-catch below as well.
except NotImplementedError: | ||
log( | ||
INFO, | ||
"""Early stopping not implemented for this client. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor, but we've been avoiding """
logging strings since they preserve white space. Rather, we could do
log(
INFO,
"Early stopping not implemented for this client. ",
"Override set_early_stopper to activate early stopping.",
)
class EarlyStopper: | ||
def __init__( | ||
self, | ||
client: BasicClient, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We talked about looking into this, so forgive me if we already did and I forgot the conclusion, but have we made sure that this will be storing a reference to the client object? Just want to make sure we're not suddenly doubling the memory footprint of each client by doing this.
snapshot_dir: Path | None = None, | ||
) -> None: | ||
""" | ||
Early stopping class is an plugin for the client that allows to stop local training based on the validation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"an plugin" -> "a plugin" 🙂
if self.best_score is None or val_loss < self.best_score: | ||
self.best_score = val_loss | ||
|
||
self.count_down = self.patience |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say that we have patience = 0 and we do our first check. The best_score will be None, we we'll get in here and self.count_down = self.patience = 0. Next time we get in here, count_down will get decremented to -1, so we also won't stop. I know if works, but it feels a bit weird and confusing to read. It feels like there is a cleaner way to do "infinite" patience and also making sure count_down never goes negative. Perhaps we can have patience be optional?
T = TypeVar("T") | ||
|
||
|
||
class Snapshotter(ABC, Generic[T]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice use of the Generic!
Returns: | ||
dict[str, T]: Wrapped attribute as a dictionary. | ||
""" | ||
attribute = copy.deepcopy(getattr(self.client, name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we're deep copying here? I think this will force a duplicate of the attribute (for example the model) to be created. I'm a little worried that will double our memory footprint. If it's so that we can keep stuff in memory (i.e. not checkpoint to a file) I think maybe we should force checkpointing to a file to avoid the memory overhead.
self.snapshot_ckpt = self.checkpointer.load_checkpoint(f"temp_{self.client.client_name}.pt") | ||
|
||
for attr in attrs: | ||
snapshotter_function, expected_type = self.default_snapshot_attrs[attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion, but I think we can call this just snapshotter
rather than snapshotter_function
@@ -0,0 +1,244 @@ | |||
import copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think these snapshotters look good. I just want to make sure we're sure that we're not duplicating a bunch of objects in memory. If we're not careful, these objects will make the clients much heavier than without early stopping, which we want to avoid.
PR Type
[Feature]
Short Description
Clickup Ticket(s): https://app.clickup.com/t/8688wzkuk , https://app.clickup.com/t/860qxm622
Integrated an early stopping module as a plug-in for all clients. After a specified number of training steps, the module computes the evaluation loss. If the loss improves compared to previous evaluations, it saves a snapshot of the model's key attributes, enabling the model to restore these attributes when the stopping criteria are met.
Tests Added
Added a series of tests for snapshot modules to ensure they are saved and loaded correctly as intended.