From 2dfc0e198f117c72b0c45c1e4ac07756133a4e0a Mon Sep 17 00:00:00 2001 From: Gustavo Rosa Date: Fri, 6 Jan 2023 14:59:04 -0300 Subject: [PATCH] chore(tests): Implements nlp.datasets tests and fixes wrong type annotations. --- archai/nlp/datasets/hf/loaders.py | 2 +- archai/nlp/datasets/hf/processors.py | 14 +- archai/nlp/datasets/nvidia/lm_iterators.py | 4 +- docs/contributing/documentation.rst | 4 +- tests/nlp/datasets/hf/test_loaders.py | 69 ++++++++ tests/nlp/datasets/hf/test_processors.py | 162 ++++++++++++++++++ .../tokenizer_utils/test_hf_token_config.py | 59 +++++++ .../tokenizer_utils/test_hf_tokenizer_base.py | 49 ++++++ .../nlp/datasets/nvidia/test_corpus_utils.py | 37 ++++ .../nlp/datasets/nvidia/test_lm_iterators.py | 25 +++ .../test_nvidia_token_config.py | 53 ++++++ .../tokenizer_utils/test_nvidia_vocab_base.py | 98 +++++++++++ 12 files changed, 569 insertions(+), 7 deletions(-) create mode 100644 tests/nlp/datasets/hf/test_loaders.py create mode 100644 tests/nlp/datasets/hf/test_processors.py create mode 100644 tests/nlp/datasets/hf/tokenizer_utils/test_hf_token_config.py create mode 100644 tests/nlp/datasets/hf/tokenizer_utils/test_hf_tokenizer_base.py create mode 100644 tests/nlp/datasets/nvidia/test_corpus_utils.py create mode 100644 tests/nlp/datasets/nvidia/test_lm_iterators.py create mode 100644 tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_token_config.py create mode 100644 tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_vocab_base.py diff --git a/archai/nlp/datasets/hf/loaders.py b/archai/nlp/datasets/hf/loaders.py index 1bcd56b1d..9bdf2af39 100644 --- a/archai/nlp/datasets/hf/loaders.py +++ b/archai/nlp/datasets/hf/loaders.py @@ -107,7 +107,7 @@ def load_dataset( ) if not isinstance(dataset, (DatasetDict, IterableDatasetDict)): - dataset = map_dataset_to_dict(dataset, splits=dataset_split) + dataset = map_dataset_to_dict(dataset, dataset_split) n_samples_list = map_to_list(n_samples, len(dataset.items())) for split, n_samples in zip(dataset.keys(), n_samples_list): diff --git a/archai/nlp/datasets/hf/processors.py b/archai/nlp/datasets/hf/processors.py index 0386a601c..7579aa9bb 100644 --- a/archai/nlp/datasets/hf/processors.py +++ b/archai/nlp/datasets/hf/processors.py @@ -22,7 +22,7 @@ def map_dataset_to_dict( dataset: Union[Dataset, IterableDataset, List[Dataset], List[IterableDataset]], - splits: Optional[Union[str, List[str]]] = None, + splits: Union[str, List[str]], ) -> Union[DatasetDict, IterableDatasetDict]: """Map a dataset or list of datasets to a dictionary. @@ -155,7 +155,7 @@ def shuffle_dataset(dataset: Union[Dataset, IterableDataset], seed: int) -> Unio def tokenize_dataset( examples: List[str], tokenizer: Optional[Union[AutoTokenizer, ArchaiPreTrainedTokenizerFast]] = None, - mapping_column_name: Optional[Union[str, List[str]]] = "text", + mapping_column_name: Optional[List[str]] = None, truncate: Optional[Union[bool, str]] = True, padding: Optional[Union[bool, str]] = "max_length", **kwargs, @@ -174,6 +174,9 @@ def tokenize_dataset( """ + if mapping_column_name is None: + mapping_column_name = ["text"] + examples_mapping = tuple(examples[column_name] for column_name in mapping_column_name) return tokenizer(*examples_mapping, truncation=truncate, padding=padding) @@ -182,7 +185,7 @@ def tokenize_dataset( def tokenize_contiguous_dataset( examples: List[str], tokenizer: Optional[Union[AutoTokenizer, ArchaiPreTrainedTokenizerFast]] = None, - mapping_column_name: Optional[Union[str, List[str]]] = "text", + mapping_column_name: Optional[List[str]] = None, model_max_length: Optional[int] = 1024, **kwargs, ) -> Dict[str, Any]: @@ -221,7 +224,7 @@ def tokenize_contiguous_dataset( def tokenize_nsp_dataset( examples: List[str], tokenizer: Optional[Union[AutoTokenizer, ArchaiPreTrainedTokenizerFast]] = None, - mapping_column_name: Optional[Union[str, List[str]]] = "text", + mapping_column_name: Optional[List[str]] = None, truncate: Optional[Union[bool, str]] = True, padding: Optional[Union[bool, str]] = "max_length", **kwargs, @@ -241,6 +244,9 @@ def tokenize_nsp_dataset( """ + if mapping_column_name is None: + mapping_column_name = ["text"] + assert len(mapping_column_name) == 1, "`mapping_column_name` must have a single value." examples_mapping = examples[mapping_column_name[0]] diff --git a/archai/nlp/datasets/nvidia/lm_iterators.py b/archai/nlp/datasets/nvidia/lm_iterators.py index e8fc91db5..bd8ead1de 100644 --- a/archai/nlp/datasets/nvidia/lm_iterators.py +++ b/archai/nlp/datasets/nvidia/lm_iterators.py @@ -50,7 +50,9 @@ def __init__( # Divides cleanly the inputs into batches and trims the remaining elements n_step = input_ids.size(0) // bsz input_ids = input_ids[: n_step * bsz] - self.input_ids = input_ids.view(bsz, -1).contiguous().pin_memory() + self.input_ids = input_ids.view(bsz, -1).contiguous() + if device != "cpu": + self.input_ids = self.input_ids.pin_memory() # Creates warmup batches if memory is being used if mem_len and warmup: diff --git a/docs/contributing/documentation.rst b/docs/contributing/documentation.rst index 1bf64842c..2d2b0df3c 100644 --- a/docs/contributing/documentation.rst +++ b/docs/contributing/documentation.rst @@ -1,6 +1,8 @@ Documentation ============= +The Archai project welcomes contributions through the implementation of documentation files using Sphinx and RST. If you are interested in contributing to the project in this way, please follow these steps: + #. Ensure that Sphinx is installed. You can install it using ``pip install archai[docs]``. #. Check out the Archai codebase and create a new branch for your changes. This will allow for easy submission of your code as a pull request upon completion. @@ -29,7 +31,7 @@ Documentation .. tab:: Windows - .. code-block:: sh + .. code-block:: bat cd archai/docs .\make.bat html diff --git a/tests/nlp/datasets/hf/test_loaders.py b/tests/nlp/datasets/hf/test_loaders.py new file mode 100644 index 000000000..8364caaf7 --- /dev/null +++ b/tests/nlp/datasets/hf/test_loaders.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from transformers import AutoTokenizer + +from archai.nlp.datasets.hf.loaders import ( + DatasetDict, + DownloadMode, + IterableDatasetDict, + _should_refresh_cache, + encode_dataset, + load_dataset, +) + + +def test_should_refresh_cache(): + # Test that the function returns FORCE_REDOWNLOAD when refresh is True + assert _should_refresh_cache(True) == DownloadMode.FORCE_REDOWNLOAD + + # Test that the function returns REUSE_DATASET_IF_EXISTS when refresh is False + assert _should_refresh_cache(False) == DownloadMode.REUSE_DATASET_IF_EXISTS + + +def test_load_dataset(): + dataset_name = "wikitext" + dataset_config_name = "wikitext-2-raw-v1" + + # Assert loading dataset from Hugging Face Hub + dataset = load_dataset( + dataset_name=dataset_name, + dataset_config_name=dataset_config_name, + dataset_refresh_cache=True, + ) + assert isinstance(dataset, (DatasetDict, IterableDatasetDict)) + + # Assert that subsampling works + n_samples = 10 + dataset = dataset = load_dataset( + dataset_name=dataset_name, + dataset_config_name=dataset_config_name, + dataset_refresh_cache=True, + n_samples=n_samples, + ) + assert all(len(split) == n_samples for split in dataset.values()) + + +def test_encode_dataset(): + dataset = load_dataset( + dataset_name="wikitext", dataset_config_name="wikitext-2-raw-v1", dataset_refresh_cache=True, n_samples=10 + ) + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + + # Assert that dataaset can be encoded + encoded_dataset = encode_dataset(dataset, tokenizer) + assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) + + # Assert that dataset can be encoded with custom mapping function + def custom_mapping_fn(example, tokenizer, mapping_column_name=None): + example["text2"] = example["text"] + return example + + encoded_dataset = encode_dataset(dataset, tokenizer, mapping_fn=custom_mapping_fn) + assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) + + # Assert that dataset can be encoded with multiprocessing + num_proc = 4 + encoded_dataset = encode_dataset(dataset, tokenizer, num_proc=num_proc) + assert isinstance(encoded_dataset, (DatasetDict, IterableDatasetDict)) diff --git a/tests/nlp/datasets/hf/test_processors.py b/tests/nlp/datasets/hf/test_processors.py new file mode 100644 index 000000000..0ab46b126 --- /dev/null +++ b/tests/nlp/datasets/hf/test_processors.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +from transformers import AutoTokenizer + +from archai.nlp.datasets.hf.loaders import ( + DatasetDict, + IterableDatasetDict, + load_dataset, +) +from archai.nlp.datasets.hf.processors import ( + map_dataset_to_dict, + merge_datasets, + resize_dataset, + shuffle_dataset, + tokenize_contiguous_dataset, + tokenize_dataset, + tokenize_nsp_dataset, +) + + +@pytest.fixture +def dataset(): + return load_dataset( + dataset_name="wikitext", + dataset_config_name="wikitext-2-raw-v1", + dataset_refresh_cache=True, + ) + + +@pytest.fixture +def dataset_train_set(): + return load_dataset( + dataset_name="wikitext", + dataset_config_name="wikitext-2-raw-v1", + dataset_refresh_cache=True, + dataset_split="train", + ) + + +@pytest.fixture +def dataset_val_set(): + return load_dataset( + dataset_name="wikitext", + dataset_config_name="wikitext-2-raw-v1", + dataset_refresh_cache=True, + dataset_split="validation", + ) + + +@pytest.fixture +def iterable_dataset(): + return load_dataset( + dataset_name="wikitext", + dataset_config_name="wikitext-2-raw-v1", + dataset_refresh_cache=True, + dataset_stream=True, + ) + + +@pytest.fixture +def tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", model_max_length=8) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def test_map_dataset_to_dict(dataset): + # Assert mapping single dataset to dictionary + dataset = dataset["test"] + dataset_dict = map_dataset_to_dict(dataset, "test") + assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict)) + + # Assert mapping multiple datasets to dictionary + datasets = [dataset for _ in range(3)] + dataset_dict = map_dataset_to_dict(datasets, ["test", "test", "test"]) + assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict)) + + +def test_merge_datasets(dataset, dataset_train_set, dataset_val_set, iterable_dataset): + # Assert that dataset can be merged + datasets = [dataset for _ in range(3)] + merged_dataset = merge_datasets(datasets) + assert isinstance(merged_dataset, (DatasetDict, IterableDatasetDict)) + assert len(merged_dataset) == 3 + assert len(list(merged_dataset.values())[0]) == len(list(dataset.values())[0]) * 3 + + # Assert that dataset can not be merged with different splits + datasets = [dataset_train_set, dataset_val_set] + with pytest.raises(AssertionError): + merged_dataset = merge_datasets(datasets) + + # Assert that dataset can not be merged with different types + datasets = [dataset, iterable_dataset] + with pytest.raises(AssertionError): + merged_dataset = merge_datasets(datasets) + + +def test_resize_dataset(dataset): + dataset = dataset["train"] + + # Assert resizing dataset to smaller size + resized_dataset = resize_dataset(dataset, 10) + assert len(resized_dataset) == 10 + + # Assert resizing dataset to larger size + resized_dataset = resize_dataset(dataset, 10000) + assert len(resized_dataset) == 10000 + + # Assert resizing dataset to same size + resized_dataset = resize_dataset(dataset, len(dataset)) + assert len(resized_dataset) == len(dataset) + + +def test_shuffle_dataset(dataset): + dataset = dataset["train"] + + # Assert shuffling dataset with positive seed + shuffled_dataset = shuffle_dataset(dataset, 42) + assert len(shuffled_dataset) == len(dataset) + assert isinstance(shuffled_dataset, type(dataset)) + + # Assert shuffling dataset with negative seed + shuffled_dataset = shuffle_dataset(dataset, -1) + assert len(shuffled_dataset) == len(dataset) + assert isinstance(shuffled_dataset, type(dataset)) + + +def test_tokenize_dataset(tokenizer): + # Assert that examples can be tokenized + examples = {"text": ["Hello, this is a test.", "This is another test."]} + expected_output = { + "input_ids": [[15496, 11, 428, 318, 257, 1332, 13, 50256], [1212, 318, 1194, 1332, 13, 50256, 50256, 50256]], + "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0, 0, 0]], + } + + output = tokenize_dataset(examples, tokenizer=tokenizer) + assert output == expected_output + + +def test_tokenize_contiguous_dataset(tokenizer): + # Assert that examples can be contiguously tokenized + examples = {"text": ["This is a test example.", "This is another test example.", "And yet another test example."]} + output = tokenize_contiguous_dataset(examples, tokenizer=tokenizer, model_max_length=8) + assert len(output["input_ids"][0]) == 8 + assert len(output["input_ids"][1]) == 8 + with pytest.raises(IndexError): + assert len(output["input_ids"][2]) == 8 + + +def test_tokenize_nsp_dataset(tokenizer): + # Assert that a single example can be tokenized + examples = {"text": ["This is a single example."]} + tokenized_examples = tokenize_nsp_dataset(examples, tokenizer=tokenizer) + assert tokenized_examples["next_sentence_label"][0] in [0, 1] + + # Assert that multiple examples can be tokenized + examples = {"text": ["This is the first example.", "This is the second example."]} + tokenized_examples = tokenize_nsp_dataset(examples, tokenizer=tokenizer) + assert tokenized_examples["next_sentence_label"][0] in [0, 1] + assert tokenized_examples["next_sentence_label"][1] in [0, 1] diff --git a/tests/nlp/datasets/hf/tokenizer_utils/test_hf_token_config.py b/tests/nlp/datasets/hf/tokenizer_utils/test_hf_token_config.py new file mode 100644 index 000000000..5754ad0d3 --- /dev/null +++ b/tests/nlp/datasets/hf/tokenizer_utils/test_hf_token_config.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Dict, List + +import pytest + +from archai.nlp.datasets.hf.tokenizer_utils.token_config import TokenConfig + + +@pytest.fixture +def token_config(): + return TokenConfig( + bos_token="", + eos_token="", + unk_token="", + sep_token="", + pad_token="", + cls_token="", + mask_token="", + ) + + +def test_token_config_special_tokens(token_config): + # Assert that the special tokens are set correctly + special_tokens = token_config.special_tokens + assert isinstance(special_tokens, List) + assert len(special_tokens) == 7 + assert "" in special_tokens + assert "" in special_tokens + assert "" in special_tokens + assert "" in special_tokens + assert "" in special_tokens + assert "" in special_tokens + assert "" in special_tokens + + +def test_token_config_to_dict(token_config): + # Assert that the token config is converted to a dictionary correctly + token_dict = token_config.to_dict() + assert isinstance(token_dict, Dict) + assert token_dict["bos_token"] == "" + assert token_dict["eos_token"] == "" + assert token_dict["unk_token"] == "" + assert token_dict["sep_token"] == "" + assert token_dict["pad_token"] == "" + assert token_dict["cls_token"] == "" + assert token_dict["mask_token"] == "" + + +def test_token_config_from_file(token_config, tmp_path): + token_config_path = tmp_path / "token_config.json" + token_config.save(str(token_config_path)) + + # Assert that the token config is loaded correctly from a file + loaded_token_config = TokenConfig.from_file(str(token_config_path)) + assert isinstance(loaded_token_config, TokenConfig) + assert loaded_token_config.bos_token == "" + assert loaded_token_config.eos_token == "" diff --git a/tests/nlp/datasets/hf/tokenizer_utils/test_hf_tokenizer_base.py b/tests/nlp/datasets/hf/tokenizer_utils/test_hf_tokenizer_base.py new file mode 100644 index 000000000..9d4b8d4d1 --- /dev/null +++ b/tests/nlp/datasets/hf/tokenizer_utils/test_hf_tokenizer_base.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +import pytest +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer + +from archai.nlp.datasets.hf.tokenizer_utils.token_config import TokenConfig +from archai.nlp.datasets.hf.tokenizer_utils.tokenizer_base import TokenizerBase + + +@pytest.fixture +def token_config(): + return TokenConfig( + bos_token="", + eos_token="", + unk_token="", + sep_token="", + pad_token="", + cls_token="", + mask_token="", + ) + + +@pytest.fixture +def tokenizer(): + return Tokenizer(BPE()) + + +@pytest.fixture +def trainer(): + return BpeTrainer() + + +def test_tokenizer_base(token_config, tokenizer, trainer): + # Assert that the tokenizer base is initialized correctly + tokenizer_base = TokenizerBase(token_config, tokenizer, trainer) + assert isinstance(tokenizer_base, TokenizerBase) + + # Assert that the tokenizer can be saved + tokenizer_base.save("tokenizer.json") + assert os.path.exists("tokenizer.json") + assert os.path.exists("token_config.json") + + os.remove("tokenizer.json") + os.remove("token_config.json") diff --git a/tests/nlp/datasets/nvidia/test_corpus_utils.py b/tests/nlp/datasets/nvidia/test_corpus_utils.py new file mode 100644 index 000000000..ad45f9464 --- /dev/null +++ b/tests/nlp/datasets/nvidia/test_corpus_utils.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import tempfile + +import pytest + +from archai.nlp.datasets.nvidia.corpus_utils import create_dirs, get_dataset_dir_name + + +def test_get_dataset_dir_name(): + # Assert that the correct dataset directory name is returned for supported datasets + assert get_dataset_dir_name("wt2") == "wikitext-2" + assert get_dataset_dir_name("wt103") == "wikitext-103" + assert get_dataset_dir_name("lm1b") == "one-billion-words" + assert get_dataset_dir_name("olx_jobs") == "olx_jobs" + + # Assert that a RuntimeError is raised for unsupported datasets + with pytest.raises(RuntimeError): + get_dataset_dir_name("unsupported_dataset") + + +def test_create_dirs(): + experiment_name = "experiment" + output_dir = tempfile.mkdtemp() + dataroot = tempfile.mkdtemp() + dataset_name = "wt2" + + # Assert that the function creates the expected directories + dataset_dir, output_dir, pretrained_path, cache_dir = create_dirs( + dataroot, dataset_name, experiment_name=experiment_name, output_dir=output_dir + ) + assert os.path.isdir(dataset_dir) + assert os.path.isdir(output_dir) + assert os.path.isdir(cache_dir) + assert pretrained_path == "" diff --git a/tests/nlp/datasets/nvidia/test_lm_iterators.py b/tests/nlp/datasets/nvidia/test_lm_iterators.py new file mode 100644 index 000000000..fa90740b5 --- /dev/null +++ b/tests/nlp/datasets/nvidia/test_lm_iterators.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +import torch + +from archai.nlp.datasets.nvidia.lm_iterators import LMOrderedIterator + + +@pytest.fixture +def iterator(): + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) + return LMOrderedIterator(input_ids, bsz=1, bptt=5) + + +def test_lm_ordered_iterator(iterator): + # Assert that the iterator has the correct number of batches + assert iterator.n_batch == 4 + + # Assert that the batches can be iterated over + for inputs, labels, seq_len, warmup in iterator: + assert inputs.shape == (1, 5) + assert labels.shape == (1, 5) + assert seq_len == 5 + assert warmup is True diff --git a/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_token_config.py b/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_token_config.py new file mode 100644 index 000000000..1b390b712 --- /dev/null +++ b/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_token_config.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from archai.nlp.datasets.nvidia.tokenizer_utils.token_config import ( + SpecialTokenEnum, + TokenConfig, +) + + +@pytest.fixture +def token_config(): + return TokenConfig( + bos_token="", + eos_token="", + unk_token="", + pad_token="", + add_prefix_space=False, + add_prefix_new_line=True, + lower_case=True, + ) + + +def test_special_token_enum(): + # Assert that the correct values are assigned to the special tokens enumerator + assert SpecialTokenEnum.UNK.value == 0 + assert SpecialTokenEnum.BOS.value == 1 + assert SpecialTokenEnum.EOS.value == 2 + assert SpecialTokenEnum.PAD.value == 3 + assert SpecialTokenEnum.MASK.value == 4 + + +def test_token_config(token_config): + # Assert that the correct values are assigned to the special tokens + assert token_config.bos_token == "" + assert token_config.eos_token == "" + assert token_config.unk_token == "" + assert token_config.pad_token == "" + assert token_config.add_prefix_space is False + assert token_config.add_prefix_new_line is True + assert token_config.lower_case is True + + # Assert that the special tokens are added to the special token list + special_tokens = token_config.get_special_tokens() + assert special_tokens == ["", "", "", ""] + + # Assert that the special tokens names are returned correctly + assert token_config.special_token_name(SpecialTokenEnum.BOS) == "" + assert token_config.special_token_name(SpecialTokenEnum.EOS) == "" + assert token_config.special_token_name(SpecialTokenEnum.UNK) == "" + assert token_config.special_token_name(SpecialTokenEnum.PAD) == "" + assert token_config.special_token_name("invalid") is None diff --git a/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_vocab_base.py b/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_vocab_base.py new file mode 100644 index 000000000..106ed73d5 --- /dev/null +++ b/tests/nlp/datasets/nvidia/tokenizer_utils/test_nvidia_vocab_base.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +from overrides import overrides + +from archai.nlp.datasets.nvidia.tokenizer_utils.token_config import SpecialTokenEnum +from archai.nlp.datasets.nvidia.tokenizer_utils.vocab_base import VocabBase + + +@pytest.fixture +def vocab_base(): + class Vocab(VocabBase): + def __init__(self): + self.is_trained_value = False + + def __len__(self): + return 100 + + @overrides + def train(self, filepaths): + self.is_trained_value = True + + @overrides + def is_trained(self): + return self.is_trained_value + + @overrides + def load(self): + self.is_trained_value = True + + @overrides + def encode_text(self, text): + return [1, 2, 3] + + @overrides + def decode_text(self, ids): + return "decoded" + + @overrides + def special_token_id(self, sp): + if sp == SpecialTokenEnum.BOS: + return 1 + if sp == SpecialTokenEnum.EOS: + return 2 + if sp == SpecialTokenEnum.UNK: + return 3 + if sp == SpecialTokenEnum.PAD: + return 4 + return None + + @overrides + def token_to_id(self, t): + return 5 + + @overrides + def id_to_token(self, id): + return "token" + + return Vocab() + + +def test_vocab_base_len(vocab_base): + assert len(vocab_base) == 100 + + +def test_vocab_base_train(vocab_base): + vocab_base.train(["file1", "file2"]) + assert vocab_base.is_trained() is True + + +def test_vocab_base_load(vocab_base): + vocab_base.load() + assert vocab_base.is_trained() is True + + +def test_vocab_base_encode_text(vocab_base): + assert vocab_base.encode_text("test") == [1, 2, 3] + + +def test_vocab_base_decode_text(vocab_base): + assert vocab_base.decode_text([1, 2, 3]) == "decoded" + + +def test_vocab_base_special_token_id(vocab_base): + assert vocab_base.special_token_id(SpecialTokenEnum.BOS) == 1 + assert vocab_base.special_token_id(SpecialTokenEnum.EOS) == 2 + assert vocab_base.special_token_id(SpecialTokenEnum.UNK) == 3 + assert vocab_base.special_token_id(SpecialTokenEnum.PAD) == 4 + assert vocab_base.special_token_id("invalid") is None + + +def test_vocab_base_token_to_id(vocab_base): + assert vocab_base.token_to_id("test") == 5 + + +def test_vocab_base_id_to_token(vocab_base): + assert vocab_base.id_to_token(5) == "token"