diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 73d048ddf3f..0d1edbfcb04 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -18,6 +18,7 @@ from .dirichlet_partitioner import DirichletPartitioner from .exponential_partitioner import ExponentialPartitioner from .iid_partitioner import IidPartitioner +from .inner_dirichlet_partitioner import InnerDirichletPartitioner from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner @@ -32,6 +33,7 @@ "DirichletPartitioner", "SizePartitioner", "LinearPartitioner", + "InnerDirichletPartitioner", "SquarePartitioner", "ShardPartitioner", "ExponentialPartitioner", diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py new file mode 100644 index 00000000000..c25a9b059d1 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py @@ -0,0 +1,308 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""InnerDirichlet partitioner.""" +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat, NDArrayInt +from flwr_datasets.partitioner.partitioner import Partitioner + + +class InnerDirichletPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on Dirichlet distribution. + + Each partition is created based on the Dirichlet distribution, where the + probability corresponds to the fractions of samples of specific classes. + This process is iterative (sample by sample assignment), where first, the + partition ID to which the class will be assigned is chosen (at random, uniformly), + and then the class is decided based on the Dirichlet probabilities (note that when + a class gets exhausted - no more samples exists to sample from - the probability of + sampling this class is set as zero and the remaining probabilities renormalized). + + Implementation based on: Federated Learning Based on Dynamic Regularization + (https://arxiv.org/abs/2111.04263). + + Parameters + ---------- + partition_sizes : Union[List[int], NDArrayInt] + The sizes of all partitions. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution (a single value for + symmetric Dirichlet distribution, or a list/NDArray of length equal to the + number of unique classes) + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import InnerDirichletPartitioner + >>> + >>> partitioner = InnerDirichletPartitioner( + >>> partition_sizes=[6_000] * 10, partition_by="label", alpha=0.5 + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + """ + + def __init__( # pylint: disable=R0913 + self, + partition_sizes: Union[List[int], NDArrayInt], + partition_by: str, + alpha: Union[int, float, List[float], NDArrayFloat], + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._partition_sizes = _instantiate_partition_sizes(partition_sizes) + self._initial_alpha = alpha + self._alpha: Optional[NDArrayFloat] = None + self._partition_by = partition_by + self._shuffle = shuffle + self._seed = seed + + # Utility attributes + self._initialized_alpha = False + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + # The attributes below are determined during the first call to load_partition + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._num_unique_classes: Optional[int] = None + self._num_partitions = len(self._partition_sizes) + + # self._avg_num_of_samples_per_node: Optional[float] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._check_partition_sizes_correctness_if_needed() + self._check_the_sum_of_partition_sizes() + self._determine_num_unique_classes_if_needed() + self._alpha = self._initialize_alpha_if_needed(self._initial_alpha) + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha_if_needed( + self, alpha: Union[int, float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if self._initialized_alpha: + assert self._alpha is not None + return self._alpha + if isinstance(alpha, int): + assert self._num_unique_classes is not None + alpha = np.array([float(alpha)], dtype=float).repeat( + self._num_unique_classes + ) + elif isinstance(alpha, float): + assert self._num_unique_classes is not None + alpha = np.array([alpha], dtype=float).repeat(self._num_unique_classes) + elif isinstance(alpha, List): + if len(alpha) != self._num_unique_classes: + raise ValueError( + "When passing alpha as a List, its length needs needs to be " + "of length equal to the number of unique classes." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_unique_classes: + raise ValueError( + "When passing alpha as an NDArray, its length needs needs to be " + "of length equal to the number of unique classes." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_unique_classes: + raise ValueError( + "When passing alpha as an NDArray, its length needs needs to be" + " of length equal to the number of unique classes." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Create class priors for the whole partitioning process + assert self._alpha is not None + class_priors = self._rng.dirichlet(alpha=self._alpha, size=self._num_partitions) + targets = np.asarray(self.dataset[self._partition_by]) + # List representing indices of each class + assert self._num_unique_classes is not None + idx_list = [np.where(targets == i)[0] for i in range(self._num_unique_classes)] + class_sizes = [len(idx_list[i]) for i in range(self._num_unique_classes)] + + client_indices = [ + np.zeros(self._partition_sizes[cid]).astype(np.int64) + for cid in range(self._num_partitions) + ] + + # Node id to number of sample left for allocation for that node id + node_id_to_left_to_allocate = dict( + zip(range(self._num_partitions), self._partition_sizes) + ) + + not_full_node_ids = list(range(self._num_partitions)) + while np.sum(list(node_id_to_left_to_allocate.values())) != 0: + # Choose a node + current_node_id = self._rng.choice(not_full_node_ids) + # If current node is full resample a client + if node_id_to_left_to_allocate[current_node_id] == 0: + # When the node is full, exclude it from the sampling nodes list + not_full_node_ids.pop(not_full_node_ids.index(current_node_id)) + continue + node_id_to_left_to_allocate[current_node_id] -= 1 + # Access the label distribution of the chosen client + current_probabilities = class_priors[current_node_id] + while True: + # curr_class = np.argmax(np.random.uniform() <= curr_prior) + curr_class = self._rng.choice( + list(range(self._num_unique_classes)), p=current_probabilities + ) + # Redraw class label if there are no samples left to be allocated from + # that class + if class_sizes[curr_class] == 0: + # Class got exhausted, set probabilities to 0 + class_priors[:, curr_class] = 0 + # Renormalize such that the probability sums to 1 + row_sums = class_priors.sum(axis=1, keepdims=True) + class_priors = class_priors / row_sums + # Adjust the current_probabilities (it won't sum up to 1 otherwise) + current_probabilities = class_priors[current_node_id] + continue + class_sizes[curr_class] -= 1 + # Store sample index at the empty array cell + index = node_id_to_left_to_allocate[current_node_id] + client_indices[current_node_id][index] = idx_list[curr_class][ + class_sizes[curr_class] + ] + break + + node_id_to_indices = { + cid: client_indices[cid].tolist() for cid in range(self._num_partitions) + } + # Shuffle the indices if the shuffle is True. + # Note that the samples from this partitioning do not necessarily require + # shuffling, the order should exhibit consecutive samples. + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller or equal to " + " the number of samples in the dataset." + ) + + def _check_partition_sizes_correctness_if_needed(self) -> None: + """Test partition_sizes when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if sum(self._partition_sizes) > self.dataset.num_rows: + raise ValueError( + "The sum of the `partition_sizes` needs to be smaller or equal to " + "the number of samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") + + def _determine_num_unique_classes_if_needed(self) -> None: + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + self._num_unique_classes = len(self._unique_classes) + + def _check_the_sum_of_partition_sizes(self) -> None: + if np.sum(self._partition_sizes) != len(self.dataset): + warnings.warn( + "The sum of the partition_sizes does not sum to the whole " + "dataset size. Make sure that is the desired behavior.", + stacklevel=1, + ) + + +def _instantiate_partition_sizes( + partition_sizes: Union[List[int], NDArrayInt] +) -> NDArrayInt: + """Transform list to the ndarray of ints if needed.""" + if isinstance(partition_sizes, List): + partition_sizes = np.asarray(partition_sizes) + elif isinstance(partition_sizes, np.ndarray): + pass + else: + raise ValueError( + f"The type of partition_sizes is incorrect. Given: " + f"{type(partition_sizes)}" + ) + + if not all(partition_sizes >= 0): + raise ValueError("The samples numbers must be greater or equal to zero.") + return partition_sizes diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py new file mode 100644 index 00000000000..0c5fb502870 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner_test.py @@ -0,0 +1,106 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" +# pylint: disable=W0212 +import unittest +from typing import List, Tuple, Union + +from datasets import Dataset +from flwr_datasets.common.typing import NDArrayFloat, NDArrayInt +from flwr_datasets.partitioner.inner_dirichlet_partitioner import ( + InnerDirichletPartitioner, +) + + +def _dummy_setup( + num_rows: int, + partition_by: str, + partition_sizes: Union[List[int], NDArrayInt], + alpha: Union[float, List[float], NDArrayFloat], +) -> Tuple[Dataset, InnerDirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = InnerDirichletPartitioner( + partition_sizes=partition_sizes, + alpha=alpha, + partition_by=partition_by, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestInnerDirichletPartitionerSuccess(unittest.TestCase): + """Test InnerDirichletPartitioner used with no exceptions.""" + + def test_correct_num_of_partitions(self) -> None: + """Test correct number of partitions.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + _ = partitioner.load_partition(0) + self.assertEqual( + len(partitioner._node_id_to_indices.keys()), len(partition_sizes) + ) + + def test_correct_partition_sizes(self) -> None: + """Test correct partition sizes.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + _ = partitioner.load_partition(0) + sizes_created = [ + len(indices) for indices in partitioner._node_id_to_indices.values() + ] + self.assertEqual(sorted(sizes_created), partition_sizes) + + +class TestInnerDirichletPartitionerFailure(unittest.TestCase): + """Test InnerDirichletPartitioner failures (exceptions) by incorrect usage.""" + + def test_incorrect_shape_of_alpha(self) -> None: + """Test the alpha shape not equal to the number of unique classes.""" + num_rows = 113 + partition_by = "labels" + alpha = [1.0, 1.0] + partition_sizes = [20, 20, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + def test_too_big_sum_of_partition_sizes(self) -> None: + """Test sum of partition_sizes greater than the size of the dataset.""" + num_rows = 113 + partition_by = "labels" + alpha = 1.0 + partition_sizes = [60, 60, 30, 43] + + _, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main()