Skip to content

Commit

Permalink
refactor(datasets): deprecate "DataSet" type names (databricks)
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Sep 10, 2023
1 parent 039150b commit 07cdf9e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 98 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/docs/source/kedro_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ kedro_datasets
kedro_datasets.dask.ParquetDataSet
kedro_datasets.dask.ParquetDataset
kedro_datasets.databricks.ManagedTableDataSet
kedro_datasets.databricks.ManagedTableDataset
kedro_datasets.email.EmailMessageDataSet
kedro_datasets.geopandas.GeoJSONDataSet
kedro_datasets.holoviews.HoloviewsWriter
Expand Down
10 changes: 8 additions & 2 deletions kedro-datasets/kedro_datasets/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Provides interface to Unity Catalog Tables."""
from __future__ import annotations

from typing import Any

import lazy_loader as lazy

# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
ManagedTableDataSet: Any
ManagedTableDataSet: type[ManagedTableDataset]
ManagedTableDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__, submod_attrs={"managed_table_dataset": ["ManagedTableDataSet"]}
__name__,
submod_attrs={
"managed_table_dataset": ["ManagedTableDataSet", "ManagedTableDataset"]
},
)
120 changes: 69 additions & 51 deletions kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""``ManagedTableDataSet`` implementation to access managed delta tables
"""``ManagedTableDataset`` implementation to access managed delta tables
in Databricks.
"""
import logging
import re
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

Expand All @@ -12,8 +13,7 @@
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

from .._io import AbstractVersionedDataset as AbstractVersionedDataSet
from .._io import DatasetError as DataSetError
from kedro_datasets._io import AbstractVersionedDataset, DatasetError

logger = logging.getLogger(__name__)

Expand All @@ -39,9 +39,9 @@ class ManagedTable:
def __post_init__(self):
"""Run validation methods if declared.
The validation method can be a simple check
that raises DataSetError.
that raises DatasetError.
The validation is performed by calling a function named:
`validate_<field_name>(self, value) -> raises DataSetError`
`validate_<field_name>(self, value) -> raises DatasetError`
"""
for name in self.__dataclass_fields__.keys(): # pylint: disable=no-member
method = getattr(self, f"_validate_{name}", None)
Expand All @@ -52,42 +52,42 @@ def _validate_table(self):
"""Validates table name
Raises:
DataSetError: If the table name does not conform to naming constraints.
DatasetError: If the table name does not conform to naming constraints.
"""
if not re.fullmatch(self._NAMING_REGEX, self.table):
raise DataSetError("table does not conform to naming")
raise DatasetError("table does not conform to naming")

def _validate_database(self):
"""Validates database name
Raises:
DataSetError: If the dataset name does not conform to naming constraints.
DatasetError: If the dataset name does not conform to naming constraints.
"""
if not re.fullmatch(self._NAMING_REGEX, self.database):
raise DataSetError("database does not conform to naming")
raise DatasetError("database does not conform to naming")

def _validate_catalog(self):
"""Validates catalog name
Raises:
DataSetError: If the catalog name does not conform to naming constraints.
DatasetError: If the catalog name does not conform to naming constraints.
"""
if self.catalog:
if not re.fullmatch(self._NAMING_REGEX, self.catalog):
raise DataSetError("catalog does not conform to naming")
raise DatasetError("catalog does not conform to naming")

def _validate_write_mode(self):
"""Validates the write mode
Raises:
DataSetError: If an invalid `write_mode` is passed.
DatasetError: If an invalid `write_mode` is passed.
"""
if (
self.write_mode is not None
and self.write_mode not in self._VALID_WRITE_MODES
):
valid_modes = ", ".join(self._VALID_WRITE_MODES)
raise DataSetError(
raise DatasetError(
f"Invalid `write_mode` provided: {self.write_mode}. "
f"`write_mode` must be one of: {valid_modes}"
)
Expand All @@ -96,21 +96,21 @@ def _validate_dataframe_type(self):
"""Validates the dataframe type
Raises:
DataSetError: If an invalid `dataframe_type` is passed
DatasetError: If an invalid `dataframe_type` is passed
"""
if self.dataframe_type not in self._VALID_DATAFRAME_TYPES:
valid_types = ", ".join(self._VALID_DATAFRAME_TYPES)
raise DataSetError(f"`dataframe_type` must be one of {valid_types}")
raise DatasetError(f"`dataframe_type` must be one of {valid_types}")

def _validate_primary_key(self):
"""Validates the primary key of the table
Raises:
DataSetError: If no `primary_key` is specified.
DatasetError: If no `primary_key` is specified.
"""
if self.primary_key is None or len(self.primary_key) == 0:
if self.write_mode == "upsert":
raise DataSetError(
raise DatasetError(
f"`primary_key` must be provided for"
f"`write_mode` {self.write_mode}"
)
Expand Down Expand Up @@ -139,12 +139,12 @@ def schema(self) -> StructType:
if self.json_schema is not None:
schema = StructType.fromJson(self.json_schema)
except (KeyError, ValueError) as exc:
raise DataSetError(exc) from exc
raise DatasetError(exc) from exc
return schema


class ManagedTableDataSet(AbstractVersionedDataSet):
"""``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks.
class ManagedTableDataset(AbstractVersionedDataset):
"""``ManagedTableDataset`` loads and saves data into managed delta tables on Databricks.
Load and save can be in Spark or Pandas dataframes, specified in dataframe_type.
When saving data, you can specify one of three modes: overwrite(default), append,
or upsert. Upsert requires you to specify the primary_column parameter which
Expand All @@ -160,13 +160,13 @@ class ManagedTableDataSet(AbstractVersionedDataSet):
.. code-block:: yaml
names_and_ages@spark:
type: databricks.ManagedTableDataSet
table: names_and_ages
type: databricks.ManagedTableDataset
table: names_and_ages
names_and_ages@pandas:
type: databricks.ManagedTableDataSet
table: names_and_ages
dataframe_type: pandas
type: databricks.ManagedTableDataset
table: names_and_ages
dataframe_type: pandas
Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
Expand All @@ -176,15 +176,15 @@ class ManagedTableDataSet(AbstractVersionedDataSet):
from pyspark.sql import SparkSession
from pyspark.sql.types import (StructField, StringType,
IntegerType, StructType)
from kedro_datasets.databricks import ManagedTableDataSet
IntegerType, StructType)
from kedro_datasets.databricks import ManagedTableDataset
schema = StructType([StructField("name", StringType(), True),
StructField("age", IntegerType(), True)])
StructField("age", IntegerType(), True)])
data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)]
spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
data_set = ManagedTableDataSet(table="names_and_ages")
data_set.save(spark_df)
reloaded = data_set.load()
dataset = ManagedTableDataset(table="names_and_ages")
dataset.save(spark_df)
reloaded = dataset.load()
reloaded.take(4)
"""

Expand All @@ -210,36 +210,36 @@ def __init__( # pylint: disable=R0913
partition_columns: List[str] = None,
owner_group: str = None,
) -> None:
"""Creates a new instance of ``ManagedTableDataSet``
"""Creates a new instance of ``ManagedTableDataset``.
Args:
table: the name of the table
catalog: the name of the catalog in Unity.
Defaults to None.
Defaults to None.
database: the name of the database.
(also referred to as schema). Defaults to "default".
(also referred to as schema). Defaults to "default".
write_mode: the mode to write the data into the table. If not
present, the data set is read-only.
Options are:["overwrite", "append", "upsert"].
"upsert" mode requires primary_key field to be populated.
Defaults to None.
present, the data set is read-only.
Options are:["overwrite", "append", "upsert"].
"upsert" mode requires primary_key field to be populated.
Defaults to None.
dataframe_type: "pandas" or "spark" dataframe.
Defaults to "spark".
Defaults to "spark".
primary_key: the primary key of the table.
Can be in the form of a list. Defaults to None.
Can be in the form of a list. Defaults to None.
version: kedro.io.core.Version instance to load the data.
Defaults to None.
Defaults to None.
schema: the schema of the table in JSON form.
Dataframes will be truncated to match the schema if provided.
Used by the hooks to create the table if the schema is provided
Defaults to None.
Dataframes will be truncated to match the schema if provided.
Used by the hooks to create the table if the schema is provided
Defaults to None.
partition_columns: the columns to use for partitioning the table.
Used by the hooks. Defaults to None.
Used by the hooks. Defaults to None.
owner_group: if table access control is enabled in your workspace,
specifying owner_group will transfer ownership of the table and database to
this owner. All databases should have the same owner_group. Defaults to None.
specifying owner_group will transfer ownership of the table and database to
this owner. All databases should have the same owner_group. Defaults to None.
Raises:
DataSetError: Invalid configuration supplied (through ManagedTable validation)
DatasetError: Invalid configuration supplied (through ManagedTable validation)
"""

self._table = ManagedTable(
Expand Down Expand Up @@ -332,7 +332,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_columns = update_data.columns

if set(update_columns) != set(base_columns):
raise DataSetError(
raise DatasetError(
f"Upsert requires tables to have identical columns. "
f"Delta table {self._table.full_table_location()} "
f"has columns: {base_columns}, whereas "
Expand Down Expand Up @@ -370,7 +370,7 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None:
data (Any): Spark or pandas dataframe to save to the table location
"""
if self._table.write_mode is None:
raise DataSetError(
raise DatasetError(
"'save' can not be used in read-only mode. "
"Change 'write_mode' value to `overwrite`, `upsert` or `append`."
)
Expand All @@ -394,7 +394,7 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None:
self._save_append(data)

def _describe(self) -> Dict[str, str]:
"""Returns a description of the instance of ManagedTableDataSet
"""Returns a description of the instance of ManagedTableDataset
Returns:
Dict[str, str]: Dict with the details of the dataset
Expand Down Expand Up @@ -438,3 +438,21 @@ def _exists(self) -> bool:
except (ParseException, AnalysisException) as exc:
logger.warning("error occured while trying to find table: %s", exc)
return False


_DEPRECATED_CLASSES = {
"ManagedTableDataSet": ManagedTableDataset,
}


def __getattr__(name):
if name in _DEPRECATED_CLASSES:
alias = _DEPRECATED_CLASSES[name]
warnings.warn(
f"{repr(name)} has been renamed to {repr(alias.__name__)}, "
f"and the alias will be removed in Kedro-Datasets 2.0.0",
DeprecationWarning,
stacklevel=2,
)
return alias
raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}")
Loading

0 comments on commit 07cdf9e

Please sign in to comment.