Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datasets tests for FDS #2964

Merged
merged 23 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
be4e63a
Add mocking of categories, strings, sentences
adam-narozniak Jan 25, 2024
bf0a986
Add image mocking, DatasetDict creation
adam-narozniak Jan 26, 2024
bbd5844
Clarify the second class test set
adam-narozniak Jan 29, 2024
8863747
Add load_mocked_dataset function
adam-narozniak Jan 29, 2024
f3c655c
Integrate mock testing with real datasets testing
adam-narozniak Jan 29, 2024
b8060e2
Update the partitions size check
adam-narozniak Jan 30, 2024
7e4360c
Add audio mocking speech_commands
adam-narozniak Feb 15, 2024
5765ea2
Make functions private
adam-narozniak Feb 15, 2024
834708d
Fix formatting errors
adam-narozniak Feb 16, 2024
9249c37
Remove __main__
adam-narozniak Feb 16, 2024
5c8c684
Fix missing reference to mock speech commands
adam-narozniak Feb 16, 2024
9d4084d
Update datasets equality function
adam-narozniak Feb 16, 2024
b4203cd
Update mocking unique labels in cifar100
adam-narozniak Mar 6, 2024
6b134d6
Add seed to _generate_artificial_strings method to fix tests
adam-narozniak Mar 8, 2024
b08f66b
Improve docs
adam-narozniak Mar 8, 2024
93ec136
Apply suggestions from code review
adam-narozniak Mar 10, 2024
16a0bff
Merge branch 'main' into fds-add-tests
danieljanes Mar 11, 2024
e789d48
Add _test suffix to mock_utils.py
adam-narozniak Mar 12, 2024
e3a9617
Update the reference to mock_utils
adam-narozniak Mar 13, 2024
64c8458
Merge remote-tracking branch 'origin/main' into fds-add-tests
adam-narozniak Apr 19, 2024
be79f8f
Fix types
adam-narozniak Apr 22, 2024
f9b6881
Fix test for the case there is no test set in the dataset
adam-narozniak Apr 22, 2024
b4664be
Merge branch 'main' into fds-add-tests
danieljanes Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 98 additions & 16 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,70 @@
from typing import Dict, Union
from unittest.mock import Mock, patch

import numpy as np
import pytest
from parameterized import parameterized, parameterized_class

import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.mock_utils_test import _load_mocked_dataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]


@parameterized_class(
("dataset_name", "test_split", "subset"),
[
{"dataset_name": "mnist", "test_split": "test"},
{"dataset_name": "cifar10", "test_split": "test"},
{"dataset_name": "fashion_mnist", "test_split": "test"},
{"dataset_name": "sasha/dog-food", "test_split": "test"},
{"dataset_name": "zh-plus/tiny-imagenet", "test_split": "valid"},
]
# Downloaded
# #Image datasets
("mnist", "test", ""),
("cifar10", "test", ""),
("fashion_mnist", "test", ""),
("sasha/dog-food", "test", ""),
("zh-plus/tiny-imagenet", "valid", ""),
# Text
("scikit-learn/adult-census-income", None, ""),
# Mocked
# #Image
("cifar100", "test", ""),
# Note: there's also the extra split and full_numbers subset
("svhn", "test", "cropped_digits"),
# Text
("sentiment140", "test", ""), # aka twitter
# Audio
("speech_commands", "test", "v0.01"),
],
)
class RealDatasetsFederatedDatasetsTrainTest(unittest.TestCase):
"""Test Real Dataset (MNIST, CIFAR10) in FederatedDatasets."""
class BaseFederatedDatasetsTest(unittest.TestCase):
"""Test Real/Mocked Datasets used in FederatedDatasets.

The setUp method mocks the dataset download via datasets.load_dataset if it is in
the `mocked_datasets` list.
"""

dataset_name = ""
test_split = ""
subset = ""

def setUp(self) -> None:
"""Mock the dataset download prior to each method if needed.

If the `dataset_name` is in the `mocked_datasets` list, then the dataset
download is mocked.
"""
if self.dataset_name in mocked_datasets:
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = _load_mocked_dataset(
self.dataset_name, [200, 100], ["train", self.test_split], self.subset
)

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
patch.stopall()

@parameterized.expand( # type: ignore
[
Expand All @@ -61,14 +102,25 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": train_num_partitions}
)
dataset_partition0 = dataset_fds.load_partition(0, "train")
# Compute the actual partition sizes
partition_sizes = []
for node_id in range(train_num_partitions):
partition_sizes.append(len(dataset_fds.load_partition(node_id, "train")))

# Create the expected sizes of partitions
dataset = datasets.load_dataset(self.dataset_name)
self.assertEqual(
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)
full_train_length = len(dataset["train"])
expected_sizes = []
default_partition_size = full_train_length // train_num_partitions
mod = full_train_length % train_num_partitions
for i in range(train_num_partitions):
expected_sizes.append(default_partition_size + (1 if i < mod else 0))
self.assertEqual(partition_sizes, expected_sizes)

def test_load_split(self) -> None:
"""Test if the load_split works with the correct split name."""
if self.test_split is None:
return
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}
)
Expand All @@ -78,6 +130,8 @@ def test_load_split(self) -> None:

def test_multiple_partitioners(self) -> None:
"""Test if the dataset works when multiple partitioners are specified."""
if self.test_split is None:
return
num_train_partitions = 100
num_test_partitions = 100
dataset_fds = FederatedDataset(
Expand All @@ -97,7 +151,7 @@ def test_multiple_partitioners(self) -> None:

def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:
"""Test if partitions got with and without split args are the same."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition_loaded_with_no_split_arg = fds.load_partition(0)
partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train")
self.assertTrue(
Expand All @@ -109,6 +163,8 @@ def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
if self.test_split is None:
return
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
Expand All @@ -122,6 +178,8 @@ def test_resplit_dataset_into_one(self) -> None:
# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
if self.test_split is None:
return
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
Expand All @@ -138,6 +196,8 @@ def test_resplit_dataset_to_change_names(self) -> None:

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""
if self.test_split is None:
return

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
Expand All @@ -157,8 +217,13 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
self.assertEqual(len(full), dataset_length)


class ArtificialDatasetTest(unittest.TestCase):
"""Test using small artificial dataset, mocked load_dataset."""
class ShufflingResplittingOnArtificialDatasetTest(unittest.TestCase):
"""Test shuffling and resplitting using small artificial dataset.

The purpose of this class is to ensure the order of samples remains as expected.

The load_dataset method is mocked and the artificial dataset is returned.
"""

# pylint: disable=no-self-use
def _dummy_setup(self, train_rows: int = 10, test_rows: int = 5) -> DatasetDict:
Expand Down Expand Up @@ -360,9 +425,26 @@ def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
# Ensure all keys are the same in both rows
if set(row1.keys()) != set(row2.keys()):
return False

# Compare values for each key
for key in row1:
if key == "audio":
# Special handling for 'audio' key
if not all(
[
np.array_equal(row1[key]["array"], row2[key]["array"]),
row1[key]["path"] == row2[key]["path"],
row1[key]["sampling_rate"] == row2[key]["sampling_rate"],
]
):
return False
elif row1[key] != row2[key]:
# Direct comparison for other keys
return False

return True


Expand Down
Loading