-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Add Hugging Face datasets (#344)
* Add HuggingFace datasets Co-authored-by: Danny Farah <[email protected]> Co-authored-by: Kevin Koga <[email protected]> Co-authored-by: Mate Scharnitzky <[email protected]> Co-authored-by: Tomer Shor <[email protected]> Co-authored-by: Pierre-Yves Mousset <[email protected]> Co-authored-by: Bela Chupal <[email protected]> Co-authored-by: Khangjrakpam Arjun <[email protected]> Co-authored-by: Juan Luis Cano Rodríguez <[email protected]> Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Apply suggestions from code review Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> Co-authored-by: Joel <[email protected]> Co-authored-by: Nok Lam Chan <[email protected]> * Typo Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Fix docstring Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Add docstring for HFTransformerPipelineDataset Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Use intersphinx for cross references in Hugging Face docstrings Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Add docstring for HFDataset Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Add missing test dependencies Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Add tests for huggingface datasets Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Fix HFDataset.save Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Add test for HFDataset.list_datasets Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Use new name Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> * Consolidate imports Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> --------- Signed-off-by: Juan Luis Cano Rodríguez <[email protected]> Co-authored-by: Danny Farah <[email protected]> Co-authored-by: Kevin Koga <[email protected]> Co-authored-by: Mate Scharnitzky <[email protected]> Co-authored-by: Tomer Shor <[email protected]> Co-authored-by: Pierre-Yves Mousset <[email protected]> Co-authored-by: Bela Chupal <[email protected]> Co-authored-by: Khangjrakpam Arjun <[email protected]> Co-authored-by: Joel <[email protected]> Co-authored-by: Nok Lam Chan <[email protected]>
- Loading branch information
1 parent
d81a4d0
commit f59e930
Showing
10 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Provides interface to Hugging Face transformers and datasets.""" | ||
from typing import Any | ||
|
||
import lazy_loader as lazy | ||
|
||
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 | ||
HFDataset: Any | ||
HFTransformerPipelineDataset: Any | ||
|
||
__getattr__, __dir__, __all__ = lazy.attach( | ||
__name__, | ||
submod_attrs={ | ||
"hugging_face_dataset": ["HFDataset"], | ||
"transformer_pipeline_dataset": ["HFTransformerPipelineDataset"], | ||
}, | ||
) |
56 changes: 56 additions & 0 deletions
56
kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from datasets import load_dataset | ||
from huggingface_hub import HfApi | ||
from kedro.io import AbstractVersionedDataset | ||
|
||
|
||
class HFDataset(AbstractVersionedDataset): | ||
"""``HFDataset`` loads Hugging Face datasets | ||
using the `datasets <https://pypi.org/project/datasets>`_ library. | ||
Example usage for the :doc:`YAML API <kedro:data/data_catalog_yaml_examples>`: | ||
.. code-block:: yaml | ||
yelp_reviews: | ||
type: kedro_hf_datasets.HFDataset | ||
dataset_name: yelp_review_full | ||
Example usage for the :doc:`Python API <kedro:data/advanced_data_catalog_usage>`: | ||
.. code-block:: pycon | ||
>>> from kedro_datasets.huggingface import HFDataset | ||
>>> dataset = HFDataset(dataset_name="yelp_review_full") | ||
>>> yelp_review_full = dataset.load() | ||
>>> assert "train" in yelp_review_full | ||
>>> assert "test" in yelp_review_full | ||
>>> assert len(yelp_review_full["train"]) == 650000 | ||
""" | ||
|
||
def __init__(self, dataset_name: str): | ||
self.dataset_name = dataset_name | ||
|
||
def _load(self): | ||
return load_dataset(self.dataset_name) | ||
|
||
def _save(self): | ||
raise NotImplementedError("Not yet implemented") | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
api = HfApi() | ||
dataset_info = list(api.list_datasets(search=self.dataset_name))[0] | ||
return { | ||
"dataset_name": self.dataset_name, | ||
"dataset_tags": dataset_info.tags, | ||
"dataset_author": dataset_info.author, | ||
} | ||
|
||
@staticmethod | ||
def list_datasets(): | ||
api = HfApi() | ||
return list(api.list_datasets()) |
71 changes: 71 additions & 0 deletions
71
kedro-datasets/kedro_datasets/huggingface/transformer_pipeline_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
import typing as t | ||
from warnings import warn | ||
|
||
from kedro.io import AbstractDataset | ||
from transformers import Pipeline, pipeline | ||
|
||
|
||
class HFTransformerPipelineDataset(AbstractDataset): | ||
"""``HFTransformerPipelineDataset`` loads pretrained Hugging Face transformers | ||
using the `transformers <https://pypi.org/project/transformers>`_ library. | ||
Example usage for the :doc:`YAML API <kedro:data/data_catalog_yaml_examples>`: | ||
.. code-block:: yaml | ||
summarizer_model: | ||
type: huggingface.HFTransformerPipelineDataset | ||
task: summarization | ||
fill_mask_model: | ||
type: huggingface.HFTransformerPipelineDataset | ||
task: fill-mask | ||
model_name: Twitter/twhin-bert-base | ||
Example usage for the :doc:`Python API <kedro:data/advanced_data_catalog_usage>`: | ||
.. code-block:: pycon | ||
>>> from kedro_datasets.huggingface import HFTransformerPipelineDataset | ||
>>> dataset = HFTransformerPipelineDataset(task="text-classification", model_name="papluca/xlm-roberta-base-language-detection") | ||
>>> detector = dataset.load() | ||
>>> assert detector("Ceci n'est pas une pipe")[0]["label"] == "fr" | ||
""" | ||
|
||
def __init__( | ||
self, | ||
task: str | None = None, | ||
model_name: str | None = None, | ||
pipeline_kwargs: dict[t.Any] | None = None, | ||
): | ||
if task is None and model_name is None: | ||
raise ValueError("At least 'task' or 'model_name' are needed") | ||
self._task = task if task else None | ||
self._model_name = model_name | ||
self._pipeline_kwargs = pipeline_kwargs or {} | ||
|
||
if self._pipeline_kwargs and ( | ||
"task" in self._pipeline_kwargs or "model" in self._pipeline_kwargs | ||
): | ||
warn( | ||
"Specifying 'task' or 'model' in 'pipeline_kwargs' is not allowed", | ||
UserWarning, | ||
) | ||
self._pipeline_kwargs.pop("task", None) | ||
self._pipeline_kwargs.pop("model", None) | ||
|
||
def _load(self) -> Pipeline: | ||
return pipeline(self._task, model=self._model_name, **self._pipeline_kwargs) | ||
|
||
def _save(self, pipeline: Pipeline) -> None: | ||
raise NotImplementedError("Not yet implemented") | ||
|
||
def _describe(self) -> dict[str, t.Any]: | ||
return { | ||
"task": self._task, | ||
"model_name": self._model_name, | ||
"pipeline_kwargs": self._pipeline_kwargs, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
This file contains the fixtures that are reusable by any tests within | ||
this directory. You don't need to import the fixtures as pytest will | ||
discover them automatically. More info here: | ||
https://docs.pytest.org/en/latest/fixture.html | ||
""" |
33 changes: 33 additions & 0 deletions
33
kedro-datasets/tests/huggingface/test_hugging_face_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pytest | ||
from huggingface_hub import HfApi | ||
|
||
from kedro_datasets.huggingface import HFDataset | ||
|
||
|
||
@pytest.fixture | ||
def dataset_name(): | ||
return "yelp_review_full" | ||
|
||
|
||
class TestHFDataset: | ||
def test_simple_dataset_load(self, dataset_name, mocker): | ||
mocked_load_dataset = mocker.patch( | ||
"kedro_datasets.huggingface.hugging_face_dataset.load_dataset" | ||
) | ||
|
||
dataset = HFDataset( | ||
dataset_name=dataset_name, | ||
) | ||
hf_ds = dataset.load() | ||
|
||
mocked_load_dataset.assert_called_once_with(dataset_name) | ||
assert hf_ds is mocked_load_dataset.return_value | ||
|
||
def test_list_datasets(self, mocker): | ||
expected_datasets = ["dataset_1", "dataset_2"] | ||
mocked_hf_list_datasets = mocker.patch.object(HfApi, "list_datasets") | ||
mocked_hf_list_datasets.return_value = expected_datasets | ||
|
||
datasets = HFDataset.list_datasets() | ||
|
||
assert datasets == expected_datasets |
65 changes: 65 additions & 0 deletions
65
kedro-datasets/tests/huggingface/test_transformer_pipeline_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
|
||
from kedro_datasets.huggingface import HFTransformerPipelineDataset | ||
|
||
|
||
@pytest.fixture | ||
def task(): | ||
return "fill-mask" | ||
|
||
|
||
@pytest.fixture | ||
def model_name(): | ||
return "Twitter/twhin-bert-base" | ||
|
||
|
||
class TestHFTransformerPipelineDataset: | ||
def test_simple_dataset_load(self, task, model_name, mocker): | ||
mocked_pipeline = mocker.patch( | ||
"kedro_datasets.huggingface.transformer_pipeline_dataset.pipeline" | ||
) | ||
|
||
dataset = HFTransformerPipelineDataset( | ||
task=task, | ||
model_name=model_name, | ||
) | ||
model = dataset.load() | ||
|
||
mocked_pipeline.assert_called_once_with(task, model=model_name) | ||
assert model is mocked_pipeline.return_value | ||
|
||
def test_dataset_pipeline_kwargs_load(self, task, model_name, mocker): | ||
mocked_pipeline = mocker.patch( | ||
"kedro_datasets.huggingface.transformer_pipeline_dataset.pipeline" | ||
) | ||
|
||
pipeline_kwargs = {"foo": True} | ||
dataset = HFTransformerPipelineDataset( | ||
task=task, | ||
model_name=model_name, | ||
pipeline_kwargs=pipeline_kwargs, | ||
) | ||
model = dataset.load() | ||
|
||
mocked_pipeline.assert_called_once_with( | ||
task, model=model_name, **pipeline_kwargs | ||
) | ||
assert model is mocked_pipeline.return_value | ||
|
||
def test_dataset_redundant_pipeline_kwargs(self, task, model_name, mocker): | ||
pipeline_kwargs = {"task": "redundant"} | ||
with pytest.warns( | ||
UserWarning, | ||
match="Specifying 'task' or 'model' in 'pipeline_kwargs' is not allowed", | ||
): | ||
HFTransformerPipelineDataset( | ||
task=task, | ||
model_name=model_name, | ||
pipeline_kwargs=pipeline_kwargs, | ||
) | ||
|
||
def test_dataset_incomplete_init(self): | ||
with pytest.raises( | ||
ValueError, match="At least 'task' or 'model_name' are needed" | ||
): | ||
HFTransformerPipelineDataset() |