Skip to content

Commit

Permalink
feat(datasets): Add Hugging Face datasets (#344)
Browse files Browse the repository at this point in the history
* Add HuggingFace datasets

Co-authored-by: Danny Farah <[email protected]>
Co-authored-by: Kevin Koga <[email protected]>
Co-authored-by: Mate Scharnitzky <[email protected]>
Co-authored-by: Tomer Shor <[email protected]>
Co-authored-by: Pierre-Yves Mousset <[email protected]>
Co-authored-by: Bela Chupal <[email protected]>
Co-authored-by: Khangjrakpam Arjun <[email protected]>
Co-authored-by: Juan Luis Cano Rodríguez <[email protected]>
Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Apply suggestions from code review

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

Co-authored-by: Joel <[email protected]>
Co-authored-by: Nok Lam Chan <[email protected]>

* Typo

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Fix docstring

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Add docstring for HFTransformerPipelineDataset

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Use intersphinx for cross references in Hugging Face docstrings

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Add docstring for HFDataset

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Add missing test dependencies

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Add tests for huggingface datasets

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Fix HFDataset.save

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Add test for HFDataset.list_datasets

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Use new name

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

* Consolidate imports

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>

---------

Signed-off-by: Juan Luis Cano Rodríguez <[email protected]>
Co-authored-by: Danny Farah <[email protected]>
Co-authored-by: Kevin Koga <[email protected]>
Co-authored-by: Mate Scharnitzky <[email protected]>
Co-authored-by: Tomer Shor <[email protected]>
Co-authored-by: Pierre-Yves Mousset <[email protected]>
Co-authored-by: Bela Chupal <[email protected]>
Co-authored-by: Khangjrakpam Arjun <[email protected]>
Co-authored-by: Joel <[email protected]>
Co-authored-by: Nok Lam Chan <[email protected]>
  • Loading branch information
10 people authored Nov 13, 2023
1 parent d81a4d0 commit f59e930
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 0 deletions.
4 changes: 4 additions & 0 deletions kedro-datasets/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx.ext.doctest",
Expand Down Expand Up @@ -90,6 +91,9 @@
"kedro_docs_style_guide.md",
]

intersphinx_mapping = {
"kedro": ("https://docs.kedro.org/en/stable/", None),
}

type_targets = {
"py:class": (
Expand Down
2 changes: 2 additions & 0 deletions kedro-datasets/docs/source/kedro_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ kedro_datasets
kedro_datasets.geopandas.GeoJSONDataSet
kedro_datasets.geopandas.GeoJSONDataset
kedro_datasets.holoviews.HoloviewsWriter
kedro_datasets.huggingface.HFDataset
kedro_datasets.huggingface.HFTransformerPipelineDataset
kedro_datasets.json.JSONDataSet
kedro_datasets.json.JSONDataset
kedro_datasets.matplotlib.MatplotlibWriter
Expand Down
16 changes: 16 additions & 0 deletions kedro-datasets/kedro_datasets/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Provides interface to Hugging Face transformers and datasets."""
from typing import Any

import lazy_loader as lazy

# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
HFDataset: Any
HFTransformerPipelineDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__,
submod_attrs={
"hugging_face_dataset": ["HFDataset"],
"transformer_pipeline_dataset": ["HFTransformerPipelineDataset"],
},
)
56 changes: 56 additions & 0 deletions kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import Any

from datasets import load_dataset
from huggingface_hub import HfApi
from kedro.io import AbstractVersionedDataset


class HFDataset(AbstractVersionedDataset):
"""``HFDataset`` loads Hugging Face datasets
using the `datasets <https://pypi.org/project/datasets>`_ library.
Example usage for the :doc:`YAML API <kedro:data/data_catalog_yaml_examples>`:
.. code-block:: yaml
yelp_reviews:
type: kedro_hf_datasets.HFDataset
dataset_name: yelp_review_full
Example usage for the :doc:`Python API <kedro:data/advanced_data_catalog_usage>`:
.. code-block:: pycon
>>> from kedro_datasets.huggingface import HFDataset
>>> dataset = HFDataset(dataset_name="yelp_review_full")
>>> yelp_review_full = dataset.load()
>>> assert "train" in yelp_review_full
>>> assert "test" in yelp_review_full
>>> assert len(yelp_review_full["train"]) == 650000
"""

def __init__(self, dataset_name: str):
self.dataset_name = dataset_name

def _load(self):
return load_dataset(self.dataset_name)

def _save(self):
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, Any]:
api = HfApi()
dataset_info = list(api.list_datasets(search=self.dataset_name))[0]
return {
"dataset_name": self.dataset_name,
"dataset_tags": dataset_info.tags,
"dataset_author": dataset_info.author,
}

@staticmethod
def list_datasets():
api = HfApi()
return list(api.list_datasets())
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

import typing as t
from warnings import warn

from kedro.io import AbstractDataset
from transformers import Pipeline, pipeline


class HFTransformerPipelineDataset(AbstractDataset):
"""``HFTransformerPipelineDataset`` loads pretrained Hugging Face transformers
using the `transformers <https://pypi.org/project/transformers>`_ library.
Example usage for the :doc:`YAML API <kedro:data/data_catalog_yaml_examples>`:
.. code-block:: yaml
summarizer_model:
type: huggingface.HFTransformerPipelineDataset
task: summarization
fill_mask_model:
type: huggingface.HFTransformerPipelineDataset
task: fill-mask
model_name: Twitter/twhin-bert-base
Example usage for the :doc:`Python API <kedro:data/advanced_data_catalog_usage>`:
.. code-block:: pycon
>>> from kedro_datasets.huggingface import HFTransformerPipelineDataset
>>> dataset = HFTransformerPipelineDataset(task="text-classification", model_name="papluca/xlm-roberta-base-language-detection")
>>> detector = dataset.load()
>>> assert detector("Ceci n'est pas une pipe")[0]["label"] == "fr"
"""

def __init__(
self,
task: str | None = None,
model_name: str | None = None,
pipeline_kwargs: dict[t.Any] | None = None,
):
if task is None and model_name is None:
raise ValueError("At least 'task' or 'model_name' are needed")
self._task = task if task else None
self._model_name = model_name
self._pipeline_kwargs = pipeline_kwargs or {}

if self._pipeline_kwargs and (
"task" in self._pipeline_kwargs or "model" in self._pipeline_kwargs
):
warn(
"Specifying 'task' or 'model' in 'pipeline_kwargs' is not allowed",
UserWarning,
)
self._pipeline_kwargs.pop("task", None)
self._pipeline_kwargs.pop("model", None)

def _load(self) -> Pipeline:
return pipeline(self._task, model=self._model_name, **self._pipeline_kwargs)

def _save(self, pipeline: Pipeline) -> None:
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, t.Any]:
return {
"task": self._task,
"model_name": self._model_name,
"pipeline_kwargs": self._pipeline_kwargs,
}
9 changes: 9 additions & 0 deletions kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def _collect_requirements(requires):
"geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"]
}
holoviews_require = {"holoviews.HoloviewsWriter": ["holoviews~=1.13.0"]}
huggingface_require = {
"huggingface.HFDataset": ["datasets", "huggingface_hub"],
"huggingface.HFTransformerPipelineDataset": ["transformers"],
}
matplotlib_require = {"matplotlib.MatplotlibWriter": ["matplotlib>=3.0.3, <4.0"]}
networkx_require = {"networkx.NetworkXDataSet": ["networkx~=2.4"]}
pandas_require = {
Expand Down Expand Up @@ -102,6 +106,7 @@ def _collect_requirements(requires):
"databricks": _collect_requirements(databricks_require),
"geopandas": _collect_requirements(geopandas_require),
"holoviews": _collect_requirements(holoviews_require),
"huggingface": _collect_requirements(huggingface_require),
"matplotlib": _collect_requirements(matplotlib_require),
"networkx": _collect_requirements(networkx_require),
"pandas": _collect_requirements(pandas_require),
Expand Down Expand Up @@ -224,6 +229,10 @@ def _collect_requirements(requires):
"triad>=0.6.7, <1.0",
"trufflehog~=2.1",
"xlsxwriter~=1.0",
# huggingface
"datasets",
"huggingface_hub",
"transformers",
]

setup(
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions kedro-datasets/tests/huggingface/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
This file contains the fixtures that are reusable by any tests within
this directory. You don't need to import the fixtures as pytest will
discover them automatically. More info here:
https://docs.pytest.org/en/latest/fixture.html
"""
33 changes: 33 additions & 0 deletions kedro-datasets/tests/huggingface/test_hugging_face_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from huggingface_hub import HfApi

from kedro_datasets.huggingface import HFDataset


@pytest.fixture
def dataset_name():
return "yelp_review_full"


class TestHFDataset:
def test_simple_dataset_load(self, dataset_name, mocker):
mocked_load_dataset = mocker.patch(
"kedro_datasets.huggingface.hugging_face_dataset.load_dataset"
)

dataset = HFDataset(
dataset_name=dataset_name,
)
hf_ds = dataset.load()

mocked_load_dataset.assert_called_once_with(dataset_name)
assert hf_ds is mocked_load_dataset.return_value

def test_list_datasets(self, mocker):
expected_datasets = ["dataset_1", "dataset_2"]
mocked_hf_list_datasets = mocker.patch.object(HfApi, "list_datasets")
mocked_hf_list_datasets.return_value = expected_datasets

datasets = HFDataset.list_datasets()

assert datasets == expected_datasets
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from kedro_datasets.huggingface import HFTransformerPipelineDataset


@pytest.fixture
def task():
return "fill-mask"


@pytest.fixture
def model_name():
return "Twitter/twhin-bert-base"


class TestHFTransformerPipelineDataset:
def test_simple_dataset_load(self, task, model_name, mocker):
mocked_pipeline = mocker.patch(
"kedro_datasets.huggingface.transformer_pipeline_dataset.pipeline"
)

dataset = HFTransformerPipelineDataset(
task=task,
model_name=model_name,
)
model = dataset.load()

mocked_pipeline.assert_called_once_with(task, model=model_name)
assert model is mocked_pipeline.return_value

def test_dataset_pipeline_kwargs_load(self, task, model_name, mocker):
mocked_pipeline = mocker.patch(
"kedro_datasets.huggingface.transformer_pipeline_dataset.pipeline"
)

pipeline_kwargs = {"foo": True}
dataset = HFTransformerPipelineDataset(
task=task,
model_name=model_name,
pipeline_kwargs=pipeline_kwargs,
)
model = dataset.load()

mocked_pipeline.assert_called_once_with(
task, model=model_name, **pipeline_kwargs
)
assert model is mocked_pipeline.return_value

def test_dataset_redundant_pipeline_kwargs(self, task, model_name, mocker):
pipeline_kwargs = {"task": "redundant"}
with pytest.warns(
UserWarning,
match="Specifying 'task' or 'model' in 'pipeline_kwargs' is not allowed",
):
HFTransformerPipelineDataset(
task=task,
model_name=model_name,
pipeline_kwargs=pipeline_kwargs,
)

def test_dataset_incomplete_init(self):
with pytest.raises(
ValueError, match="At least 'task' or 'model_name' are needed"
):
HFTransformerPipelineDataset()

0 comments on commit f59e930

Please sign in to comment.