Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Kedro committed Sep 17, 2024
2 parents 3fccaaf + 6bf29f9 commit 4f57625
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 71 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Upcoming Release

## Major features and improvements
* Implemented `Protocol` abstraction for the current `DataCatalog` and adding new catalog implementations.
* Refactored `kedro run` and `kedro catalog` commands.
* Moved pattern resolution logic from `DataCatalog` to a separate component - `CatalogConfigResolver`. Updated `DataCatalog` to use `CatalogConfigResolver` internally.
* Made packaged Kedro projects return `session.run()` output to be used when running it in the interactive environment.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"kedro.io.catalog_config_resolver.CatalogConfigResolver",
"kedro.io.core.AbstractDataset",
"kedro.io.core.AbstractVersionedDataset",
"kedro.io.core.CatalogProtocol",
"kedro.io.core.DatasetError",
"kedro.io.core.Version",
"kedro.io.data_catalog.DataCatalog",
Expand Down Expand Up @@ -170,6 +171,7 @@
"None. Update D from mapping/iterable E and F.",
"Patterns",
"CatalogConfigResolver",
"CatalogProtocol",
),
"py:data": (
"typing.Any",
Expand Down
20 changes: 10 additions & 10 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from kedro.config import AbstractConfigLoader, MissingConfigException
from kedro.framework.project import settings
from kedro.io import DataCatalog # noqa: TCH001
from kedro.io import CatalogProtocol, DataCatalog # noqa: TCH001
from kedro.pipeline.transcoding import _transcode_split

if TYPE_CHECKING:
Expand Down Expand Up @@ -123,7 +123,7 @@ def _convert_paths_to_absolute_posix(
return conf_dictionary


def _validate_transcoded_datasets(catalog: DataCatalog) -> None:
def _validate_transcoded_datasets(catalog: CatalogProtocol) -> None:
"""Validates transcoded datasets are correctly named
Args:
Expand Down Expand Up @@ -178,13 +178,13 @@ class KedroContext:
)

@property
def catalog(self) -> DataCatalog:
"""Read-only property referring to Kedro's ``DataCatalog`` for this context.
def catalog(self) -> CatalogProtocol:
"""Read-only property referring to Kedro's catalog` for this context.
Returns:
DataCatalog defined in `catalog.yml`.
catalog defined in `catalog.yml`.
Raises:
KedroContextError: Incorrect ``DataCatalog`` registered for the project.
KedroContextError: Incorrect catalog registered for the project.
"""
return self._get_catalog()
Expand Down Expand Up @@ -213,13 +213,13 @@ def _get_catalog(
self,
save_version: str | None = None,
load_versions: dict[str, str] | None = None,
) -> DataCatalog:
"""A hook for changing the creation of a DataCatalog instance.
) -> CatalogProtocol:
"""A hook for changing the creation of a catalog instance.
Returns:
DataCatalog defined in `catalog.yml`.
catalog defined in `catalog.yml`.
Raises:
KedroContextError: Incorrect ``DataCatalog`` registered for the project.
KedroContextError: Incorrect catalog registered for the project.
"""
# '**/catalog*' reads modular pipeline configs
Expand Down
28 changes: 14 additions & 14 deletions kedro/framework/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if TYPE_CHECKING:
from kedro.framework.context import KedroContext
from kedro.io import DataCatalog
from kedro.io import CatalogProtocol
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

Expand All @@ -22,7 +22,7 @@ class DataCatalogSpecs:
@hook_spec
def after_catalog_created( # noqa: PLR0913
self,
catalog: DataCatalog,
catalog: CatalogProtocol,
conf_catalog: dict[str, Any],
conf_creds: dict[str, Any],
feed_dict: dict[str, Any],
Expand Down Expand Up @@ -53,7 +53,7 @@ class NodeSpecs:
def before_node_run(
self,
node: Node,
catalog: DataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
session_id: str,
Expand All @@ -63,7 +63,7 @@ def before_node_run(
Args:
node: The ``Node`` to run.
catalog: A ``DataCatalog`` containing the node's inputs and outputs.
catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs.
inputs: The dictionary of inputs dataset.
The keys are dataset names and the values are the actual loaded input data,
not the dataset instance.
Expand All @@ -81,7 +81,7 @@ def before_node_run(
def after_node_run( # noqa: PLR0913
self,
node: Node,
catalog: DataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
outputs: dict[str, Any],
is_async: bool,
Expand All @@ -93,7 +93,7 @@ def after_node_run( # noqa: PLR0913
Args:
node: The ``Node`` that ran.
catalog: A ``DataCatalog`` containing the node's inputs and outputs.
catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs.
inputs: The dictionary of inputs dataset.
The keys are dataset names and the values are the actual loaded input data,
not the dataset instance.
Expand All @@ -110,7 +110,7 @@ def on_node_error( # noqa: PLR0913
self,
error: Exception,
node: Node,
catalog: DataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
session_id: str,
Expand All @@ -122,7 +122,7 @@ def on_node_error( # noqa: PLR0913
Args:
error: The uncaught exception thrown during the node run.
node: The ``Node`` to run.
catalog: A ``DataCatalog`` containing the node's inputs and outputs.
catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs.
inputs: The dictionary of inputs dataset.
The keys are dataset names and the values are the actual loaded input data,
not the dataset instance.
Expand All @@ -137,7 +137,7 @@ class PipelineSpecs:

@hook_spec
def before_pipeline_run(
self, run_params: dict[str, Any], pipeline: Pipeline, catalog: DataCatalog
self, run_params: dict[str, Any], pipeline: Pipeline, catalog: CatalogProtocol
) -> None:
"""Hook to be invoked before a pipeline runs.
Expand All @@ -164,7 +164,7 @@ def before_pipeline_run(
}
pipeline: The ``Pipeline`` that will be run.
catalog: The ``DataCatalog`` to be used during the run.
catalog: An implemented instance of ``CatalogProtocol`` to be used during the run.
"""
pass

Expand All @@ -174,7 +174,7 @@ def after_pipeline_run(
run_params: dict[str, Any],
run_result: dict[str, Any],
pipeline: Pipeline,
catalog: DataCatalog,
catalog: CatalogProtocol,
) -> None:
"""Hook to be invoked after a pipeline runs.
Expand Down Expand Up @@ -202,7 +202,7 @@ def after_pipeline_run(
run_result: The output of ``Pipeline`` run.
pipeline: The ``Pipeline`` that was run.
catalog: The ``DataCatalog`` used during the run.
catalog: An implemented instance of ``CatalogProtocol`` used during the run.
"""
pass

Expand All @@ -212,7 +212,7 @@ def on_pipeline_error(
error: Exception,
run_params: dict[str, Any],
pipeline: Pipeline,
catalog: DataCatalog,
catalog: CatalogProtocol,
) -> None:
"""Hook to be invoked if a pipeline run throws an uncaught Exception.
The signature of this error hook should match the signature of ``before_pipeline_run``
Expand Down Expand Up @@ -242,7 +242,7 @@ def on_pipeline_error(
}
pipeline: The ``Pipeline`` that will was run.
catalog: The ``DataCatalog`` used during the run.
catalog: An implemented instance of ``CatalogProtocol`` used during the run.
"""
pass

Expand Down
25 changes: 23 additions & 2 deletions kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dynaconf import LazySettings
from dynaconf.validator import ValidationError, Validator

from kedro.io import CatalogProtocol
from kedro.pipeline import Pipeline, pipeline

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,6 +60,25 @@ def validate(
)


class _ImplementsCatalogProtocolValidator(Validator):
"""A validator to check if the supplied setting value is a subclass of the default class"""

def validate(
self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any
) -> None:
super().validate(settings, *args, **kwargs)

protocol = CatalogProtocol
for name in self.names:
setting_value = getattr(settings, name)
if not isinstance(setting_value(), protocol):
raise ValidationError(
f"Invalid value '{setting_value.__module__}.{setting_value.__qualname__}' "
f"received for setting '{name}'. It must implement "
f"'{protocol.__module__}.{protocol.__qualname__}'."
)


class _HasSharedParentClassValidator(Validator):
"""A validator to check that the parent of the default class is an ancestor of
the settings value."""
Expand Down Expand Up @@ -115,8 +135,9 @@ class _ProjectSettings(LazySettings):
_CONFIG_LOADER_ARGS = Validator(
"CONFIG_LOADER_ARGS", default={"base_env": "base", "default_run_env": "local"}
)
_DATA_CATALOG_CLASS = _IsSubclassValidator(
"DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog")
_DATA_CATALOG_CLASS = _ImplementsCatalogProtocolValidator(
"DATA_CATALOG_CLASS",
default=_get_default_class("kedro.io.DataCatalog"),
)

def __init__(self, *args: Any, **kwargs: Any):
Expand Down
2 changes: 2 additions & 0 deletions kedro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .core import (
AbstractDataset,
AbstractVersionedDataset,
CatalogProtocol,
DatasetAlreadyExistsError,
DatasetError,
DatasetNotFoundError,
Expand All @@ -23,6 +24,7 @@
"AbstractDataset",
"AbstractVersionedDataset",
"CachedDataset",
"CatalogProtocol",
"DataCatalog",
"CatalogConfigResolver",
"DatasetAlreadyExistsError",
Expand Down
79 changes: 78 additions & 1 deletion kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from glob import iglob
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Protocol,
TypeVar,
runtime_checkable,
)
from urllib.parse import urlsplit

from cachetools import Cache, cachedmethod
Expand All @@ -29,6 +37,8 @@
if TYPE_CHECKING:
import os

from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns

VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ"
VERSIONED_FLAG_KEY = "versioned"
VERSION_KEY = "version"
Expand Down Expand Up @@ -871,3 +881,70 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None:
raise DatasetError(
f"Neither white-space nor semicolon are allowed in '{key}'."
)


_C = TypeVar("_C")


@runtime_checkable
class CatalogProtocol(Protocol[_C]):
_datasets: dict[str, AbstractDataset]

def __contains__(self, ds_name: str) -> bool:
"""Check if a dataset is in the catalog."""
...

@property
def config_resolver(self) -> CatalogConfigResolver:
"""Return a copy of the datasets dictionary."""
...

@classmethod
def from_config(cls, catalog: dict[str, dict[str, Any]] | None) -> _C:
"""Create a catalog instance from configuration."""
...

def _get_dataset(
self,
dataset_name: str,
version: Any = None,
suggest: bool = True,
) -> AbstractDataset:
"""Retrieve a dataset by its name."""
...

def list(self, regex_search: str | None = None) -> list[str]:
"""List all dataset names registered in the catalog."""
...

def save(self, name: str, data: Any) -> None:
"""Save data to a registered dataset."""
...

def load(self, name: str, version: str | None = None) -> Any:
"""Load data from a registered dataset."""
...

def add(self, ds_name: str, dataset: Any, replace: bool = False) -> None:
"""Add a new dataset to the catalog."""
...

def add_feed_dict(self, datasets: dict[str, Any], replace: bool = False) -> None:
"""Add datasets to the catalog using the data provided through the `feed_dict`."""
...

def exists(self, name: str) -> bool:
"""Checks whether registered data set exists by calling its `exists()` method."""
...

def release(self, name: str) -> None:
"""Release any cached data associated with a dataset."""
...

def confirm(self, name: str) -> None:
"""Confirm a dataset by its name."""
...

def shallow_copy(self, extra_dataset_patterns: Patterns | None = None) -> _C:
"""Returns a shallow copy of the current object."""
...
Loading

0 comments on commit 4f57625

Please sign in to comment.