Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): Refactor TensorFlowModelDataset to DataSet #186

Merged
7 changes: 7 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
## Bug fixes and other changes
* Relaxed `delta-spark` upper bound to allow compatibility with Spark 3.1.x and 3.2.x.

# Release 1.2.1:

## Major features and improvements:

## Bug fixes and other changes
* Renamed `TensorFlowModelDataset` to `TensorFlowModelDataSet` to be consistent with all other plugins in kedro-datasets.

# Release 1.2.0:

## Major features and improvements:
Expand Down
8 changes: 4 additions & 4 deletions kedro-datasets/kedro_datasets/tensorflow/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TensorFlowModelDataset
# TensorFlowModelDataSet

``TensorflowModelDataset`` loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments to TensorFlow 2.X load_model and save_model methods. Only TF2 is currently supported for saving and loading, V1 requires HDF5 and serialises differently.
Expand All @@ -8,9 +8,9 @@ The underlying functionality is supported by, and passes input arguments to Tens
import numpy as np
import tensorflow as tf

from kedro_datasets.tensorflow import TensorFlowModelDataset
from kedro_datasets.tensorflow import TensorFlowModelDataSet

data_set = TensorFlowModelDataset("tf_model_dirname")
data_set = TensorFlowModelDataSet("tf_model_dirname")

model = tf.keras.Model()
predictions = model.predict([...])
Expand All @@ -25,7 +25,7 @@ np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
#### Example catalog.yml:
```yaml
example_tensorflow_data:
type: tensorflow.TensorFlowModelDataset
type: tensorflow.TensorFlowModelDataSet
filepath: data/08_reporting/tf_model_dirname
load_args:
tf_device: "/CPU:0" # optional
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Provides I/O for TensorFlow Models."""

__all__ = ["TensorFlowModelDataset"]
__all__ = ["TensorFlowModelDataSet"]

from contextlib import suppress

with suppress(ImportError):
from .tensorflow_model_dataset import TensorFlowModelDataset
from .tensorflow_model_dataset import TensorFlowModelDataSet
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""``TensorflowModelDataset`` is a data set implementation which can save and load
"""``TensorFlowModelDataSet`` is a data set implementation which can save and load
TensorFlow models.
"""
import copy
Expand All @@ -19,8 +19,8 @@
TEMPORARY_H5_FILE = "tmp_tensorflow_model.h5"


class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.Model]):
"""``TensorflowModelDataset`` loads and saves TensorFlow models.
class TensorFlowModelDataSet(AbstractVersionedDataSet[tf.keras.Model, tf.keras.Model]):
"""``TensorFlowModelDataSet`` loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments through to,
TensorFlow 2.X load_model and save_model methods.

Expand All @@ -31,7 +31,7 @@ class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M
.. code-block:: yaml

tensorflow_model:
type: tensorflow.TensorFlowModelDataset
type: tensorflow.TensorFlowModelDataSet
filepath: data/06_models/tensorflow_model.h5
load_args:
compile: False
Expand All @@ -45,11 +45,11 @@ class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M
data_catalog.html#use-the-data-catalog-with-the-code-api>`_:
::

>>> from kedro_datasets.tensorflow import TensorFlowModelDataset
>>> from kedro_datasets.tensorflow import TensorFlowModelDataSet
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> data_set = TensorFlowModelDataset("data/06_models/tensorflow_model.h5")
>>> data_set = TensorFlowModelDataSet("data/06_models/tensorflow_model.h5")
>>> model = tf.keras.Model()
>>> predictions = model.predict([...])
>>>
Expand All @@ -73,7 +73,7 @@ def __init__(
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
) -> None:
"""Creates a new instance of ``TensorFlowModelDataset``.
"""Creates a new instance of ``TensorFlowModelDataSet``.

Args:
filepath: Filepath in POSIX format to a TensorFlow model directory prefixed with a
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _collect_requirements(requires):
}
svmlight_require = {"svmlight.SVMLightDataSet": ["scikit-learn~=1.0.2", "scipy~=1.7.3"]}
tensorflow_require = {
"tensorflow.TensorflowModelDataset": [
"tensorflow.TensorFlowModelDataSet": [
# currently only TensorFlow V2 supported for saving and loading.
# V1 requires HDF5 and serialises differently
"tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'",
Expand Down
30 changes: 15 additions & 15 deletions kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from s3fs import S3FileSystem


# In this test module, we wrap tensorflow and TensorFlowModelDataset imports into a module-scoped
# In this test module, we wrap tensorflow and TensorFlowModelDataSet imports into a module-scoped
# fixtures to avoid them being evaluated immediately when a new test process is spawned.
# Specifically:
# - ParallelRunner spawns a new subprocess.
Expand All @@ -34,9 +34,9 @@ def tf():

@pytest.fixture(scope="module")
def tensorflow_model_dataset():
from kedro_datasets.tensorflow import TensorFlowModelDataset
from kedro_datasets.tensorflow import TensorFlowModelDataSet

return TensorFlowModelDataset
return TensorFlowModelDataSet


@pytest.fixture
Expand Down Expand Up @@ -134,7 +134,7 @@ def call(self, inputs, training=None, mask=None): # pragma: no cover
return model


class TestTensorFlowModelDataset:
class TestTensorFlowModelDataSet:
"""No versioning passed to creator"""

def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test):
Expand All @@ -152,7 +152,7 @@ def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test
def test_load_missing_model(self, tf_model_dataset):
"""Test error message when trying to load missing model."""
pattern = (
r"Failed while loading data from data set TensorFlowModelDataset\(.*\)"
r"Failed while loading data from data set TensorFlowModelDataSet\(.*\)"
)
with pytest.raises(DataSetError, match=pattern):
tf_model_dataset.load()
Expand All @@ -166,7 +166,7 @@ def test_exists(self, tf_model_dataset, dummy_tf_base_model):
def test_hdf5_save_format(
self, dummy_tf_base_model, dummy_x_test, filepath, tensorflow_model_dataset
):
"""Test TensorflowModelDataset can save TF graph models in HDF5 format"""
"""Test TensorFlowModelDataSet can save TF graph models in HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath, save_args={"save_format": "h5"}
)
Expand All @@ -187,7 +187,7 @@ def test_unused_subclass_model_hdf5_save_format(
filepath,
tensorflow_model_dataset,
):
"""Test TensorflowModelDataset cannot save subclassed user models in HDF5 format
"""Test TensorFlowModelDataSet cannot save subclassed user models in HDF5 format

Subclassed model

Expand Down Expand Up @@ -277,8 +277,8 @@ def test_save_and_overwrite_existing_model(
assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers)


class TestTensorFlowModelDatasetVersioned:
"""Test suite with versioning argument passed into TensorFlowModelDataset creator"""
class TestTensorFlowModelDataSetVersioned:
"""Test suite with versioning argument passed into TensorFlowModelDataSet creator"""

@pytest.mark.parametrize(
"load_version,save_version",
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_hdf5_save_format(
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can save TF graph models in
"""Test versioned TensorFlowModelDataSet can save TF graph models in
HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
Expand All @@ -340,7 +340,7 @@ def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset
corresponding file for a given save version already exists."""
versioned_tf_model_dataset.save(dummy_tf_base_model)
pattern = (
r"Save path \'.+\' for TensorFlowModelDataset\(.+\) must "
r"Save path \'.+\' for TensorFlowModelDataSet\(.+\) must "
r"not exist if versioning is enabled\."
)
with pytest.raises(DataSetError, match=pattern):
Expand All @@ -362,7 +362,7 @@ def test_save_version_warning(
the subsequent load path."""
pattern = (
rf"Save version '{save_version}' did not match load version '{load_version}' "
rf"for TensorFlowModelDataset\(.+\)"
rf"for TensorFlowModelDataSet\(.+\)"
)
with pytest.warns(UserWarning, match=pattern):
versioned_tf_model_dataset.save(dummy_tf_base_model)
Expand All @@ -383,7 +383,7 @@ def test_exists(self, versioned_tf_model_dataset, dummy_tf_base_model):

def test_no_versions(self, versioned_tf_model_dataset):
"""Check the error if no versions are available for load."""
pattern = r"Did not find any versions for TensorFlowModelDataset\(.+\)"
pattern = r"Did not find any versions for TensorFlowModelDataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
versioned_tf_model_dataset.load()

Expand All @@ -408,7 +408,7 @@ def test_versioning_existing_dataset(
self, tf_model_dataset, versioned_tf_model_dataset, dummy_tf_base_model
):
"""Check behavior when attempting to save a versioned dataset on top of an
already existing (non-versioned) dataset. Note: because TensorFlowModelDataset
already existing (non-versioned) dataset. Note: because TensorFlowModelDataSet
saves to a directory even if non-versioned, an error is not expected."""
tf_model_dataset.save(dummy_tf_base_model)
assert tf_model_dataset.exists()
Expand All @@ -425,7 +425,7 @@ def test_save_and_load_with_device(
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can load models using an explicit tf_device"""
"""Test versioned TensorFlowModelDataSet can load models using an explicit tf_device"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
load_args={"tf_device": "/CPU:0"},
Expand Down