From 9a2c13b0af21d1a0eb24f164b5a0dc12c6b4ff33 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 15 Jan 2024 15:59:21 +0100 Subject: [PATCH 01/13] Add dirichlet partitioner --- .../partitioner/dirichlet_partitioner.py | 264 ++++++++++++++++++ .../partitioner/dirichlet_partitioner_test.py | 168 +++++++++++ 2 files changed, 432 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py new file mode 100644 index 00000000000..de53b0928af --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -0,0 +1,264 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dirichlet partitioner class that works with Hugging Face Datasets.""" +# pylint: disable=R0912 +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat +from flwr_datasets.partitioner.partitioner import Partitioner + + +class DirichletPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on Dirichlet distribution. + + The balancing (not mentioned in paper but implemented in the code) is controlled by + `self_balancing` parameter. + + Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks + https://arxiv.org/abs/1905.12022 + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + min_partition_size : int + The minimum number of samples that each partitions will have (the sampling + process is repeated if any partition is too small). + self_balancing : bool + Weather assign further samples to a partition after the number of samples + exceeded the average number of samples per partition. (True in the original + paper's code although not mentioned in paper itself). + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + alpha: Union[float, List[float], NDArrayFloat], + partition_by: str, + min_partition_size: Optional[int] = None, + self_balancing: bool = True, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._num_partitions = num_partitions + self._check_num_partitions_greater_than_zero() + self._alpha: NDArrayFloat = self._initialize_alpha(alpha) + self._partition_by = partition_by + if min_partition_size is None: + # Note that zero might make problems with the training + min_partition_size = 0 + self._min_partition_size: int = min_partition_size + self._self_balancing = self_balancing + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + + # Utility attributes + # The attributes below are determined during the first call to load_partition + self._num_unique_classes: Optional[int] = None + self._avg_num_of_samples_per_node: Optional[float] = None + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha( + self, alpha: Union[float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if isinstance(alpha, float): + alpha = np.array([alpha], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, List): + if len(alpha) != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to the " + "num_partitions." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to" + "the num_partitions." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to " + "the num_partitions." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Generate information needed for Dirichlet partitioning + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + self._num_unique_classes = len(self._unique_classes) + # This is needed only if self._self_balancing is True (the default option) + self._avg_num_of_samples_per_node = self.dataset.num_rows / self._num_partitions + + # Change targets list data type to numpy + targets = np.array(self.dataset[self._partition_by]) + + # Repeat the sampling procedure based on the Dirichlet distribution until the + # min_partition_size is reached. + while True: + # Prepare data structure to store indices assigned to node ids + node_id_to_indices: Dict[int, List[int]] = {} + for nid in range(self._num_partitions): + node_id_to_indices[nid] = [] + + # Iterated over all unique labels (they are not necessarily of type int) + for k in self._unique_classes: + # Access all the indices associated with class k + indices_representing_class_k = np.nonzero(targets == k)[0] + # Determine division (the fractions) of the data representing class k + # among the partitions + class_k_division_proportions = self._rng.dirichlet(self._alpha) + nid_to_proportion_of_k_samples = {} + for nid in range(self._num_partitions): + nid_to_proportion_of_k_samples[nid] = class_k_division_proportions[ + nid + ] + # Balancing (not mentioned in the paper but implemented) + # Do not assign additional samples to the node if it already has more + # than the average numbers of samples per partition. Note that it might + # especially affect classes that are later in the order. This is the + # reason for more sparse division that the alpha might suggest. + if self._self_balancing: + assert self._avg_num_of_samples_per_node is not None + for nid in nid_to_proportion_of_k_samples.copy(): + if ( + len(node_id_to_indices[nid]) + > self._avg_num_of_samples_per_node + ): + nid_to_proportion_of_k_samples[nid] = 0 + + # Normalize the proportions such that they sum up to 1 + sum_proportions = sum(nid_to_proportion_of_k_samples.values()) + for nid, prop in nid_to_proportion_of_k_samples.copy().items(): + nid_to_proportion_of_k_samples[nid] = prop / sum_proportions + + # Determine the split indices + cumsum_division_fractions = np.cumsum( + list(nid_to_proportion_of_k_samples.values()) + ) + cumsum_division_numbers = cumsum_division_fractions * len( + indices_representing_class_k + ) + # [:-1] is because the np.split requires the division indices but the + # last element represents the sum = total number of samples + indices_on_which_split = cumsum_division_numbers.astype(int)[:-1] + + split_indices = np.split( + indices_representing_class_k, indices_on_which_split + ) + + # Append new indices (coming from class k) to the existing indices + for nid, indices in node_id_to_indices.items(): + indices.extend(split_indices[nid].tolist()) + + # Determine if the indices assignment meets the min_partition_size + # If it does not mean the requirement repeat the Dirichlet sampling process + # Otherwise break the while loop + min_sample_size_on_client = min( + len(indices) for indices in node_id_to_indices.values() + ) + if min_sample_size_on_client >= self._min_partition_size: + break + + # Shuffle the indices not to have the datasets with targets in sequences like + # [00000, 11111, ...]) if the shuffle is True + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py new file mode 100644 index 00000000000..d8cfb3cb854 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -0,0 +1,168 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" +# pylint: disable=W0212 +import unittest +from typing import Tuple, Union + +import numpy as np +from numpy.typing import NDArray +from parameterized import parameterized + +from datasets import Dataset +from flwr_datasets.partitioner.dirichlet_partitioner import DirichletPartitioner + + +def _dummy_setup( + num_partitions: int, + alpha: Union[float, NDArray[np.float_]], + num_rows: int, + partition_by: str, + self_balancing: bool = True, +) -> Tuple[Dataset, DirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = DirichletPartitioner( + num_partitions=num_partitions, + alpha=alpha, + partition_by=partition_by, + self_balancing=self_balancing, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestDirichletPartitionerSuccess(unittest.TestCase): + """Test DirichletPartitioner used with no exceptions.""" + + @parameterized.expand( # type: ignore + [ + # num_partitions, alpha, num_rows, partition_by + (3, 0.5, 100, "labels"), + (5, 1.0, 150, "labels"), + ] + ) + def test_valid_initialization( + self, num_partitions: int, alpha: float, num_rows: int, partition_by: str + ) -> None: + """Test if alpha is correct scaled based on the given num_partitions.""" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + self.assertEqual( + ( + partitioner._num_partitions, + len(partitioner._alpha), + partitioner._partition_by, + ), + (num_partitions, num_partitions, partition_by), + ) + + def test_min_partition_size_requirement(self) -> None: + """Test if partitions are created with min partition size required.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + partition_list = [partitioner.load_partition(node_id) for node_id in [0, 1, 2]] + self.assertTrue( + all(len(p) > partitioner._min_partition_size for p in partition_list) + ) + + def test_alpha_in_ndarray_initialization(self) -> None: + """Test alpha does not change when in NDArrayFloat format.""" + _, partitioner = _dummy_setup(3, np.array([1.0, 1.0, 1.0]), 100, "labels") + self.assertTrue(np.all(partitioner._alpha == np.array([1.0, 1.0, 1.0]))) + + def test__determine_node_id_to_indices(self) -> None: + """Test the determine_nod_id_to_indices matches the flag after the call.""" + num_partitions, alpha, num_rows, partition_by = 3, 0.5, 100, "labels" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + partitioner._determine_node_id_to_indices_if_needed() + self.assertTrue( + partitioner._node_id_to_indices_determined + and len(partitioner._node_id_to_indices) == num_partitions + ) + + +class TestDirichletPartitionerFailure(unittest.TestCase): + """Test DirichletPartitioner failures (exceptions) by incorrect usage.""" + + @parameterized.expand([(-2,), (-1,), (3,), (4,), (100,)]) # type: ignore + def test_load_invalid_partition_index(self, partition_id): + """Test if raises when the load_partition is above the num_partitions.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + with self.assertRaises(KeyError): + partitioner.load_partition(partition_id) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + (-0.5, 1), + (-0.5, 2), + (-0.5, 3), + (-0.5, 10), + ([0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, 0.5], 3), + ([-0.5, -0.5, -0.5], 3), + ([0.5, 0.5, -0.5, -0.5, 0.5], 5), + (np.array([0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, 0.5]), 3), + (np.array([-0.5, -0.5, -0.5]), 3), + (np.array([0.5, 0.5, -0.5, -0.5, 0.5]), 5), + ] + ) + def test_negative_values_in_alpha(self, alpha, num_partitions): + """Test if giving the negative value of alpha raises error.""" + num_rows, partition_by = 100, "labels" + with self.assertRaises(ValueError): + _, _ = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + # alpha greater than the num_partitions + ([0.5, 0.5], 1), + ([0.5, 0.5, 0.5], 2), + (np.array([0.5, 0.5]), 1), + (np.array([0.5, 0.5, 0.5]), 2), + (np.array([0.5, 0.5, 0.5, 0.5]), 3), + ] + ) + def test_incorrect_alpha_shape(self, alpha, num_partitions): + """Test alpha list len not matching the num_partitions.""" + with self.assertRaises(ValueError): + DirichletPartitioner( + num_partitions=num_partitions, alpha=alpha, partition_by="labels" + ) + + @parameterized.expand( # type: ignore + [(0,), (-1,), (11,), (100,)] + ) # num_partitions, + def test_invalid_num_partitions(self, num_partitions): + """Test if 0 is invalid num_partitions.""" + with self.assertRaises(ValueError): + _, partitioner = _dummy_setup( + num_partitions=num_partitions, + alpha=1.0, + num_rows=10, + partition_by="labels", + ) + partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() From 36a6a503761db43767a466c14e18f56a310744e7 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 11:13:25 +0100 Subject: [PATCH 02/13] Make DirichletPartitioner available on the package level --- datasets/flwr_datasets/partitioner/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 5e7c86718f6..6a85f8a1174 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -15,6 +15,7 @@ """Flower Datasets Partitioner package.""" +from .dirichlet_partitioner import DirichletPartitioner from .exponential_partitioner import ExponentialPartitioner from .iid_partitioner import IidPartitioner from .linear_partitioner import LinearPartitioner @@ -27,6 +28,7 @@ "IidPartitioner", "Partitioner", "NaturalIdPartitioner", + "DirichletPartitioner", "SizePartitioner", "LinearPartitioner", "SquarePartitioner", From 2f31ef288e6f4fab975375b0df2e03de2cd16474 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 12:49:26 +0100 Subject: [PATCH 03/13] Improve documentation --- .../partitioner/dirichlet_partitioner.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index de53b0928af..ab118ba27a8 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -26,11 +26,20 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 """Partitioner based on Dirichlet distribution. - The balancing (not mentioned in paper but implemented in the code) is controlled by - `self_balancing` parameter. - Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks - https://arxiv.org/abs/1905.12022 + https://arxiv.org/abs/1905.12022. + + The algorithm sequentially divides the data with each label. The fractions of the + data with each label is drawn from Dirichlet distribution and adjusted in case of + balancing. The data is assigned. In case the `min_partition_size` is not satisfied + the algorithm is run again (the fractions will change since it is a random process + even though the alpha stays the same). + + The notion of balancing is explicitly introduced here (not mentioned in paper but + implemented in the code). It is a mechanism that excludes the node from + assigning new samples to it if the current number of samples on that node exceeds + the average number that the node would get in case of even data distribution. + It is controlled by`self_balancing` parameter. Parameters ---------- From bd49e8b9e2b0197e5792d4591c7008c968929974 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 17:21:19 +0100 Subject: [PATCH 04/13] Add example --- .../partitioner/dirichlet_partitioner.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index ab118ba27a8..ba05c3af251 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -45,10 +45,10 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 ---------- num_partitions : int The total number of partitions that the data will be divided into. - alpha : Union[float, List[float], NDArrayFloat] - Concentration parameter to the Dirichlet distribution partition_by : str Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution min_partition_size : int The minimum number of samples that each partitions will have (the sampling process is repeated if any partition is too small). @@ -61,14 +61,31 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 samples assignment to nodes. seed: int Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> + >>> partitioner = DirichletPartitioner(num_partitions=10, partition_by="label", + >>> alpha=0.5, min_partition_size=10, + >>> self_balancing=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 4} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [2134, 2615, 3646, 6011, 6170, 6386, 6715, 7653, 8435, 10235] """ def __init__( # pylint: disable=R0913 self, num_partitions: int, - alpha: Union[float, List[float], NDArrayFloat], partition_by: str, - min_partition_size: Optional[int] = None, + alpha: Union[float, List[float], NDArrayFloat], + min_partition_size: int = 10, self_balancing: bool = True, shuffle: bool = True, seed: Optional[int] = 42, @@ -79,9 +96,6 @@ def __init__( # pylint: disable=R0913 self._check_num_partitions_greater_than_zero() self._alpha: NDArrayFloat = self._initialize_alpha(alpha) self._partition_by = partition_by - if min_partition_size is None: - # Note that zero might make problems with the training - min_partition_size = 0 self._min_partition_size: int = min_partition_size self._self_balancing = self_balancing self._shuffle = shuffle From 54b4cc8ad0e6b91983ef8b3d5267c5039f9b2924 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 18 Jan 2024 10:05:06 +0100 Subject: [PATCH 05/13] Apply suggestions from code review Co-authored-by: Javier --- datasets/flwr_datasets/partitioner/dirichlet_partitioner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index ba05c3af251..f7192a53e08 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 The minimum number of samples that each partitions will have (the sampling process is repeated if any partition is too small). self_balancing : bool - Weather assign further samples to a partition after the number of samples + Whether assign further samples to a partition after the number of samples exceeded the average number of samples per partition. (True in the original paper's code although not mentioned in paper itself). shuffle: bool From 726480f524c33d8394918cb7c463c8dc49af2c9e Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 25 Jan 2024 09:56:58 +0100 Subject: [PATCH 06/13] Enable alpha as int --- .../partitioner/dirichlet_partitioner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index f7192a53e08..75f5f03e7b2 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -47,7 +47,7 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 The total number of partitions that the data will be divided into. partition_by : str Column name of the labels (targets) based on which Dirichlet sampling works. - alpha : Union[float, List[float], NDArrayFloat] + alpha : Union[int, float, List[float], NDArrayFloat] Concentration parameter to the Dirichlet distribution min_partition_size : int The minimum number of samples that each partitions will have (the sampling @@ -84,7 +84,7 @@ def __init__( # pylint: disable=R0913 self, num_partitions: int, partition_by: str, - alpha: Union[float, List[float], NDArrayFloat], + alpha: Union[int, float, List[float], NDArrayFloat], min_partition_size: int = 10, self_balancing: bool = True, shuffle: bool = True, @@ -131,7 +131,7 @@ def load_partition(self, node_id: int) -> datasets.Dataset: return self.dataset.select(self._node_id_to_indices[node_id]) def _initialize_alpha( - self, alpha: Union[float, List[float], NDArrayFloat] + self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: """Convert alpha to the used format in the code a NDArrayFloat. @@ -141,7 +141,7 @@ def _initialize_alpha( Parameters ---------- - alpha : Union[float, List[float], NDArrayFloat] + alpha : Union[int, float, List[float], NDArrayFloat] Concentration parameter to the Dirichlet distribution Returns @@ -149,7 +149,9 @@ def _initialize_alpha( alpha : NDArrayFloat Concentration parameter in a format ready to used in computation. """ - if isinstance(alpha, float): + if isinstance(alpha, int): + alpha = np.array([float(alpha)], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, float): alpha = np.array([alpha], dtype=float).repeat(self._num_partitions) elif isinstance(alpha, List): if len(alpha) != self._num_partitions: From a7722dd68b42d18cd5d59e96315c1438f9690a18 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 25 Jan 2024 09:57:24 +0100 Subject: [PATCH 07/13] Remove unused _num_unique_classes --- datasets/flwr_datasets/partitioner/dirichlet_partitioner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 75f5f03e7b2..dfd8533b479 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -104,7 +104,6 @@ def __init__( # pylint: disable=R0913 # Utility attributes # The attributes below are determined during the first call to load_partition - self._num_unique_classes: Optional[int] = None self._avg_num_of_samples_per_node: Optional[float] = None self._unique_classes: Optional[Union[List[int], List[str]]] = None self._node_id_to_indices: Dict[int, List[int]] = {} @@ -191,7 +190,6 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 # Generate information needed for Dirichlet partitioning self._unique_classes = self.dataset.unique(self._partition_by) assert self._unique_classes is not None - self._num_unique_classes = len(self._unique_classes) # This is needed only if self._self_balancing is True (the default option) self._avg_num_of_samples_per_node = self.dataset.num_rows / self._num_partitions From d4fc8d88498d3e848cede03bdb68cdb4591872d6 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 25 Jan 2024 10:14:00 +0100 Subject: [PATCH 08/13] Add warning on the repeated sampling process --- .../flwr_datasets/partitioner/dirichlet_partitioner.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index dfd8533b479..e324cc99f04 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -21,7 +21,7 @@ import datasets from flwr_datasets.common.typing import NDArrayFloat from flwr_datasets.partitioner.partitioner import Partitioner - +import warnings class DirichletPartitioner(Partitioner): # pylint: disable=R0902 """Partitioner based on Dirichlet distribution. @@ -198,6 +198,7 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 # Repeat the sampling procedure based on the Dirichlet distribution until the # min_partition_size is reached. + sampling_try = 0 while True: # Prepare data structure to store indices assigned to node ids node_id_to_indices: Dict[int, List[int]] = {} @@ -262,6 +263,13 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 ) if min_sample_size_on_client >= self._min_partition_size: break + warnings.warn(f"The specified min_partition_size of the create the " + f"partitions was not satisfied as the direct result of the " + f"{sampling_try} st/nd/rd/th sampling from the Dirichlet " + f"distribution. The probability sampling from the Dirichlet " + f"distribution will be repeated. Note: It is not a desired " + f"behavior. It is recommended to adjust the alpha or " + f"min_partition_size instead.") # Shuffle the indices not to have the datasets with targets in sequences like # [00000, 11111, ...]) if the shuffle is True From d47cd9f767e9bf588fbeb03b94c7520100ee1d11 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 25 Jan 2024 10:16:06 +0100 Subject: [PATCH 09/13] Fix tries counter --- datasets/flwr_datasets/partitioner/dirichlet_partitioner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index e324cc99f04..fe84d225c7e 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -270,6 +270,7 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 f"distribution will be repeated. Note: It is not a desired " f"behavior. It is recommended to adjust the alpha or " f"min_partition_size instead.") + sampling_try += 1 # Shuffle the indices not to have the datasets with targets in sequences like # [00000, 11111, ...]) if the shuffle is True From 506ed27aba84ac94b851c5131514a1224b46da1b Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 25 Jan 2024 10:21:34 +0100 Subject: [PATCH 10/13] Fix formatting --- .../partitioner/dirichlet_partitioner.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index fe84d225c7e..9474e061539 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Dirichlet partitioner class that works with Hugging Face Datasets.""" -# pylint: disable=R0912 +import warnings from typing import Dict, List, Optional, Union import numpy as np @@ -21,9 +21,10 @@ import datasets from flwr_datasets.common.typing import NDArrayFloat from flwr_datasets.partitioner.partitioner import Partitioner -import warnings -class DirichletPartitioner(Partitioner): # pylint: disable=R0902 + +# pylint: disable=R0902, R0912 +class DirichletPartitioner(Partitioner): """Partitioner based on Dirichlet distribution. Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks @@ -263,13 +264,16 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 ) if min_sample_size_on_client >= self._min_partition_size: break - warnings.warn(f"The specified min_partition_size of the create the " - f"partitions was not satisfied as the direct result of the " - f"{sampling_try} st/nd/rd/th sampling from the Dirichlet " - f"distribution. The probability sampling from the Dirichlet " - f"distribution will be repeated. Note: It is not a desired " - f"behavior. It is recommended to adjust the alpha or " - f"min_partition_size instead.") + warnings.warn( + f"The specified min_partition_size of the create the " + f"partitions was not satisfied as the direct result of the " + f"{sampling_try} st/nd/rd/th sampling from the Dirichlet " + f"distribution. The probability sampling from the Dirichlet " + f"distribution will be repeated. Note: It is not a desired " + f"behavior. It is recommended to adjust the alpha or " + f"min_partition_size instead.", + stacklevel=1, + ) sampling_try += 1 # Shuffle the indices not to have the datasets with targets in sequences like From f3c284d5172efeaea43f9565e1b68117f2b45d8b Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Fri, 16 Feb 2024 11:05:32 +0100 Subject: [PATCH 11/13] Update datasets/flwr_datasets/partitioner/dirichlet_partitioner.py Co-authored-by: Javier --- .../partitioner/dirichlet_partitioner.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 9474e061539..78b2ee6c710 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -264,14 +264,30 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 ) if min_sample_size_on_client >= self._min_partition_size: break + sample_sizes = [len(indices) for indices in node_id_to_indices.values()] + alpha_not_met = [ + self._alpha[i] + for i, ss in enumerate(sample_sizes) + if ss == min(sample_sizes) + ] + mssg_list_alphas = ( + ( + " Generating partitions by sampling from a list of very" + "dispair alpha values can be hard to achieve. " + f"Try reducing the range between maximum ({max(self._alpha)}) and " + f"minimum alpha ({min(self._alpha)}) values." + ) + if len(self._alpha.flatten().tolist()) > 0 + else "" + ) warnings.warn( - f"The specified min_partition_size of the create the " - f"partitions was not satisfied as the direct result of the " - f"{sampling_try} st/nd/rd/th sampling from the Dirichlet " + f"The specified min_partition_size ({self._min_partition_size}) was " + f"not satisfied for alpha ({alpha_not_met}) after " + f"{sampling_try} attempts at sampling from the Dirichlet " f"distribution. The probability sampling from the Dirichlet " - f"distribution will be repeated. Note: It is not a desired " + f"distribution will be repeated. Note: This is not a desired " f"behavior. It is recommended to adjust the alpha or " - f"min_partition_size instead.", + f"min_partition_size instead. {mssg_list_alphas}", stacklevel=1, ) sampling_try += 1 From ae4031d9a877479cab403c6f2b697ce1caa25e53 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Fri, 16 Feb 2024 11:09:36 +0100 Subject: [PATCH 12/13] Raise error when 10 sampling tries is reached --- .../partitioner/dirichlet_partitioner.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 78b2ee6c710..aad90697f7b 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Dirichlet partitioner class that works with Hugging Face Datasets.""" + + import warnings from typing import Dict, List, Optional, Union @@ -272,10 +274,10 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 ] mssg_list_alphas = ( ( - " Generating partitions by sampling from a list of very" - "dispair alpha values can be hard to achieve. " - f"Try reducing the range between maximum ({max(self._alpha)}) and " - f"minimum alpha ({min(self._alpha)}) values." + "Generating partitions by sampling from a list of very wide range " + "of alpha values can be hard to achieve. Try reducing the range " + f"between maximum ({max(self._alpha)}) and minimum alpha " + f"({min(self._alpha)}) values or increasing all the values." ) if len(self._alpha.flatten().tolist()) > 0 else "" @@ -290,6 +292,11 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 f"min_partition_size instead. {mssg_list_alphas}", stacklevel=1, ) + if sampling_try == 10: + raise ValueError( + "The max number of attempts (10) was reached. " + "Please update the values of alpha and try again." + ) sampling_try += 1 # Shuffle the indices not to have the datasets with targets in sequences like From 5a3f45ce5981debe389a1e623e7d7d91fe9722ef Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:39:36 +0100 Subject: [PATCH 13/13] Apply suggestions from code review Co-authored-by: Javier --- .../flwr_datasets/partitioner/dirichlet_partitioner.py | 10 +++++----- .../partitioner/dirichlet_partitioner_test.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index aad90697f7b..5f1df71991b 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -158,7 +158,7 @@ def _initialize_alpha( elif isinstance(alpha, List): if len(alpha) != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to the " + "If passing alpha as a List, it needs to be of length of equal to " "num_partitions." ) alpha = np.asarray(alpha) @@ -166,15 +166,15 @@ def _initialize_alpha( # pylint: disable=R1720 if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to" - "the num_partitions." + "If passing alpha as an NDArray, its length needs to be of length " + "equal to num_partitions." ) elif alpha.ndim == 2: alpha = alpha.flatten() if alpha.shape[0] != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to " - "the num_partitions." + "If passing alpha as an NDArray, its size needs to be of length" + " equal to num_partitions." ) else: raise ValueError("The given alpha format is not supported.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py index d8cfb3cb854..c123f84effb 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Test DirichletPartitioner.""" + + # pylint: disable=W0212 import unittest from typing import Tuple, Union