Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Fixed Requirements File for FLamby, Update Dynamic Weight Exchanger and FedOpt Example #68

Merged
merged 7 commits into from
Nov 18, 2023

Conversation

emersodb
Copy link
Collaborator

@emersodb emersodb commented Nov 1, 2023

This PR addresses three small tickets: Ticket 1, Ticket 2, Ticket 3.

The first ticket is just to create a pinned FLamby requirements file that can be used to create the FLamby/FL4Health environment to run the FLamby experiments. This is in response to a new release from MonAI that breaks the FLamby code for FedIsic (they have not pinned the MonAI version in their repo). I've tested it and it appears to work.

The second ticket considers adding a specific dynamic weight exchanger client which handles the get and set parameters functionality required to perform dynamic weight exchange.

The third ticket migrates the FedOpt example to our new BasicClient structure, which reduces a bunch of code duplication and simplifies a bunch of the code. This was one of our last hold out examples from before the refactor. In performing the migration, I ditched the news classification dataset that had been used as it wasn't a great example task, and instead moved everything over to use the AG's News dataset that Yuchong had worked with. So there are some minor changes that relate to that move as well.

@@ -2,14 +2,15 @@
from abc import ABC, abstractmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file simply relate to moving from the previous news classification dataset to the ag news dataset.

loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super(BasicClient, self).__init__(data_path, device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is a slight hack that likely won't be necessary with additional refactors to the metrics managers class. Since I'm not explicitly defining a set of metric for my MetricMeter class, I am skipping over the __init__ function of the BasicClient class and directly going to the NumpyFl parent __init__

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for future reference, for #69 I slightly refactored how we handle accumulating and computing metrics, so I decided to base off this branch and update it so we call the BasicClient constructor, and save us from having to redefine a lot of attributes.

train_loader, validation_loader, num_examples, weight_matrix = construct_dataloaders(
self.data_path, vocabulary, label_encoder, sequence_length, batch_size
train_loader, validation_loader, _, weight_matrix = construct_dataloaders(
self.data_path, self.vocabulary, self.label_encoder, sequence_length, self.batch_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: self.vocabulary and self.label_encoder are initialized in setup_client before the call to super().setup_client() to ensure their availability

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding a comment for future reference when people are looking at the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

@@ -1,12 +1,12 @@
# Parameters that describe server
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some changes to shorten the run time.

@@ -75,54 +75,69 @@ def compute_metrics(self) -> Metrics:
return metrics


class ClientMetrics:
class CustomMetricMeter(MetricMeter):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining our own MetricMeter, as we're going to use a custom class to accrue metric information.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for future reference, in #69 I based of this branch and made some slight updates to the code in this file. First, by having metrics store their own state, we eliminate the need for having MetricMeters. I updated this to be a "Compound Metric" which stores information that is relevant to compute one or more related metrics. All in all, it does not change very much though.

@@ -53,20 +53,8 @@ def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics
return server_metrics.compute_metrics()


def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions were redundant, we might as well just use the metric_aggregation function rather than having both of these "encapsulations"

@@ -18,9 +18,9 @@ def __init__(self, vocab_size: int, vocab_dimension: int = 128, lstm_dimension:
bidirectional=True,
)
self.drop = nn.Dropout(p=0.3)
self.fc = nn.Linear(2 * lstm_dimension, 41)
self.fc = nn.Linear(2 * lstm_dimension, 4)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving from the previous new classification task, which had 41 labels, to the AG news task, which only has 4


def forward(self, x: torch.Tensor, hidden: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure why this type annotation had passed mypy previously, but hidden should be a tuple of tensors...

# stores the values of the new model parameters at the beginning of each training round.
self._align_model_parameters(self.initial_model, self.model)

def _align_model_parameters(self, initial_model: nn.Module, target_model: nn.Module) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yc7z: I feel like, at some point, we had a good reason for this function, but I couldn't quite see why we needed it anymore. It was only called on line 132 above and it appears that it was setting the initial_model values before training to use as a comparison to the trained model after local training. However, if initial_model is the same architecture as self.model, which is set using the server passed weights in set_parameters, I don't see why we need the for loops in this function. So I put
self.initial_model.load_state_dict(self.model.state_dict(), strict=True)
at the end of set_parameters in the new refactor. Let me know if that makes sense to you or if I'm missing something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense. Originally we had this method because I was using self.pull_parameters() to update self.initial_model, which was incorrect since the parameters pulled would not be the full model weights. But I think we could have simply used load_state_dict instead back then.

@@ -53,13 +53,6 @@ def construct_config(
}


def construct_eval_config(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't called anywhere, so I removed it.

def __init__(
self,
name: str = "F1 score",
average: Optional[str] = "weighted",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding an additional functionality to use different averaging approaches for the F1 calculations.

@jewelltaylor jewelltaylor mentioned this pull request Nov 4, 2023
# Initial model parameters to be used in calculating weight shifts during training
self.initial_model: nn.Module
# Parameter exchanger to be used in server-client exchange of dynamic layers.
self.parameter_exchanger: NormDriftParameterExchanger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thought I had is that the dynamic weight exchanger could be more general than the norm drift parameter exchanger its constrained to using. We may want to define other parameter exchanger that select a certain set of weights each FL round to exchange according to some crtieria. However, lacking a concrete alternative parameter exchanger perhaps we can save it for a future work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have a ticket in the backlog to tackle making this more generic! Just haven't gotten to it yet.

filter_by_percentage = self.narrow_config_type(config, "filter_by_percentage", bool)
norm_threshold = self.narrow_config_type(config, "norm_threshold", float)
exchange_percentage = self.narrow_config_type(config, "exchange_percentage", float)
parameter_exchanger = NormDriftParameterExchanger(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want the user to be able to configure what norm we use but default to the 2-norm since it is probably what people are looking to use most of the time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup definitely!

Copy link
Collaborator

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the most part, looks good to me! I left a few very minor comments for us to discuss further. Some just in relation to how I build on this PR for my metric factor PR and address some of the comments you had left in this PR. I don't imagine it will involve substantial changes though so I am going to go ahead and approve but will keep an eye out for the subsequent discussion/changes :)

@emersodb emersodb merged commit 9c29924 into main Nov 18, 2023
2 checks passed
@emersodb emersodb deleted the dbe/created_requirements_file_for_flamby branch November 18, 2023 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants