Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few fixes to our Examples #300

Merged
merged 7 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_dim_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.preprocessing.autoencoders.dim_reduction import CvaeFixedConditionProcessor
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand All @@ -30,7 +30,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str))
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)])
# CvaeFixedConditionProcessor is added to the data transform pipeline to encode the data samples
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
Expand Down
1 change: 1 addition & 0 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# To make sure pixels stay in the range [0.0, 1.0].
transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor()])
# To train an autoencoder-based model we need to set the data converter.
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.metrics import Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand Down Expand Up @@ -50,7 +50,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
# ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0].
# Flattening the image data to match the input shape of the model.
flatten_transform = transforms.Lambda(lambda x: torch.flatten(x))
transform = transforms.Compose([transforms.ToTensor(), flatten_transform])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), flatten_transform])
train_loader, val_loader, _ = load_mnist_data(
data_dir=self.data_path,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/fedprox_vae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.load_data import ToNumpy, load_mnist_data
from fl4health.utils.sampler import DirichletLabelBasedSampler


Expand All @@ -25,7 +25,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# Flattening the input images to use an MLP-based variational autoencoder.
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)])
# Create and pass the autoencoder data converter to the data loader.
self.autoencoder_converter = AutoEncoderDatasetConverter(condition=None)
train_loader, val_loader, _ = load_mnist_data(
Expand Down
2 changes: 1 addition & 1 deletion examples/ae_examples/fedprox_vae_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ batch_size: 32 # The batch size for client training

# FedProx variables
adaptive_proximal_weight: False # Whether to use adaptive proximal weight or not
proximal_weight : 0.1 # The proximal weight
initial_proximal_weight : 0.1 # The proximal weight

# Checkpointing
checkpoint_path: "examples/ae_examples/fedprox_vae_example"
Expand Down
3 changes: 2 additions & 1 deletion examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main(config: dict[str, Any]) -> None:
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
adapt_loss_weight=config["adapt_proximal_weight"],
adapt_loss_weight=config["adaptive_proximal_weight"],
initial_loss_weight=config["initial_proximal_weight"],
)

Expand All @@ -75,6 +75,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
8 changes: 7 additions & 1 deletion examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()])
server = FlServer(
client_manager=client_manager,
fl_config=config,
strategy=strategy,
reporters=[JsonReporter()],
accept_failures=False,
)

fl.server.start_server(
server=server,
Expand Down
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these to our examples/assets folder instead of the top level assets folder.

File renamed without changes.
1 change: 1 addition & 0 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
2 changes: 1 addition & 1 deletion examples/ditto_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = DittoServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = DittoServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
1 change: 1 addition & 0 deletions examples/dp_fed_examples/client_level_dp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def main(config: dict[str, Any]) -> None:
strategy=strategy,
server_noise_multiplier=config["server_noise_multiplier"],
num_server_rounds=config["n_server_rounds"],
accept_failures=False,
)

fl.server.start_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def main(config: dict[str, Any]) -> None:
clipping_noise_multiplier=config["clipping_bit_noise_multiplier"],
beta=config["server_momentum"],
weighted_aggregation=config["weighted_averaging"],
accept_failures=False,
)

server = ClientLevelDPFedAvgServer(
Expand Down
1 change: 1 addition & 0 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def main(config: dict[str, Any]) -> None:
batch_size=config["batch_size"],
num_server_rounds=config["n_server_rounds"],
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
1 change: 1 addition & 0 deletions examples/dp_scaffold_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def main(config: dict[str, Any]) -> None:
num_server_rounds=config["n_server_rounds"],
strategy=strategy,
warm_start=True,
accept_failures=False,
)

fl.server.start_server(
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamic_layer_exchange_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
1 change: 1 addition & 0 deletions examples/feature_alignment_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(config: dict[str, Any]) -> None:
initialize_parameters=get_initial_model_parameters,
strategy=strategy,
tabular_features_source_of_truth=tab_feature_info_encoder_hospital1,
accept_failures=False,
)

fl.server.start_server(
Expand Down
2 changes: 1 addition & 1 deletion examples/fedbn_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(config: dict[str, Any], server_address: str, dataset_name: str) -> None
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
8 changes: 7 additions & 1 deletion examples/feddg_ga_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def main(config: dict[str, Any]) -> None:
# will return the same sampling until it is told to reset, which in FedDgGaStrategy
# is done right before fit_round.
client_manager = FixedSamplingClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()])
server = FlServer(
client_manager=client_manager,
fl_config=config,
strategy=strategy,
reporters=[JsonReporter()],
accept_failures=False,
)

fl.server.start_server(
server=server,
Expand Down
1 change: 1 addition & 0 deletions examples/federated_eval_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def main(config: dict[str, Any], server_checkpoint_path: Path | None) -> None:
evaluate_config=evaluate_config,
evaluate_metrics_aggregation_fn=uniform_evaluate_metrics_aggregation_fn,
min_available_clients=config["n_clients"],
accept_failures=False,
)

fl.server.start_server(
Expand Down
3 changes: 2 additions & 1 deletion examples/fedopt_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def main(config: dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
# Server side weight initialization
initial_parameters=get_all_model_parameters(initial_model),
accept_failures=False,
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server_address=config["server_address"],
Expand Down
3 changes: 2 additions & 1 deletion examples/fedpca_examples/dim_reduction/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.preprocessing.pca_preprocessor import PcaPreprocessor
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import get_train_and_val_mnist_datasets
from fl4health.utils.load_data import ToNumpy, get_train_and_val_mnist_datasets
from fl4health.utils.metrics import Accuracy
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
Expand All @@ -31,6 +31,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
# Get training and validation datasets.
transform = transforms.Compose(
[
ToNumpy(),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)),
]
Expand Down
1 change: 1 addition & 0 deletions examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
2 changes: 1 addition & 1 deletion examples/fedper_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
1 change: 1 addition & 0 deletions examples/fedpm_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def main(config: dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(initial_model),
# Perform Bayesian aggregation.
bayesian_aggregation=True,
accept_failures=False,
)

client_manager = SimpleClientManager()
Expand Down
4 changes: 3 additions & 1 deletion examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def main(config: dict[str, Any], server_address: str) -> None:
reporters = [wandb_reporter, json_reporter]
else:
reporters = [json_reporter]
server = FedProxServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters)
server = FedProxServer(
client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters, accept_failures=False
)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedrep_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fl4health.model_bases.fedsimclr_base import FedSimClrModel
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import get_cifar10_data_and_target_tensors, split_data_and_targets
from fl4health.utils.load_data import ToNumpy, get_cifar10_data_and_target_tensors, split_data_and_targets
from fl4health.utils.metrics import Accuracy


Expand All @@ -26,6 +26,7 @@ def get_finetune_dataset(data_dir: Path, batch_size: int) -> tuple[DataLoader, D

input_transform = transforms.Compose(
[
ToNumpy(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def get_transforms() -> tuple[Callable, Callable]:

target_transform = transforms.Compose(
[
ToNumpy(),
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([color_jitter], p=0.8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def main(config: dict[str, Any]) -> None:
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
Expand Down
2 changes: 1 addition & 1 deletion examples/fenda_ditto_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fenda_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fl_plus_local_ft_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server_address="0.0.0.0:8080",
Expand Down
2 changes: 1 addition & 1 deletion examples/flash_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(config: dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False)

fl.server.start_server(
server=server,
Expand Down
Loading