From 9a2c13b0af21d1a0eb24f164b5a0dc12c6b4ff33 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 15 Jan 2024 15:59:21 +0100 Subject: [PATCH] Add dirichlet partitioner --- .../partitioner/dirichlet_partitioner.py | 264 ++++++++++++++++++ .../partitioner/dirichlet_partitioner_test.py | 168 +++++++++++ 2 files changed, 432 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py new file mode 100644 index 00000000000..de53b0928af --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -0,0 +1,264 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dirichlet partitioner class that works with Hugging Face Datasets.""" +# pylint: disable=R0912 +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat +from flwr_datasets.partitioner.partitioner import Partitioner + + +class DirichletPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on Dirichlet distribution. + + The balancing (not mentioned in paper but implemented in the code) is controlled by + `self_balancing` parameter. + + Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks + https://arxiv.org/abs/1905.12022 + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + min_partition_size : int + The minimum number of samples that each partitions will have (the sampling + process is repeated if any partition is too small). + self_balancing : bool + Weather assign further samples to a partition after the number of samples + exceeded the average number of samples per partition. (True in the original + paper's code although not mentioned in paper itself). + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + alpha: Union[float, List[float], NDArrayFloat], + partition_by: str, + min_partition_size: Optional[int] = None, + self_balancing: bool = True, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._num_partitions = num_partitions + self._check_num_partitions_greater_than_zero() + self._alpha: NDArrayFloat = self._initialize_alpha(alpha) + self._partition_by = partition_by + if min_partition_size is None: + # Note that zero might make problems with the training + min_partition_size = 0 + self._min_partition_size: int = min_partition_size + self._self_balancing = self_balancing + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + + # Utility attributes + # The attributes below are determined during the first call to load_partition + self._num_unique_classes: Optional[int] = None + self._avg_num_of_samples_per_node: Optional[float] = None + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha( + self, alpha: Union[float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if isinstance(alpha, float): + alpha = np.array([alpha], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, List): + if len(alpha) != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to the " + "num_partitions." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to" + "the num_partitions." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_partitions: + raise ValueError( + "The alpha parameter needs to be of length of equal to " + "the num_partitions." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Generate information needed for Dirichlet partitioning + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + self._num_unique_classes = len(self._unique_classes) + # This is needed only if self._self_balancing is True (the default option) + self._avg_num_of_samples_per_node = self.dataset.num_rows / self._num_partitions + + # Change targets list data type to numpy + targets = np.array(self.dataset[self._partition_by]) + + # Repeat the sampling procedure based on the Dirichlet distribution until the + # min_partition_size is reached. + while True: + # Prepare data structure to store indices assigned to node ids + node_id_to_indices: Dict[int, List[int]] = {} + for nid in range(self._num_partitions): + node_id_to_indices[nid] = [] + + # Iterated over all unique labels (they are not necessarily of type int) + for k in self._unique_classes: + # Access all the indices associated with class k + indices_representing_class_k = np.nonzero(targets == k)[0] + # Determine division (the fractions) of the data representing class k + # among the partitions + class_k_division_proportions = self._rng.dirichlet(self._alpha) + nid_to_proportion_of_k_samples = {} + for nid in range(self._num_partitions): + nid_to_proportion_of_k_samples[nid] = class_k_division_proportions[ + nid + ] + # Balancing (not mentioned in the paper but implemented) + # Do not assign additional samples to the node if it already has more + # than the average numbers of samples per partition. Note that it might + # especially affect classes that are later in the order. This is the + # reason for more sparse division that the alpha might suggest. + if self._self_balancing: + assert self._avg_num_of_samples_per_node is not None + for nid in nid_to_proportion_of_k_samples.copy(): + if ( + len(node_id_to_indices[nid]) + > self._avg_num_of_samples_per_node + ): + nid_to_proportion_of_k_samples[nid] = 0 + + # Normalize the proportions such that they sum up to 1 + sum_proportions = sum(nid_to_proportion_of_k_samples.values()) + for nid, prop in nid_to_proportion_of_k_samples.copy().items(): + nid_to_proportion_of_k_samples[nid] = prop / sum_proportions + + # Determine the split indices + cumsum_division_fractions = np.cumsum( + list(nid_to_proportion_of_k_samples.values()) + ) + cumsum_division_numbers = cumsum_division_fractions * len( + indices_representing_class_k + ) + # [:-1] is because the np.split requires the division indices but the + # last element represents the sum = total number of samples + indices_on_which_split = cumsum_division_numbers.astype(int)[:-1] + + split_indices = np.split( + indices_representing_class_k, indices_on_which_split + ) + + # Append new indices (coming from class k) to the existing indices + for nid, indices in node_id_to_indices.items(): + indices.extend(split_indices[nid].tolist()) + + # Determine if the indices assignment meets the min_partition_size + # If it does not mean the requirement repeat the Dirichlet sampling process + # Otherwise break the while loop + min_sample_size_on_client = min( + len(indices) for indices in node_id_to_indices.values() + ) + if min_sample_size_on_client >= self._min_partition_size: + break + + # Shuffle the indices not to have the datasets with targets in sequences like + # [00000, 11111, ...]) if the shuffle is True + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py new file mode 100644 index 00000000000..d8cfb3cb854 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -0,0 +1,168 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" +# pylint: disable=W0212 +import unittest +from typing import Tuple, Union + +import numpy as np +from numpy.typing import NDArray +from parameterized import parameterized + +from datasets import Dataset +from flwr_datasets.partitioner.dirichlet_partitioner import DirichletPartitioner + + +def _dummy_setup( + num_partitions: int, + alpha: Union[float, NDArray[np.float_]], + num_rows: int, + partition_by: str, + self_balancing: bool = True, +) -> Tuple[Dataset, DirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = DirichletPartitioner( + num_partitions=num_partitions, + alpha=alpha, + partition_by=partition_by, + self_balancing=self_balancing, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestDirichletPartitionerSuccess(unittest.TestCase): + """Test DirichletPartitioner used with no exceptions.""" + + @parameterized.expand( # type: ignore + [ + # num_partitions, alpha, num_rows, partition_by + (3, 0.5, 100, "labels"), + (5, 1.0, 150, "labels"), + ] + ) + def test_valid_initialization( + self, num_partitions: int, alpha: float, num_rows: int, partition_by: str + ) -> None: + """Test if alpha is correct scaled based on the given num_partitions.""" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + self.assertEqual( + ( + partitioner._num_partitions, + len(partitioner._alpha), + partitioner._partition_by, + ), + (num_partitions, num_partitions, partition_by), + ) + + def test_min_partition_size_requirement(self) -> None: + """Test if partitions are created with min partition size required.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + partition_list = [partitioner.load_partition(node_id) for node_id in [0, 1, 2]] + self.assertTrue( + all(len(p) > partitioner._min_partition_size for p in partition_list) + ) + + def test_alpha_in_ndarray_initialization(self) -> None: + """Test alpha does not change when in NDArrayFloat format.""" + _, partitioner = _dummy_setup(3, np.array([1.0, 1.0, 1.0]), 100, "labels") + self.assertTrue(np.all(partitioner._alpha == np.array([1.0, 1.0, 1.0]))) + + def test__determine_node_id_to_indices(self) -> None: + """Test the determine_nod_id_to_indices matches the flag after the call.""" + num_partitions, alpha, num_rows, partition_by = 3, 0.5, 100, "labels" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + partitioner._determine_node_id_to_indices_if_needed() + self.assertTrue( + partitioner._node_id_to_indices_determined + and len(partitioner._node_id_to_indices) == num_partitions + ) + + +class TestDirichletPartitionerFailure(unittest.TestCase): + """Test DirichletPartitioner failures (exceptions) by incorrect usage.""" + + @parameterized.expand([(-2,), (-1,), (3,), (4,), (100,)]) # type: ignore + def test_load_invalid_partition_index(self, partition_id): + """Test if raises when the load_partition is above the num_partitions.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + with self.assertRaises(KeyError): + partitioner.load_partition(partition_id) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + (-0.5, 1), + (-0.5, 2), + (-0.5, 3), + (-0.5, 10), + ([0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, 0.5], 3), + ([-0.5, -0.5, -0.5], 3), + ([0.5, 0.5, -0.5, -0.5, 0.5], 5), + (np.array([0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, 0.5]), 3), + (np.array([-0.5, -0.5, -0.5]), 3), + (np.array([0.5, 0.5, -0.5, -0.5, 0.5]), 5), + ] + ) + def test_negative_values_in_alpha(self, alpha, num_partitions): + """Test if giving the negative value of alpha raises error.""" + num_rows, partition_by = 100, "labels" + with self.assertRaises(ValueError): + _, _ = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + # alpha greater than the num_partitions + ([0.5, 0.5], 1), + ([0.5, 0.5, 0.5], 2), + (np.array([0.5, 0.5]), 1), + (np.array([0.5, 0.5, 0.5]), 2), + (np.array([0.5, 0.5, 0.5, 0.5]), 3), + ] + ) + def test_incorrect_alpha_shape(self, alpha, num_partitions): + """Test alpha list len not matching the num_partitions.""" + with self.assertRaises(ValueError): + DirichletPartitioner( + num_partitions=num_partitions, alpha=alpha, partition_by="labels" + ) + + @parameterized.expand( # type: ignore + [(0,), (-1,), (11,), (100,)] + ) # num_partitions, + def test_invalid_num_partitions(self, num_partitions): + """Test if 0 is invalid num_partitions.""" + with self.assertRaises(ValueError): + _, partitioner = _dummy_setup( + num_partitions=num_partitions, + alpha=1.0, + num_rows=10, + partition_by="labels", + ) + partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main()