-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update apfl client #62
Conversation
…y. Make sure update_after_step is being passed the actual_step not just local step in each epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of the changes look good, but I think we might want to generalize several of the functions to better admit the auxiliary pieces that APFL produces, rather than cutting them out. For example the auxiliary losses could be computed in a similar way to FedProx. I also think there are a few places we can expand the definitions of BasicClient functions (get_optimizer, predict) that would help APFL fit better into the flow, while minimally affecting our other implementations.
…tom compute loss for APFL. Return dict with personal, global and local preds from ApflClient.predict. Make some changes to BasicClient for this all to work. Added some comments
… examlples accordingly. Remove no longer relevant split optimizer test
@emersodb I made the changes to the MetricMeterManager and its initialization in the clients as we discussed this morning. Let me know if they look okay |
@@ -24,7 +25,36 @@ def __init__( | |||
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE, | |||
checkpointer: Optional[TorchCheckpointer] = None, | |||
) -> None: | |||
super().__init__(data_path, metrics, device, loss_meter_type, metric_meter_type, checkpointer) | |||
super(BasicClient, self).__init__(data_path, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to avoid using the the BasicClient constructor so that we can properly initialize the MetricMeterManager. In the follow up ticket where we have the user pass the MetricMeterManager, we can probably revert this back to using BasicClient constructor and get rid of the duplicate code. The follow up ticket has been added to clickup and I have assigned it to myself: https://app.clickup.com/t/86860zbdk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than having the user construct MetricMeterManagers, we could have them just pass a set of Meter objects and have the constructor run a function called setup_metric_meter_managers(meters: Sequence[Meters])
. By default it runs the setup in BasicClient and in APFL, we would just override the function. If you like the idea of constructing MetricMeterManagers though, that would work too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that is also another option to consider! I will see what ends up being the cleanest and go with that. Will focus on this in the next PR
"global": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_global"), | ||
"personal": MetricMeter.get_meter_by_type(self.metrics, metric_meter_type, "val_meter_personal"), | ||
} | ||
self.val_metric_meter_mngr = MetricMeterManager(val_key_to_meter_map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to do it in this PR, since we're going to follow this up with more changes to fold in the MetricMeterManager
, but a clean way to do these constructions would be useful 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah i will reserve that for the next PR since we are having the user pass MetricMeterManager and they way we construct them is likely to change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go. Great changes!
PR Type
[Feature| Fix | Documentation | Other() ]
Short Description
Clickup Story: This story is a follow up from the client refactor to have
ApflClient
inherit fromBasicClient
. I made the call not to record the losses and metrics of the global and local models, only the personal model. This is in line with FENDA and simplifies things a lot. I had to make some small changes to theApflModule
along the way.Just wanted to get this up quickly before going on vacation next week :) No rush reviewing it.
Tests Added