-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Fixed Requirements File for FLamby, Update Dynamic Weight Exchanger and FedOpt Example #68
Conversation
…he readme, as well as creating a dynamic weight exchange client.
…he AG news dataset.
@@ -2,14 +2,15 @@ | |||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this file simply relate to moving from the previous news classification dataset to the ag news dataset.
loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
checkpointer: Optional[TorchCheckpointer] = None, | ||
) -> None: | ||
super(BasicClient, self).__init__(data_path, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This is a slight hack that likely won't be necessary with additional refactors to the metrics managers class. Since I'm not explicitly defining a set of metric for my MetricMeter class, I am skipping over the __init__
function of the BasicClient
class and directly going to the NumpyFl parent __init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for future reference, for #69 I slightly refactored how we handle accumulating and computing metrics, so I decided to base off this branch and update it so we call the BasicClient constructor, and save us from having to redefine a lot of attributes.
train_loader, validation_loader, num_examples, weight_matrix = construct_dataloaders( | ||
self.data_path, vocabulary, label_encoder, sequence_length, batch_size | ||
train_loader, validation_loader, _, weight_matrix = construct_dataloaders( | ||
self.data_path, self.vocabulary, self.label_encoder, sequence_length, self.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: self.vocabulary
and self.label_encoder
are initialized in setup_client
before the call to super().setup_client()
to ensure their availability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe adding a comment for future reference when people are looking at the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call.
@@ -1,12 +1,12 @@ | |||
# Parameters that describe server |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some changes to shorten the run time.
@@ -75,54 +75,69 @@ def compute_metrics(self) -> Metrics: | |||
return metrics | |||
|
|||
|
|||
class ClientMetrics: | |||
class CustomMetricMeter(MetricMeter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defining our own MetricMeter, as we're going to use a custom class to accrue metric information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for future reference, in #69 I based of this branch and made some slight updates to the code in this file. First, by having metrics store their own state, we eliminate the need for having MetricMeters. I updated this to be a "Compound Metric" which stores information that is relevant to compute one or more related metrics. All in all, it does not change very much though.
@@ -53,20 +53,8 @@ def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics | |||
return server_metrics.compute_metrics() | |||
|
|||
|
|||
def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions were redundant, we might as well just use the metric_aggregation
function rather than having both of these "encapsulations"
@@ -18,9 +18,9 @@ def __init__(self, vocab_size: int, vocab_dimension: int = 128, lstm_dimension: | |||
bidirectional=True, | |||
) | |||
self.drop = nn.Dropout(p=0.3) | |||
self.fc = nn.Linear(2 * lstm_dimension, 41) | |||
self.fc = nn.Linear(2 * lstm_dimension, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving from the previous new classification task, which had 41 labels, to the AG news task, which only has 4
|
||
def forward(self, x: torch.Tensor, hidden: torch.Tensor) -> torch.Tensor: | ||
def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really sure why this type annotation had passed mypy previously, but hidden should be a tuple of tensors...
# stores the values of the new model parameters at the beginning of each training round. | ||
self._align_model_parameters(self.initial_model, self.model) | ||
|
||
def _align_model_parameters(self, initial_model: nn.Module, target_model: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yc7z: I feel like, at some point, we had a good reason for this function, but I couldn't quite see why we needed it anymore. It was only called on line 132 above and it appears that it was setting the initial_model
values before training to use as a comparison to the trained model after local training. However, if initial_model is the same architecture as self.model, which is set using the server passed weights in set_parameters
, I don't see why we need the for loops in this function. So I put
self.initial_model.load_state_dict(self.model.state_dict(), strict=True)
at the end of set_parameters
in the new refactor. Let me know if that makes sense to you or if I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this makes sense. Originally we had this method because I was using self.pull_parameters()
to update self.initial_model
, which was incorrect since the parameters pulled would not be the full model weights. But I think we could have simply used load_state_dict
instead back then.
@@ -53,13 +53,6 @@ def construct_config( | |||
} | |||
|
|||
|
|||
def construct_eval_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't called anywhere, so I removed it.
def __init__( | ||
self, | ||
name: str = "F1 score", | ||
average: Optional[str] = "weighted", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding an additional functionality to use different averaging approaches for the F1 calculations.
…he custom metrics meter class.
# Initial model parameters to be used in calculating weight shifts during training | ||
self.initial_model: nn.Module | ||
# Parameter exchanger to be used in server-client exchange of dynamic layers. | ||
self.parameter_exchanger: NormDriftParameterExchanger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thought I had is that the dynamic weight exchanger could be more general than the norm drift parameter exchanger its constrained to using. We may want to define other parameter exchanger that select a certain set of weights each FL round to exchange according to some crtieria. However, lacking a concrete alternative parameter exchanger perhaps we can save it for a future work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have a ticket in the backlog to tackle making this more generic! Just haven't gotten to it yet.
filter_by_percentage = self.narrow_config_type(config, "filter_by_percentage", bool) | ||
norm_threshold = self.narrow_config_type(config, "norm_threshold", float) | ||
exchange_percentage = self.narrow_config_type(config, "exchange_percentage", float) | ||
parameter_exchanger = NormDriftParameterExchanger( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also want the user to be able to configure what norm we use but default to the 2-norm since it is probably what people are looking to use most of the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup definitely!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the most part, looks good to me! I left a few very minor comments for us to discuss further. Some just in relation to how I build on this PR for my metric factor PR and address some of the comments you had left in this PR. I don't imagine it will involve substantial changes though so I am going to go ahead and approve but will keep an eye out for the subsequent discussion/changes :)
This PR addresses three small tickets: Ticket 1, Ticket 2, Ticket 3.
The first ticket is just to create a pinned FLamby requirements file that can be used to create the FLamby/FL4Health environment to run the FLamby experiments. This is in response to a new release from MonAI that breaks the FLamby code for FedIsic (they have not pinned the MonAI version in their repo). I've tested it and it appears to work.
The second ticket considers adding a specific dynamic weight exchanger client which handles the get and set parameters functionality required to perform dynamic weight exchange.
The third ticket migrates the FedOpt example to our new BasicClient structure, which reduces a bunch of code duplication and simplifies a bunch of the code. This was one of our last hold out examples from before the refactor. In performing the migration, I ditched the news classification dataset that had been used as it wasn't a great example task, and instead moved everything over to use the AG's News dataset that Yuchong had worked with. So there are some minor changes that relate to that move as well.