diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 73d048ddf3f..0d1edbfcb04 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -18,6 +18,7 @@ from .dirichlet_partitioner import DirichletPartitioner from .exponential_partitioner import ExponentialPartitioner from .iid_partitioner import IidPartitioner +from .inner_dirichlet_partitioner import InnerDirichletPartitioner from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner @@ -32,6 +33,7 @@ "DirichletPartitioner", "SizePartitioner", "LinearPartitioner", + "InnerDirichletPartitioner", "SquarePartitioner", "ShardPartitioner", "ExponentialPartitioner", diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py new file mode 100644 index 00000000000..c25a9b059d1 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py @@ -0,0 +1,308 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""InnerDirichlet partitioner.""" +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat, NDArrayInt +from flwr_datasets.partitioner.partitioner import Partitioner + + +class InnerDirichletPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on Dirichlet distribution. + + Each partition is created based on the Dirichlet distribution, where the + probability corresponds to the fractions of samples of specific classes. + This process is iterative (sample by sample assignment), where first, the + partition ID to which the class will be assigned is chosen (at random, uniformly), + and then the class is decided based on the Dirichlet probabilities (note that when + a class gets exhausted - no more samples exists to sample from - the probability of + sampling this class is set as zero and the remaining probabilities renormalized). + + Implementation based on: Federated Learning Based on Dynamic Regularization + (https://arxiv.org/abs/2111.04263). + + Parameters + ---------- + partition_sizes : Union[List[int], NDArrayInt] + The sizes of all partitions. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution (a single value for + symmetric Dirichlet distribution, or a list/NDArray of length equal to the + number of unique classes) + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import InnerDirichletPartitioner + >>> + >>> partitioner = InnerDirichletPartitioner( + >>> partition_sizes=[6_000] * 10, partition_by="label", alpha=0.5 + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + """ + + def __init__( # pylint: disable=R0913 + self, + partition_sizes: Union[List[int], NDArrayInt], + partition_by: str, + alpha: Union[int, float, List[float], NDArrayFloat], + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._partition_sizes = _instantiate_partition_sizes(partition_sizes) + self._initial_alpha = alpha + self._alpha: Optional[NDArrayFloat] = None + self._partition_by = partition_by + self._shuffle = shuffle + self._seed = seed + + # Utility attributes + self._initialized_alpha = False + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + # The attributes below are determined during the first call to load_partition + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._num_unique_classes: Optional[int] = None + self._num_partitions = len(self._partition_sizes) + + # self._avg_num_of_samples_per_node: Optional[float] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._check_partition_sizes_correctness_if_needed() + self._check_the_sum_of_partition_sizes() + self._determine_num_unique_classes_if_needed() + self._alpha = self._initialize_alpha_if_needed(self._initial_alpha) + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha_if_needed( + self, alpha: Union[int, float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if self._initialized_alpha: + assert self._alpha is not None + return self._alpha + if isinstance(alpha, int): + assert self._num_unique_classes is not None + alpha = np.array([float(alpha)], dtype=float).repeat( + self._num_unique_classes + ) + elif isinstance(alpha, float): + assert self._num_unique_classes is not None + alpha = np.array([alpha], dtype=float).repeat(self._num_unique_classes) + elif isinstance(alpha, List): + if len(alpha) != self._num_unique_classes: + raise ValueError( + "When passing alpha as a List, its length needs needs to be " + "of length equal to the number of unique classes." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_unique_classes: + raise ValueError( + "When passing alpha as an NDArray, its length needs needs to be " + "of length equal to the number of unique classes." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_unique_classes: + raise ValueError( + "When passing alpha as an NDArray, its length needs needs to be" + " of length equal to the number of unique classes." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Create class priors for the whole partitioning process + assert self._alpha is not None + class_priors = self._rng.dirichlet(alpha=self._alpha, size=self._num_partitions) + targets = np.asarray(self.dataset[self._partition_by]) + # List representing indices of each class + assert self._num_unique_classes is not None + idx_list = [np.where(targets == i)[0] for i in range(self._num_unique_classes)] + class_sizes = [len(idx_list[i]) for i in range(self._num_unique_classes)] + + client_indices = [ + np.zeros(self._partition_sizes[cid]).astype(np.int64) + for cid in range(self._num_partitions) + ] + + # Node id to number of sample left for allocation for that node id + node_id_to_left_to_allocate = dict( + zip(range(self._num_partitions), self._partition_sizes) + ) + + not_full_node_ids = list(range(self._num_partitions)) + while np.sum(list(node_id_to_left_to_allocate.values())) != 0: + # Choose a node + current_node_id = self._rng.choice(not_full_node_ids) + # If current node is full resample a client + if node_id_to_left_to_allocate[current_node_id] == 0: + # When the node is full, exclude it from the sampling nodes list + not_full_node_ids.pop(not_full_node_ids.index(current_node_id)) + continue + node_id_to_left_to_allocate[current_node_id] -= 1 + # Access the label distribution of the chosen client + current_probabilities = class_priors[current_node_id] + while True: + # curr_class = np.argmax(np.random.uniform() <= curr_prior) + curr_class = self._rng.choice( + list(range(self._num_unique_classes)), p=current_probabilities + ) + # Redraw class label if there are no samples left to be allocated from + # that class + if class_sizes[curr_class] == 0: + # Class got exhausted, set probabilities to 0 + class_priors[:, curr_class] = 0 + # Renormalize such that the probability sums to 1 + row_sums = class_priors.sum(axis=1, keepdims=True) + class_priors = class_priors / row_sums + # Adjust the current_probabilities (it won't sum up to 1 otherwise) + current_probabilities = class_priors[current_node_id] + continue + class_sizes[curr_class] -= 1 + # Store sample index at the empty array cell + index = node_id_to_left_to_allocate[current_node_id] + client_indices[current_node_id][index] = idx_list[curr_class][ + class_sizes[curr_class] + ] + break + + node_id_to_indices = { + cid: client_indices[cid].tolist() for cid in range(self._num_partitions) + } + # Shuffle the indices if the shuffle is True. + # Note that the samples from this partitioning do not necessarily require + # shuffling, the order should exhibit consecutive samples. + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller or equal to " + " the number of samples in the dataset." + ) + + def _check_partition_sizes_correctness_if_needed(self) -> None: + """Test partition_sizes when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if sum(self._partition_sizes) > self.dataset.num_rows: + raise ValueError( + "The sum of the `partition_sizes` needs to be smaller or equal to " + "the number of samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") + + def _determine_num_unique_classes_if_needed(self) -> None: + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + self._num_unique_classes = len(self._unique_classes) + + def _check_the_sum_of_partition_sizes(self) -> None: + if np.sum(self._partition_sizes) != len(self.dataset): + warnings.warn( + "The sum of the partition_sizes does not sum to the whole " + "dataset size. Make sure that is the desired behavior.", + stacklevel=1, + ) + + +def _instantiate_partition_sizes( + partition_sizes: Union[List[int], NDArrayInt] +) -> NDArrayInt: + """Transform list to the ndarray of ints if needed.""" + if isinstance(partition_sizes, List): + partition_sizes = np.asarray(partition_sizes) + elif isinstance(partition_sizes, np.ndarray): + pass + else: + raise ValueError( + f"The type of partition_sizes is incorrect. Given: " + f"{type(partition_sizes)}" + ) + + if not all(partition_sizes >= 0): + raise ValueError("The samples numbers must be greater or equal to zero.") + return partition_sizes diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py new file mode 100644 index 00000000000..0c5fb502870 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py @@ -0,0 +1,106 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" +# pylint: disable=W0212 +import unittest +from typing import List, Tuple, Union + +from datasets import Dataset +from flwr_datasets.common.typing import NDArrayFloat, NDArrayInt +from flwr_datasets.partitioner.inner_dirichlet_partitioner import ( + InnerDirichletPartitioner, +) + + +def _dummy_setup( + num_rows: int, + partition_by: str, + partition_sizes: Union[List[int], NDArrayInt], + alpha: Union[float, List[float], NDArrayFloat], +) -> Tuple[Dataset, InnerDirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = InnerDirichletPartitioner( + partition_sizes=partition_sizes, + alpha=alpha, + partition_by=partition_by, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestInnerDirichletPartitionerSuccess(unittest.TestCase): + """Test InnerDirichletPartitioner used with no exceptions.""" + + def test_correct_num_of_partitions(self) -> None: + """Test correct number of partitions.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + _ = partitioner.load_partition(0) + self.assertEqual( + len(partitioner._node_id_to_indices.keys()), len(partition_sizes) + ) + + def test_correct_partition_sizes(self) -> None: + """Test correct partition sizes.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + _ = partitioner.load_partition(0) + sizes_created = [ + len(indices) for indices in partitioner._node_id_to_indices.values() + ] + self.assertEqual(sorted(sizes_created), partition_sizes) + + +class TestInnerDirichletPartitionerFailure(unittest.TestCase): + """Test InnerDirichletPartitioner failures (exceptions) by incorrect usage.""" + + def test_incorrect_shape_of_alpha(self) -> None: + """Test the alpha shape not equal to the number of unique classes.""" + num_rows = 113 + partition_by = "labels" + alpha = [1.0, 1.0] + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + def test_too_big_sum_of_partition_sizes(self) -> None: + """Test sum of partition_sizes greater than the size of the dataset.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [60, 60, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index d5db6091344..04e28efdf76 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -28,6 +28,7 @@ class MlFramework(str, Enum): """Available frameworks.""" + NUMPY = "NumPy" PYTORCH = "PyTorch" TENSORFLOW = "TensorFlow" diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl new file mode 100644 index 00000000000..cf24457c8d2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -0,0 +1,24 @@ +"""$project_name: A Flower / NumPy app.""" + +import flwr as fl +import numpy as np + + +# Flower client, adapted from Pytorch quickstart example +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return [np.ones((1, 1))] + + def fit(self, parameters, config): + return ([np.ones((1, 1))], 1, {}) + + def evaluate(self, parameters, config): + return float(0.0), 1, {"accuracy": float(1.0)} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +# ClientApp for Flower-Next +app = fl.client.ClientApp(client_fn=client_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl new file mode 100644 index 00000000000..03f95ae35cf --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -0,0 +1,12 @@ +"""$project_name: A Flower / NumPy app.""" + +import flwr as fl + +# Configure the strategy +strategy = fl.server.strategy.FedAvg() + +# Flower ServerApp +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=1), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/flower.toml.tpl b/src/py/flwr/cli/new/templates/app/flower.toml.tpl index 4dd7117bc3a..e171783527a 100644 --- a/src/py/flwr/cli/new/templates/app/flower.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/flower.toml.tpl @@ -1,10 +1,10 @@ -[flower] +[project] name = "$project_name" version = "1.0.0" description = "" license = "Apache-2.0" authors = ["The Flower Authors "] -[components] +[flower.components] serverapp = "$project_name.server:app" clientapp = "$project_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl new file mode 100644 index 00000000000..dfb385079b2 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl @@ -0,0 +1,2 @@ +flwr>=1.8, <2.0 +numpy >= 1.21.0 diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 93d654379cf..43781776f78 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -20,7 +20,9 @@ import time from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Optional, Tuple, Union +from typing import Callable, ContextManager, Optional, Tuple, Type, Union + +from grpc import RpcError from flwr.client.client import Client from flwr.client.client_app import ClientApp @@ -36,6 +38,7 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.retry_invoker import RetryInvoker, exponential from .client_app import load_client_app from .grpc_client.connection import grpc_connection @@ -104,6 +107,8 @@ def _load() -> ClientApp: transport="rest" if args.rest else "grpc-rere", root_certificates=root_certificates, insecure=args.insecure, + max_retries=args.max_retries, + max_wait_time=args.max_wait_time, ) register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) @@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser: default="0.0.0.0:9092", help="Server address", ) + parser.add_argument( + "--max-retries", + type=int, + default=None, + help="The maximum number of times the client will try to connect to the" + "server before giving up in case of a connection error. By default," + "it is set to None, meaning there is no limit to the number of tries.", + ) + parser.add_argument( + "--max-wait-time", + type=float, + default=None, + help="The maximum duration before the client stops trying to" + "connect to the server in case of connection error. By default, it" + "is set to None, meaning there is no limit to the total time.", + ) parser.add_argument( "--dir", default="", @@ -180,6 +201,8 @@ def start_client( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -213,6 +236,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. Examples -------- @@ -254,6 +285,8 @@ class `flwr.client.Client` (default: None) root_certificates=root_certificates, insecure=insecure, transport=transport, + max_retries=max_retries, + max_wait_time=max_wait_time, ) event(EventType.START_CLIENT_LEAVE) @@ -272,6 +305,8 @@ def _start_client_internal( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -299,7 +334,7 @@ class `flwr.client.Client` (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. - insecure : bool (default: True) + insecure : Optional[bool] (default: None) Starts an insecure gRPC connection when True. Enables HTTPS connection when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) @@ -307,6 +342,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. """ if insecure is None: insecure = root_certificates is None @@ -338,7 +381,45 @@ def _load_client_app() -> ClientApp: # Both `client` and `client_fn` must not be used directly # Initialize connection context manager - connection, address = _init_connection(transport, server_address) + connection, address, connection_error_type = _init_connection( + transport, server_address + ) + + retry_invoker = RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=connection_error_type, + max_tries=max_retries, + max_time=max_wait_time, + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), + ) node_state = NodeState() @@ -347,6 +428,7 @@ def _load_client_app() -> ClientApp: with connection( address, insecure, + retry_invoker, grpc_max_message_length, root_certificates, ) as conn: @@ -509,7 +591,7 @@ def start_numpy_client( def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ - [str, bool, int, Union[bytes, str, None]], + [str, bool, RetryInvoker, int, Union[bytes, str, None]], ContextManager[ Tuple[ Callable[[], Optional[Message]], @@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ ], ], str, + Type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ # Use either gRPC bidirectional streaming or REST request/response if transport == TRANSPORT_TYPE_REST: try: + from requests.exceptions import ConnectionError as RequestsConnectionError + from .rest_client.connection import http_request_response except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) @@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ "When using the REST API, please provide `https://` or " "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" ) - connection = http_request_response + connection, error_type = http_request_response, RequestsConnectionError elif transport == TRANSPORT_TYPE_GRPC_RERE: - connection = grpc_request_response + connection, error_type = grpc_request_response, RpcError elif transport == TRANSPORT_TYPE_GRPC_BIDI: - connection = grpc_connection + connection, error_type = grpc_connection, RpcError else: raise ValueError( f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" ) - return connection, address + return connection, address, error_type diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 3561626dcb3..ddbb5336b2a 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -39,6 +39,7 @@ ) from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.retry_invoker import RetryInvoker from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -62,6 +63,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, + retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ @@ -80,6 +82,11 @@ def grpc_connection( # pylint: disable=R0915 The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or `"[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + Unused argument present for compatibilty. max_message_length : int The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 30bff068b60..28e03979fd6 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -26,6 +26,7 @@ from flwr.common import ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -127,7 +128,16 @@ def test_integration_connection() -> None: def run_client() -> int: messages_received: int = 0 - with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn: + with grpc_connection( + server_address=f"[::]:{port}", + insecure=True, + retry_invoker=RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=1, + max_time=None, + ), + ) as conn: receive, send, _, _ = conn # Setup processing loop diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 00b7a864c5d..e6e22998b94 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_request_response( server_address: str, insecure: bool, + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ @@ -72,6 +74,13 @@ def grpc_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -113,7 +122,8 @@ def grpc_request_response( def create_node() -> None: """Set create_node.""" create_node_request = CreateNodeRequest() - create_node_response = stub.CreateNode( + create_node_response = retry_invoker.invoke( + stub.CreateNode, request=create_node_request, ) node_store[KEY_NODE] = create_node_response.node @@ -127,7 +137,7 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_request = DeleteNodeRequest(node=node) - stub.DeleteNode(request=delete_node_request) + retry_invoker.invoke(stub.DeleteNode, request=delete_node_request) del node_store[KEY_NODE] @@ -141,7 +151,7 @@ def receive() -> Optional[Message]: # Request instructions (task) from server request = PullTaskInsRequest(node=node) - response = stub.PullTaskIns(request=request) + response = retry_invoker.invoke(stub.PullTaskIns, request=request) # Get the current TaskIns task_ins: Optional[TaskIns] = get_task_ins(response) @@ -185,7 +195,7 @@ def send(message: Message) -> None: # Serialize ProtoBuf to bytes request = PushTaskResRequest(task_res_list=[task_res]) - _ = stub.PushTaskRes(request) + _ = retry_invoker.invoke(stub.PushTaskRes, request) state[KEY_METADATA] = None diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index c637475551e..d2cc71ba3b3 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -61,6 +62,7 @@ def http_request_response( server_address: str, insecure: bool, # pylint: disable=unused-argument + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[ Union[bytes, str] @@ -84,6 +86,12 @@ def http_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Unused argument present for compatibilty. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after REST connection errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -134,7 +142,8 @@ def create_node() -> None: create_node_req_proto = CreateNodeRequest() create_node_req_bytes: bytes = create_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_CREATE_NODE}", headers={ "Accept": "application/protobuf", @@ -177,7 +186,8 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_req_proto = DeleteNodeRequest(node=node) delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_DELETE_NODE}", headers={ "Accept": "application/protobuf", @@ -218,7 +228,8 @@ def receive() -> Optional[Message]: pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString() # Request instructions (task) from server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PULL_TASK_INS}", headers={ "Accept": "application/protobuf", @@ -298,7 +309,8 @@ def send(message: Message) -> None: ) # Send ClientMessage to server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PUSH_TASK_RES}", headers={ "Accept": "application/protobuf", diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 0b2d3f17c47..9f9ff7ebc68 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -15,6 +15,8 @@ """Common components shared between server and client.""" +from .constant import MessageType as MessageType +from .constant import MessageTypeLegacy as MessageTypeLegacy from .context import Context as Context from .date import now as now from .grpc import GRPC_MAX_MESSAGE_LENGTH @@ -83,6 +85,8 @@ "GRPC_MAX_MESSAGE_LENGTH", "log", "Message", + "MessageType", + "MessageTypeLegacy", "Metadata", "Metrics", "MetricsAggregationFn", diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 2946a594e68..d3f429586a0 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -38,10 +38,32 @@ MESSAGE_TYPE_GET_PROPERTIES = "get_properties" MESSAGE_TYPE_GET_PARAMETERS = "get_parameters" -MESSAGE_TYPE_FIT = "fit" +MESSAGE_TYPE_FIT = "train" MESSAGE_TYPE_EVALUATE = "evaluate" +class MessageType: + """Message type.""" + + TRAIN = "train" + EVALUATE = "evaluate" + + def __new__(cls) -> MessageType: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class MessageTypeLegacy: + """Legacy message type.""" + + GET_PROPERTIES = "get_properties" + GET_PARAMETERS = "get_parameters" + + def __new__(cls) -> MessageTypeLegacy: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + class SType: """Serialisation type.""" diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 01b1f622212..e04cfb37e11 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -365,7 +365,7 @@ def run_superlink() -> None: client_app_attr=args.client_app, backend_name=args.backend, backend_config_json_stream=args.backend_config, - working_dir=args.dir, + app_dir=args.app_dir, state_factory=state_factory, f_stop=f_stop, ) @@ -441,7 +441,7 @@ def _run_fleet_api_vce( client_app_attr: str, backend_name: str, backend_config_json_stream: str, - working_dir: str, + app_dir: str, state_factory: StateFactory, f_stop: asyncio.Event, ) -> None: @@ -453,7 +453,7 @@ def _run_fleet_api_vce( backend_name=backend_name, backend_config_json_stream=backend_config_json_stream, state_factory=state_factory, - working_dir=working_dir, + app_dir=app_dir, f_stop=f_stop, ) @@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: "`flwr.common.typing.ConfigsRecordValues`. ", ) parser.add_argument( - "--dir", + "--app-dir", default="", help="Add specified directory to the PYTHONPATH and load" "ClientApp from there." diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index d42379960a6..225c6155774 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -221,7 +221,7 @@ async def run( def start_vce( backend_name: str, backend_config_json_stream: str, - working_dir: str, + app_dir: str, f_stop: asyncio.Event, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, @@ -297,7 +297,7 @@ def start_vce( def backend_fn() -> Backend: """Instantiate a Backend.""" - return backend_type(backend_config, work_dir=working_dir) + return backend_type(backend_config, work_dir=app_dir) log(INFO, "client_app_attr = %s", client_app_attr) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 16cb45c1262..72e9c07ebcc 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -119,21 +119,21 @@ def register_messages_into_state( return expected_results -def _autoresolve_working_dir(rel_client_app_dir: str = "backend") -> str: - """Correctly resolve working directory.""" +def _autoresolve_app_dir(rel_client_app_dir: str = "backend") -> str: + """Correctly resolve working directory for the app.""" file_path = Path(__file__) - working_dir = Path.cwd() - rel_workdir = file_path.relative_to(working_dir) + app_dir = Path.cwd() + rel_app_dir = file_path.relative_to(app_dir) # Susbtract lats element and append "backend/test" (wher the client module is.) - return str(rel_workdir.parent / rel_client_app_dir) + return str(rel_app_dir.parent / rel_client_app_dir) # pylint: disable=too-many-arguments def start_and_shutdown( backend: str = "ray", client_app_attr: str = "raybackend_test:client_app", - working_dir: str = "", + app_dir: str = "", num_supernodes: Optional[int] = None, state_factory: Optional[StateFactory] = None, nodes_mapping: Optional[NodeToPartitionMapping] = None, @@ -157,8 +157,8 @@ def start_and_shutdown( termination_th.start() # Resolve working directory if not passed - if not working_dir: - working_dir = _autoresolve_working_dir() + if not app_dir: + app_dir = _autoresolve_app_dir() start_vce( num_supernodes=num_supernodes, @@ -166,7 +166,7 @@ def start_and_shutdown( backend_name=backend, backend_config_json_stream=backend_config, state_factory=state_factory, - working_dir=working_dir, + app_dir=app_dir, f_stop=f_stop, existing_nodes_mapping=nodes_mapping, ) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index cb68221ea58..31884f2edc6 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -54,7 +54,7 @@ def run_simulation_from_cli() -> None: num_supernodes=args.num_supernodes, backend_name=args.backend, backend_config=backend_config_dict, - working_dir=args.dir, + app_dir=args.app_dir, driver_api_address=args.driver_api_address, enable_tf_gpu_growth=args.enable_tf_gpu_growth, verbose_logging=args.verbose, @@ -125,7 +125,7 @@ def run_serverapp_th( server_app_attr: Optional[str], server_app: Optional[ServerApp], driver: Driver, - server_app_dir: str, + app_dir: str, f_stop: asyncio.Event, enable_tf_gpu_growth: bool, delay_launch: int = 3, @@ -163,7 +163,7 @@ def server_th_with_start_checks( # type: ignore "server_app_attr": server_app_attr, "loaded_server_app": server_app, "driver": driver, - "server_app_dir": server_app_dir, + "server_app_dir": app_dir, }, ) sleep(delay_launch) @@ -177,7 +177,7 @@ def _main_loop( backend_name: str, backend_config_stream: str, driver_api_address: str, - working_dir: str, + app_dir: str, enable_tf_gpu_growth: bool, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, @@ -214,7 +214,7 @@ def _main_loop( server_app_attr=server_app_attr, server_app=server_app, driver=driver, - server_app_dir=working_dir, + app_dir=app_dir, f_stop=f_stop, enable_tf_gpu_growth=enable_tf_gpu_growth, ) @@ -227,7 +227,7 @@ def _main_loop( client_app=client_app, backend_name=backend_name, backend_config_json_stream=backend_config_stream, - working_dir=working_dir, + app_dir=app_dir, state_factory=state_factory, f_stop=f_stop, ) @@ -260,7 +260,7 @@ def _run_simulation( backend_config: Optional[Dict[str, ConfigsRecordValues]] = None, client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, - working_dir: str = "", + app_dir: str = "", driver_api_address: str = "0.0.0.0:9091", enable_tf_gpu_growth: bool = False, verbose_logging: bool = False, @@ -297,7 +297,7 @@ def _run_simulation( A path to a `ServerApp` module to be loaded: For example: `server:app` or `project.package.module:wrapper.app`." - working_dir : str + app_dir : str Add specified directory to the PYTHONPATH and load `ClientApp` from there. (Default: current working directory.) @@ -340,7 +340,7 @@ def _run_simulation( backend_name, backend_config_stream, driver_api_address, - working_dir, + app_dir, enable_tf_gpu_growth, client_app, client_app_attr, @@ -378,15 +378,21 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Start a Flower simulation", ) + parser.add_argument( + "--server-app", + required=True, + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) parser.add_argument( "--client-app", required=True, help="For example: `client:app` or `project.package.module:wrapper.app`", ) parser.add_argument( - "--server-app", + "--num-supernodes", + type=int, required=True, - help="For example: `server:app` or `project.package.module:wrapper.app`", + help="Number of simulated SuperNodes.", ) parser.add_argument( "--driver-api-address", @@ -394,18 +400,20 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: type=str, help="For example: `server:app` or `project.package.module:wrapper.app`", ) - parser.add_argument( - "--num-supernodes", - type=int, - required=True, - help="Number of simulated SuperNodes.", - ) parser.add_argument( "--backend", default="ray", type=str, help="Simulation backend that executes the ClientApp.", ) + parser.add_argument( + "--backend-config", + type=str, + default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}', + help='A JSON formatted stream, e.g \'{"":, "":}\' to ' + "configure a backend. Values supported in are those included by " + "`flwr.common.typing.ConfigsRecordValues`. ", + ) parser.add_argument( "--enable-tf-gpu-growth", action="store_true", @@ -417,26 +425,17 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: "the TensorFlow documentation: https://www.tensorflow.org/api/stable.", ) parser.add_argument( - "--backend-config", - type=str, - default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}', - help='A JSON formatted stream, e.g \'{"":, "":}\' to ' - "configure a backend. Values supported in are those included by " - "`flwr.common.typing.ConfigsRecordValues`. ", + "--verbose", + action="store_true", + help="When unset, only INFO, WARNING and ERROR log messages will be shown. " + "If set, DEBUG-level logs will be displayed. ", ) parser.add_argument( - "--dir", + "--app-dir", default="", help="Add specified directory to the PYTHONPATH and load" "ClientApp and ServerApp from there." " Default: current working directory.", ) - parser.add_argument( - "--verbose", - action="store_true", - help="When unset, only INFO, WARNING and ERROR log messages will be shown. " - "If set, DEBUG-level logs will be displayed. ", - ) - return parser