-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidating model and state checkpointing on the client and server sides. #298
base: main
Are you sure you want to change the base?
Conversation
…ckpointer_consolidation
…ckpointer_consolidation
…ckpointer_consolidation
…ckpointer_consolidation
…ckpointer_consolidation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So as I understand it, we've defined two base types for checkpointing. Module (Torch) checkpointers and state checkpointers. Furthermore we defined client and server checkpoint module base classes which pack together the module checkpointers (is the term module being reused here?) and the state checkpointers. I think as a long term goal for a future PR it would be great to unify the state and module checkpointers. However I understand that needing to know which checkpointer to use for loading the state complicates that. Although I know this wasn't part of this PR, I was looking to the packing exchangers and parameter exchangers and am a little confused as to why so much structure and subclassing is needed. Additionally with the server checkpointing modules, there seems to be a large number of subclasses where the primary difference is just the argument passed to the parent class. Is it possible to get rid of these?
self, | ||
model: nn.Module | None = None, | ||
parameter_exchanger: ExchangerType | None = None, | ||
model_checkpointers: CheckpointModuleInput = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems the on thing preventing us from merging the model checkpointing and state checkpointing is the load state method. I'm wondering if there is any way around this so that we can fold all these classes and subclasses into one.
raise ValueError("Attempting to load state, but no state checkpointer is specified") | ||
|
||
|
||
class PackingServerCheckpointAndAndStateModule(BaseServerCheckpointAndStateModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the only difference between all these subclasses is the parameter exchanger, why the need for the wrapper classes? Why not just let the user initialize the base class with a parameter exchanger?
""" | ||
if parameter_exchanger is not None: | ||
assert isinstance( | ||
parameter_exchanger, FullParameterExchangerWithPacking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the need for all the structure and classes with the parameter and packing exchangers? From what I can tell we're trying to allow users to pass additional information when pushing and pulling the parameters. They define a packer, which is passed to the packing exchanger subclass which now is also passed here. There just seems to be a lot of classes involved. Can we just have all the packer classes be subclasses of the parameter exchanger? Is there a way to reduce the number of classes and subclasses across the state checkpointers and module checkpointers?
@@ -34,31 +40,37 @@ def __init__( | |||
pre_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of | |||
checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE** | |||
server-side aggregation. Defaults to None. | |||
post_aggregation (CheckpointModuleInput, optional], optional): If defined, this checkpointer (or sequence | |||
post_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an instance where the post-aggregation client model is different from the global model?
|
||
CheckpointModuleInput = Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] | ||
CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this type needed? Might it be clearer to just have the more explicit typing. If not maybe we can change this type name to something that communicates that there can be multiple checkpointers
SparseCooParameterPacker, | ||
) | ||
|
||
CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here with the typing, is it better to be more explicit by removing this type definition?
PR Type
Refactor
Short Description
Clickup Ticket(s): Work towards this ticket: https://app.clickup.com/t/8689xkbwm, though more work is likely necessary.
This PR is starting the work towards consolidating the model and state checkpointing on both the server and client side. On the server side, we're pulling the hydration responsibilities into the model checkpointing containers. That way the server need not "know" the process required to inject parameters into a model.
It also means that there aren't several connected but disparate pieces that need to be provided to the servers/clients, you instantiate these models and send them to the server to make things work.
Tests Added
Migrated existing tests to work with this new system.