diff --git a/docs/source/data/how_to_create_a_custom_dataset.md b/docs/source/data/how_to_create_a_custom_dataset.md index 01ad199f55..7f39987dd7 100644 --- a/docs/source/data/how_to_create_a_custom_dataset.md +++ b/docs/source/data/how_to_create_a_custom_dataset.md @@ -4,7 +4,7 @@ ## AbstractDataset -If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to override the `_load` and `_save` and provides `load` and `save` methods that enrich the corresponding private methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. +If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to implement the `load` and `save` methods while providing wrappers that enrich the corresponding methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. ## Scenario @@ -31,8 +31,8 @@ Consult the [Pillow documentation](https://pillow.readthedocs.io/en/stable/insta At the minimum, a valid Kedro dataset needs to subclass the base {py:class}`~kedro.io.AbstractDataset` and provide an implementation for the following abstract methods: -* `_load` -* `_save` +* `load` +* `save` * `_describe` `AbstractDataset` is generically typed with an input data type for saving data, and an output data type for loading data. @@ -70,7 +70,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ self._filepath = filepath - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -78,7 +78,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ ... - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath""" ... @@ -96,11 +96,11 @@ src/kedro_pokemon/datasets └── image_dataset.py ``` -## Implement the `_load` method with `fsspec` +## Implement the `load` method with `fsspec` Many of the built-in Kedro datasets rely on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) as a consistent interface to different data sources, as described earlier in the section about the [Data Catalog](../data/data_catalog.md#dataset-filepath). In this example, it's particularly convenient to use `fsspec` in conjunction with `Pillow` to read image data, since it allows the dataset to work flexibly with different image locations and formats. -Here is the implementation of the `_load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array: +Here is the implementation of the `load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array:
Click to expand @@ -130,7 +130,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -168,14 +168,14 @@ In [2]: from PIL import Image In [3]: Image.fromarray(image).show() ``` -## Implement the `_save` method with `fsspec` +## Implement the `save` method with `fsspec` Similarly, we can implement the `_save` method as follows: ```python class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems save_path = get_filepath_str(self._filepath, self._protocol) @@ -243,7 +243,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -254,7 +254,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._filepath, self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -312,7 +312,7 @@ To add versioning support to the new dataset we need to extend the {py:class}`~kedro.io.AbstractVersionedDataset` to: * Accept a `version` keyword argument as part of the constructor -* Adapt the `_load` and `_save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively +* Adapt the `load` and `save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively The following amends the full implementation of our basic `ImageDataset`. It now loads and saves data to and from a versioned subfolder (`data/01_raw/pokemon-images-and-types/images/images/pikachu.png//pikachu.png` with `version` being a datetime-formatted string `YYYY-MM-DDThh.mm.ss.sssZ` by default): @@ -359,7 +359,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): glob_function=self._fs.glob, ) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -370,7 +370,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -435,7 +435,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas + glob_function=self._fs.glob, + ) + - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -447,7 +447,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" - save_path = get_filepath_str(self._filepath, self._protocol) + save_path = get_filepath_str(self._get_save_path(), self._protocol) diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index 5f8d96dc36..85d9341db5 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -103,7 +103,7 @@ def __repr__(self) -> str: } return self._pretty_repr(object_description) - def _load(self) -> Any: + def load(self) -> Any: data = self._cache.load() if self._cache.exists() else self._dataset.load() if not self._cache.exists(): @@ -111,7 +111,7 @@ def _load(self) -> Any: return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 56ad92b7f2..1b4bb8a371 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -59,7 +59,7 @@ def __init__( if data is not _EMPTY: self.save.__wrapped__(self, data) # type: ignore[attr-defined] - def _load(self) -> Any: + def load(self) -> Any: if self._data is _EMPTY: raise DatasetError("Data for MemoryDataset has not been saved yet.") @@ -67,7 +67,7 @@ def _load(self) -> Any: data = _copy_with_mode(self._data, copy_mode=copy_mode) return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: copy_mode = self._copy_mode or _infer_copy_mode(data) self._data = _copy_with_mode(data, copy_mode=copy_mode) diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index e2bd63bf7e..a7e28d0256 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -36,10 +36,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError() return getattr(self.shared_memory_dataset, name) # pragma: no cover - def _load(self) -> Any: + def load(self) -> Any: return self.shared_memory_dataset.load() - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: """Calls save method of a shared MemoryDataset in SyncManager.""" try: self.shared_memory_dataset.save(data)