-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A few fixes to our Examples #300
Conversation
…l servers, dropping some print statements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved these to our examples/assets folder instead of the top level assets folder.
@@ -6,10 +6,6 @@ The server has some custom metrics aggregation and uses Federated Averaging as i | |||
|
|||
As this is a warm-up training for consecutive runs with different Federated Learning (FL) algorithms, it is crucial to set a fixed seed for both clients and the server to ensure uniformity in random data points across these runs. Therefore, we make sure to set a fixed seed for these consecutive runs in both the `client.py` and `server.py` files. Additionally, it is important to establish a checkpointing strategy for the clients using their randomly generated unique client names. This allows us to load each client's warmed-up model from this example in further instances. In this particular scenario, we set the checkpointing strategy to save the latest model. This ensures that we can load the trained local model for each client from this example in subsequent runs as a warmed-up model. | |||
|
|||
### Weights and Biases Reporting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropping this part, since it's not really necessary for this example
@@ -5,11 +5,3 @@ n_server_rounds: 2 # The number of rounds to run FL | |||
n_clients: 3 # The number of clients in the FL experiment | |||
local_epochs: 1 # The number of epochs to complete for client | |||
batch_size: 128 # The batch size for client training | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropping this part, since it's not really necessary for this example
@@ -6,10 +6,6 @@ The server has some custom metrics aggregation and uses FedProx as its server-si | |||
|
|||
After the warm-up training, clients can load their warmed-up models and continue training with the FedProx algorithm. To maintain consistency in the data loader between both runs, it is crucial to set a fixed seed for both clients and the server, ensuring uniformity in random data points across consecutive runs. Therefore, we ensure a fixed seed is set for these consecutive runs in both the `client.py` and `server.py` files. Additionally, to load the warmed-up models, it's important provide the path to the pretrained models based on client's unique name, ensuring that we can load the trained local model for each client from the previous example as a warmed-up model. Since models in the two runs can be different, loading weights from the pretrained model requires providing a mapping between the pretrained model and the model used in FL training. This mapping is accomplished through the `weights_mapping.json` file, which contains the names of the pretrained model's layers and the corresponding names of the layers in the model used in FL training. | |||
|
|||
### Weights and Biases Reporting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropping this part, since it's not really necessary for this example
@@ -14,11 +14,3 @@ proximal_weight_patience : 5 # The number of rounds to wait before increasing or | |||
n_clients: 3 # The number of clients in the FL experiment | |||
local_epochs: 1 # The number of epochs to complete for client | |||
batch_size: 128 # The batch size for client training | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropping this part, since it's not really necessary for this example
@@ -30,9 +30,11 @@ def __init__( | |||
metrics: Sequence[Metric], | |||
device: torch.device, | |||
checkpoint_dir: str, | |||
client_name: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this an argument you pass is easier to manage than leaving it up to the generate hash and trying to match things.
@@ -29,19 +28,17 @@ def __init__( | |||
data_path: Path, | |||
metrics: Sequence[Metric], | |||
device: torch.device, | |||
pretrained_model_dir: Path, | |||
pretrained_model_path: Path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving to just specifying the model path rather than trying to infer it based on the client name, which can not longer be fixed by the random seed (they use UUIDs not under the hood).
@@ -31,7 +30,7 @@ def __init__( | |||
data_path: Path, | |||
metrics: Sequence[Metric], | |||
device: torch.device, | |||
pretrained_model_dir: Path, | |||
pretrained_model_path: Path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving to just specifying the model path rather than trying to infer it based on the client name, which can not longer be fixed by the random seed (they use UUIDs not under the hood).
@@ -132,6 +132,8 @@ def configure_fit( | |||
if self.on_fit_config_fn is not None: | |||
# Custom fit config function provided | |||
config = self.on_fit_config_fn(server_round) | |||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a few examples, we don't specify a config, but we assume that current_server_round
is always present. So this (and those below) ensure that it is in there.
@@ -72,7 +72,7 @@ def __init__( | |||
transform: Callable | None = None, | |||
target_transform: Callable | None = None, | |||
) -> None: | |||
assert targets is not None, "SslTensorDataset targets must be None" | |||
assert targets is None, "SslTensorDataset targets must be None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a bug. See the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason this was in the smoke tests folder, but we don't run the example in the smokes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll let other people who are more knowledgeable on the examples approve this. I'm familiar mainly with the nnunet example and it seems like no major changes were made there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
PR Type
Fix
Short Description
Clickup Ticket(s): N/A
I went through all of our examples to make sure they run correctly. Generally everything still works great. There were just a few small bugs that I patched here.
Most of the changes in this PR are just me moving our example servers to not accept failures. That way we don't have zombie processes for anyone running the examples if something weird happens.
I also dropped a few print statements and moved a few files.
Tests Added
N/A