diff --git a/examples/flower-authentication/authexample/client_app.py b/examples/flower-authentication/authexample/client_app.py index 4dde41fed27..524e31d53ca 100644 --- a/examples/flower-authentication/authexample/client_app.py +++ b/examples/flower-authentication/authexample/client_app.py @@ -4,7 +4,14 @@ from flwr.client import ClientApp, NumPyClient from flwr.common import Context -from authexample.task import Net, get_weights, load_data_from_disk, set_weights, test, train +from authexample.task import ( + Net, + get_weights, + load_data_from_disk, + set_weights, + test, + train, +) # Define Flower Client diff --git a/examples/flower-authentication/authexample/task.py b/examples/flower-authentication/authexample/task.py index 7a6887ea88d..aaca1447c54 100644 --- a/examples/flower-authentication/authexample/task.py +++ b/examples/flower-authentication/authexample/task.py @@ -84,7 +84,7 @@ def load_data_to_disk(num_partitions: int = 2): dataset="uoft-cs/cifar10", partitioners={"train": partitioner}, ) - + for partition_id in range(num_partitions): partition = fds.load_partition(partition_id) partition_train_test = partition.train_test_split(test_size=0.2, seed=42) diff --git a/examples/flower-authentication/dataset.py b/examples/flower-authentication/dataset.py index 43ebe6dffbf..bd3f07bc5fd 100644 --- a/examples/flower-authentication/dataset.py +++ b/examples/flower-authentication/dataset.py @@ -3,14 +3,21 @@ if __name__ == "__main__": # Initialize argument parser - parser = argparse.ArgumentParser(description="Load CIFAR-10 dataset partitions to disk") - + parser = argparse.ArgumentParser( + description="Load CIFAR-10 dataset partitions to disk" + ) + # Add an optional positional argument for number of partitions - parser.add_argument("num_partitions", type=int, nargs="?", default=2, - help="Number of partitions to create (default: 2)") - + parser.add_argument( + "num_partitions", + type=int, + nargs="?", + default=2, + help="Number of partitions to create (default: 2)", + ) + # Parse the arguments args = parser.parse_args() - + # Call the function with the provided argument load_data_to_disk(args.num_partitions)