Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stop module #301

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Add early stop module #301

wants to merge 22 commits into from

Conversation

sanaAyrml
Copy link
Collaborator

@sanaAyrml sanaAyrml commented Dec 5, 2024

PR Type

[Feature]

Short Description

Clickup Ticket(s): https://app.clickup.com/t/8688wzkuk , https://app.clickup.com/t/860qxm622

Integrated an early stopping module as a plug-in for all clients. After a specified number of training steps, the module computes the evaluation loss. If the loss improves compared to previous evaluations, it saves a snapshot of the model's key attributes, enabling the model to restore these attributes when the stopping criteria are met.

Tests Added

Added a series of tests for snapshot modules to ensure they are saved and loaded correctly as intended.

@sanaAyrml sanaAyrml marked this pull request as draft January 2, 2025 19:09
@@ -11,7 +11,7 @@
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was necessary due to precommit errors. I think _LRScheduler has been deprecated in new versions,

@sanaAyrml sanaAyrml changed the title Sa early stop Add early stop module Jan 9, 2025
@sanaAyrml sanaAyrml marked this pull request as ready for review January 9, 2025 10:00
Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few things we definitely want to make sure we think about carefully. Specifically, I just want to make sure we aren't going to have a lot of additional memory overhead with the way we're doing snapshotting.

@@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays:
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can encapsulate this in a class method with an informative name?

@@ -159,6 +159,9 @@ def get_parameters(self, config: Config) -> NDArrays:
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
if hasattr(self, "early_stopper") and self.early_stopper.patience == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be a naive question, but is there a reason we are doing a hasattr for early stopper rather than having it be an optional property that defaults to None in the init function? That is,

self.early_stopper: EarlyStopper | None = None

if we do that, then here we can just check if it isn't None and then the set_early_stopper method can default to just logging the not activated message and leaving it None. This would eliminate the need for the try-catch below as well.

except NotImplementedError:
log(
INFO,
"""Early stopping not implemented for this client.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but we've been avoiding """ logging strings since they preserve white space. Rather, we could do

log(
    INFO, 
    "Early stopping not implemented for this client. ",
    "Override set_early_stopper to activate early stopping.",
)

class EarlyStopper:
def __init__(
self,
client: BasicClient,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked about looking into this, so forgive me if we already did and I forgot the conclusion, but have we made sure that this will be storing a reference to the client object? Just want to make sure we're not suddenly doubling the memory footprint of each client by doing this.

snapshot_dir: Path | None = None,
) -> None:
"""
Early stopping class is an plugin for the client that allows to stop local training based on the validation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"an plugin" -> "a plugin" 🙂

if self.best_score is None or val_loss < self.best_score:
self.best_score = val_loss

self.count_down = self.patience
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say that we have patience = 0 and we do our first check. The best_score will be None, we we'll get in here and self.count_down = self.patience = 0. Next time we get in here, count_down will get decremented to -1, so we also won't stop. I know if works, but it feels a bit weird and confusing to read. It feels like there is a cleaner way to do "infinite" patience and also making sure count_down never goes negative. Perhaps we can have patience be optional?

T = TypeVar("T")


class Snapshotter(ABC, Generic[T]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice use of the Generic!

Returns:
dict[str, T]: Wrapped attribute as a dictionary.
"""
attribute = copy.deepcopy(getattr(self.client, name))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're deep copying here? I think this will force a duplicate of the attribute (for example the model) to be created. I'm a little worried that will double our memory footprint. If it's so that we can keep stuff in memory (i.e. not checkpoint to a file) I think maybe we should force checkpointing to a file to avoid the memory overhead.

self.snapshot_ckpt = self.checkpointer.load_checkpoint(f"temp_{self.client.client_name}.pt")

for attr in attrs:
snapshotter_function, expected_type = self.default_snapshot_attrs[attr]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion, but I think we can call this just snapshotter rather than snapshotter_function

@@ -0,0 +1,244 @@
import copy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think these snapshotters look good. I just want to make sure we're sure that we're not duplicating a bunch of objects in memory. If we're not careful, these objects will make the clients much heavier than without early stopping, which we want to avoid.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants