-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(tests): Implements nlp.datasets tests and fixes wrong type anno…
…tations.
- Loading branch information
Showing
12 changed files
with
569 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from archai.nlp.datasets.hf.loaders import ( | ||
DatasetDict, | ||
DownloadMode, | ||
IterableDatasetDict, | ||
_should_refresh_cache, | ||
encode_dataset, | ||
load_dataset, | ||
) | ||
|
||
|
||
def test_should_refresh_cache(): | ||
# Test that the function returns FORCE_REDOWNLOAD when refresh is True | ||
assert _should_refresh_cache(True) == DownloadMode.FORCE_REDOWNLOAD | ||
|
||
# Test that the function returns REUSE_DATASET_IF_EXISTS when refresh is False | ||
assert _should_refresh_cache(False) == DownloadMode.REUSE_DATASET_IF_EXISTS | ||
|
||
|
||
def test_load_dataset(): | ||
dataset_name = "wikitext" | ||
dataset_config_name = "wikitext-2-raw-v1" | ||
|
||
# Assert loading dataset from Hugging Face Hub | ||
dataset = load_dataset( | ||
dataset_name=dataset_name, | ||
dataset_config_name=dataset_config_name, | ||
dataset_refresh_cache=True, | ||
) | ||
assert isinstance(dataset, (DatasetDict, IterableDatasetDict)) | ||
|
||
# Assert that subsampling works | ||
n_samples = 10 | ||
dataset = dataset = load_dataset( | ||
dataset_name=dataset_name, | ||
dataset_config_name=dataset_config_name, | ||
dataset_refresh_cache=True, | ||
n_samples=n_samples, | ||
) | ||
assert all(len(split) == n_samples for split in dataset.values()) | ||
|
||
|
||
def test_encode_dataset(): | ||
dataset = load_dataset( | ||
dataset_name="wikitext", dataset_config_name="wikitext-2-raw-v1", dataset_refresh_cache=True, n_samples=10 | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# Assert that dataaset can be encoded | ||
encoded_dataset = encode_dataset(dataset, tokenizer) | ||
assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) | ||
|
||
# Assert that dataset can be encoded with custom mapping function | ||
def custom_mapping_fn(example, tokenizer, mapping_column_name=None): | ||
example["text2"] = example["text"] | ||
return example | ||
|
||
encoded_dataset = encode_dataset(dataset, tokenizer, mapping_fn=custom_mapping_fn) | ||
assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) | ||
|
||
# Assert that dataset can be encoded with multiprocessing | ||
num_proc = 4 | ||
encoded_dataset = encode_dataset(dataset, tokenizer, num_proc=num_proc) | ||
assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from archai.nlp.datasets.hf.loaders import ( | ||
DatasetDict, | ||
IterableDatasetDict, | ||
load_dataset, | ||
) | ||
from archai.nlp.datasets.hf.processors import ( | ||
map_dataset_to_dict, | ||
merge_datasets, | ||
resize_dataset, | ||
shuffle_dataset, | ||
tokenize_contiguous_dataset, | ||
tokenize_dataset, | ||
tokenize_nsp_dataset, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dataset(): | ||
return load_dataset( | ||
dataset_name="wikitext", | ||
dataset_config_name="wikitext-2-raw-v1", | ||
dataset_refresh_cache=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dataset_train_set(): | ||
return load_dataset( | ||
dataset_name="wikitext", | ||
dataset_config_name="wikitext-2-raw-v1", | ||
dataset_refresh_cache=True, | ||
dataset_split="train", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dataset_val_set(): | ||
return load_dataset( | ||
dataset_name="wikitext", | ||
dataset_config_name="wikitext-2-raw-v1", | ||
dataset_refresh_cache=True, | ||
dataset_split="validation", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def iterable_dataset(): | ||
return load_dataset( | ||
dataset_name="wikitext", | ||
dataset_config_name="wikitext-2-raw-v1", | ||
dataset_refresh_cache=True, | ||
dataset_stream=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2", model_max_length=8) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
return tokenizer | ||
|
||
|
||
def test_map_dataset_to_dict(dataset): | ||
# Assert mapping single dataset to dictionary | ||
dataset = dataset["test"] | ||
dataset_dict = map_dataset_to_dict(dataset, "test") | ||
assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict)) | ||
|
||
# Assert mapping multiple datasets to dictionary | ||
datasets = [dataset for _ in range(3)] | ||
dataset_dict = map_dataset_to_dict(datasets, ["test", "test", "test"]) | ||
assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict)) | ||
|
||
|
||
def test_merge_datasets(dataset, dataset_train_set, dataset_val_set, iterable_dataset): | ||
# Assert that dataset can be merged | ||
datasets = [dataset for _ in range(3)] | ||
merged_dataset = merge_datasets(datasets) | ||
assert isinstance(merged_dataset, (DatasetDict, IterableDatasetDict)) | ||
assert len(merged_dataset) == 3 | ||
assert len(list(merged_dataset.values())[0]) == len(list(dataset.values())[0]) * 3 | ||
|
||
# Assert that dataset can not be merged with different splits | ||
datasets = [dataset_train_set, dataset_val_set] | ||
with pytest.raises(AssertionError): | ||
merged_dataset = merge_datasets(datasets) | ||
|
||
# Assert that dataset can not be merged with different types | ||
datasets = [dataset, iterable_dataset] | ||
with pytest.raises(AssertionError): | ||
merged_dataset = merge_datasets(datasets) | ||
|
||
|
||
def test_resize_dataset(dataset): | ||
dataset = dataset["train"] | ||
|
||
# Assert resizing dataset to smaller size | ||
resized_dataset = resize_dataset(dataset, 10) | ||
assert len(resized_dataset) == 10 | ||
|
||
# Assert resizing dataset to larger size | ||
resized_dataset = resize_dataset(dataset, 10000) | ||
assert len(resized_dataset) == 10000 | ||
|
||
# Assert resizing dataset to same size | ||
resized_dataset = resize_dataset(dataset, len(dataset)) | ||
assert len(resized_dataset) == len(dataset) | ||
|
||
|
||
def test_shuffle_dataset(dataset): | ||
dataset = dataset["train"] | ||
|
||
# Assert shuffling dataset with positive seed | ||
shuffled_dataset = shuffle_dataset(dataset, 42) | ||
assert len(shuffled_dataset) == len(dataset) | ||
assert isinstance(shuffled_dataset, type(dataset)) | ||
|
||
# Assert shuffling dataset with negative seed | ||
shuffled_dataset = shuffle_dataset(dataset, -1) | ||
assert len(shuffled_dataset) == len(dataset) | ||
assert isinstance(shuffled_dataset, type(dataset)) | ||
|
||
|
||
def test_tokenize_dataset(tokenizer): | ||
# Assert that examples can be tokenized | ||
examples = {"text": ["Hello, this is a test.", "This is another test."]} | ||
expected_output = { | ||
"input_ids": [[15496, 11, 428, 318, 257, 1332, 13, 50256], [1212, 318, 1194, 1332, 13, 50256, 50256, 50256]], | ||
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0, 0, 0]], | ||
} | ||
|
||
output = tokenize_dataset(examples, tokenizer=tokenizer) | ||
assert output == expected_output | ||
|
||
|
||
def test_tokenize_contiguous_dataset(tokenizer): | ||
# Assert that examples can be contiguously tokenized | ||
examples = {"text": ["This is a test example.", "This is another test example.", "And yet another test example."]} | ||
output = tokenize_contiguous_dataset(examples, tokenizer=tokenizer, model_max_length=8) | ||
assert len(output["input_ids"][0]) == 8 | ||
assert len(output["input_ids"][1]) == 8 | ||
with pytest.raises(IndexError): | ||
assert len(output["input_ids"][2]) == 8 | ||
|
||
|
||
def test_tokenize_nsp_dataset(tokenizer): | ||
# Assert that a single example can be tokenized | ||
examples = {"text": ["This is a single example."]} | ||
tokenized_examples = tokenize_nsp_dataset(examples, tokenizer=tokenizer) | ||
assert tokenized_examples["next_sentence_label"][0] in [0, 1] | ||
|
||
# Assert that multiple examples can be tokenized | ||
examples = {"text": ["This is the first example.", "This is the second example."]} | ||
tokenized_examples = tokenize_nsp_dataset(examples, tokenizer=tokenizer) | ||
assert tokenized_examples["next_sentence_label"][0] in [0, 1] | ||
assert tokenized_examples["next_sentence_label"][1] in [0, 1] |
59 changes: 59 additions & 0 deletions
59
tests/nlp/datasets/hf/tokenizer_utils/test_hf_token_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from typing import Dict, List | ||
|
||
import pytest | ||
|
||
from archai.nlp.datasets.hf.tokenizer_utils.token_config import TokenConfig | ||
|
||
|
||
@pytest.fixture | ||
def token_config(): | ||
return TokenConfig( | ||
bos_token="<BOS>", | ||
eos_token="<EOS>", | ||
unk_token="<UNK>", | ||
sep_token="<SEP>", | ||
pad_token="<PAD>", | ||
cls_token="<CLS>", | ||
mask_token="<MASK>", | ||
) | ||
|
||
|
||
def test_token_config_special_tokens(token_config): | ||
# Assert that the special tokens are set correctly | ||
special_tokens = token_config.special_tokens | ||
assert isinstance(special_tokens, List) | ||
assert len(special_tokens) == 7 | ||
assert "<BOS>" in special_tokens | ||
assert "<EOS>" in special_tokens | ||
assert "<UNK>" in special_tokens | ||
assert "<SEP>" in special_tokens | ||
assert "<PAD>" in special_tokens | ||
assert "<CLS>" in special_tokens | ||
assert "<MASK>" in special_tokens | ||
|
||
|
||
def test_token_config_to_dict(token_config): | ||
# Assert that the token config is converted to a dictionary correctly | ||
token_dict = token_config.to_dict() | ||
assert isinstance(token_dict, Dict) | ||
assert token_dict["bos_token"] == "<BOS>" | ||
assert token_dict["eos_token"] == "<EOS>" | ||
assert token_dict["unk_token"] == "<UNK>" | ||
assert token_dict["sep_token"] == "<SEP>" | ||
assert token_dict["pad_token"] == "<PAD>" | ||
assert token_dict["cls_token"] == "<CLS>" | ||
assert token_dict["mask_token"] == "<MASK>" | ||
|
||
|
||
def test_token_config_from_file(token_config, tmp_path): | ||
token_config_path = tmp_path / "token_config.json" | ||
token_config.save(str(token_config_path)) | ||
|
||
# Assert that the token config is loaded correctly from a file | ||
loaded_token_config = TokenConfig.from_file(str(token_config_path)) | ||
assert isinstance(loaded_token_config, TokenConfig) | ||
assert loaded_token_config.bos_token == "<BOS>" | ||
assert loaded_token_config.eos_token == "<EOS>" |
Oops, something went wrong.