Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fds add num_partitions property to partitioners #3095

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/iid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
return self.dataset.shard(
num_shards=self._num_partitions, index=node_id, contiguous=True
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
return self._num_partitions
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_partition_sizes_correctness_if_needed()
self._check_the_sum_of_partition_sizes()
self._determine_num_unique_classes_if_needed()
self._alpha = self._initialize_alpha_if_needed(self._initial_alpha)
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha_if_needed(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
lambda row: row[self._partition_by] == self._node_id_to_natural_id[node_id]
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._node_id_to_natural_id) == 0:
self._create_int_node_id_to_natural_id()
return len(self._node_id_to_natural_id)

@property
def node_id_to_natural_id(self) -> Dict[int, str]:
"""Node id to corresponding natural id present.
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def is_dataset_assigned(self) -> bool:
True if a dataset is assigned, otherwise False.
"""
return self._dataset is not None

@property
@abstractmethod
def num_partitions(self) -> int:
"""Total number of partitions."""
9 changes: 9 additions & 0 deletions datasets/flwr_datasets/partitioner/shard_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_possibility_of_partitions_creation()
self._sort_dataset_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914
"""Assign sample indices to each node id.

Expand Down
6 changes: 6 additions & 0 deletions datasets/flwr_datasets/partitioner/size_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

@property
def node_id_to_size(self) -> Dict[int, int]:
"""Node id to the number of samples."""
Expand Down
Loading