Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Add option to async load and save in PartitionedDatasets #696

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
7 changes: 7 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

## Major features and improvements

- Added async functionality for saving data in `PartitionedDataset` via `use_async` argument.
- Added the following new core datasets:

| Type | Description | Location |
| ------------------- | ------------------------------------------------------------- | --------------------- |
| `ibis.TableDataset` | A dataset for loading and saving files using Ibis's backends. | `kedro_datasets.ibis` |

## Community contributions

Many thanks to the following Kedroids for contributing PRs to this release:

- [Puneet Saini](https://github.com/puneeter)

# Release 5.0.0

## Major features and improvements
Expand Down
41 changes: 41 additions & 0 deletions kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import asyncio
import operator
from collections.abc import Callable
from copy import deepcopy
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__( # noqa: PLR0913
fs_args: dict[str, Any] | None = None,
overwrite: bool = False,
metadata: dict[str, Any] | None = None,
use_async: bool = False,
) -> None:
"""Creates a new instance of ``PartitionedDataset``.

Expand Down Expand Up @@ -193,6 +195,8 @@ def __init__( # noqa: PLR0913
overwrite: If True, any existing partitions will be removed.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
use_async: If True, the dataset will be saved asynchronously.
Defaults to False.

Raises:
DatasetError: If versioning is enabled for the underlying dataset.
Expand All @@ -207,6 +211,7 @@ def __init__( # noqa: PLR0913
self._protocol = infer_storage_options(self._path)["protocol"]
self._partition_cache: Cache = Cache(maxsize=1)
self.metadata = metadata
self._use_async = use_async

dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
Expand Down Expand Up @@ -302,6 +307,12 @@ def load(self) -> dict[str, Callable[[], Any]]:
return partitions

def save(self, data: dict[str, Any]) -> None:
if self._use_async:
asyncio.run(self._async_save(data))
Comment on lines 309 to +311
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, asyncio.run creates a new event loop, so if there's already an event loop running (for example, in a Jupyter notebook), calling this will raise an error.

This is essentially the red/blue function problem... Most of Kedro is synchronous anyway AFAIK, but I think this might set an API expectation that could be difficult to satisfy cleanly.

@merelcht @ElenaKhaustova do you have more thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, asyncio.run creates a new event loop, so if there's already an event loop running (for example, in a Jupyter notebook), calling this will raise an error.

That is indeed correct. Maybe it's alright though to say async saving doesn't work in interactive envs?

but I think this might set an API expectation that could be difficult to satisfy cleanly

@astrojuanlu can you elaborate what you mean with this?

else:
self._sync_save(data)

def _sync_save(self, data: dict[str, Any]) -> None:
if self._overwrite and self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True)

Expand All @@ -316,6 +327,36 @@ def save(self, data: dict[str, Any]) -> None:
dataset.save(partition_data)
self._invalidate_caches()

async def _async_save(self, data: dict[str, Any]) -> None:
if self._overwrite and await self._filesystem_exists(self._normalized_path):
await self._filesystem_rm(self._normalized_path, recursive=True)

async def save_partition(partition_id: str, partition_data: Any) -> None:
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
await self._dataset_save(dataset, partition_data)

await asyncio.gather(
*[
save_partition(partition_id, partition_data)
for partition_id, partition_data in sorted(data.items())
]
)
self._invalidate_caches()

async def _filesystem_exists(self, path: str) -> bool:
return self._filesystem.exists(path)

async def _filesystem_rm(self, path: str, recursive: bool) -> None:
self._filesystem.rm(path, recursive=recursive)

async def _dataset_save(self, dataset: AbstractDataset, data: Any) -> None:
dataset.save(data)

def _describe(self) -> dict[str, Any]:
clean_dataset_config = (
{k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY}
Expand Down
Loading
Loading