Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidating model and state checkpointing on the client and server sides. #298

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

emersodb
Copy link
Collaborator

PR Type

Refactor

Short Description

Clickup Ticket(s): Work towards this ticket: https://app.clickup.com/t/8689xkbwm, though more work is likely necessary.

This PR is starting the work towards consolidating the model and state checkpointing on both the server and client side. On the server side, we're pulling the hydration responsibilities into the model checkpointing containers. That way the server need not "know" the process required to inject parameters into a model.

It also means that there aren't several connected but disparate pieces that need to be provided to the servers/clients, you instantiate these models and send them to the server to make things work.

Tests Added

Migrated existing tests to work with this new system.

@emersodb emersodb marked this pull request as ready for review November 28, 2024 19:55
Copy link
Collaborator

@scarere scarere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as I understand it, we've defined two base types for checkpointing. Module (Torch) checkpointers and state checkpointers. Furthermore we defined client and server checkpoint module base classes which pack together the module checkpointers (is the term module being reused here?) and the state checkpointers. I think as a long term goal for a future PR it would be great to unify the state and module checkpointers. However I understand that needing to know which checkpointer to use for loading the state complicates that. Although I know this wasn't part of this PR, I was looking to the packing exchangers and parameter exchangers and am a little confused as to why so much structure and subclassing is needed. Additionally with the server checkpointing modules, there seems to be a large number of subclasses where the primary difference is just the argument passed to the parent class. Is it possible to get rid of these?

self,
model: nn.Module | None = None,
parameter_exchanger: ExchangerType | None = None,
model_checkpointers: CheckpointModuleInput = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems the on thing preventing us from merging the model checkpointing and state checkpointing is the load state method. I'm wondering if there is any way around this so that we can fold all these classes and subclasses into one.

raise ValueError("Attempting to load state, but no state checkpointer is specified")


class PackingServerCheckpointAndAndStateModule(BaseServerCheckpointAndStateModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the only difference between all these subclasses is the parameter exchanger, why the need for the wrapper classes? Why not just let the user initialize the base class with a parameter exchanger?

"""
if parameter_exchanger is not None:
assert isinstance(
parameter_exchanger, FullParameterExchangerWithPacking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the need for all the structure and classes with the parameter and packing exchangers? From what I can tell we're trying to allow users to pass additional information when pushing and pulling the parameters. They define a packer, which is passed to the packing exchanger subclass which now is also passed here. There just seems to be a lot of classes involved. Can we just have all the packer classes be subclasses of the parameter exchanger? Is there a way to reduce the number of classes and subclasses across the state checkpointers and module checkpointers?

@@ -34,31 +40,37 @@ def __init__(
pre_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE**
server-side aggregation. Defaults to None.
post_aggregation (CheckpointModuleInput, optional], optional): If defined, this checkpointer (or sequence
post_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an instance where the post-aggregation client model is different from the global model?


CheckpointModuleInput = Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]]
CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this type needed? Might it be clearer to just have the more explicit typing. If not maybe we can change this type name to something that communicates that there can be multiple checkpointers

SparseCooParameterPacker,
)

CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here with the typing, is it better to be more explicit by removing this type definition?

Base automatically changed from dbe/server_stores_config to main December 13, 2024 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants