diff --git a/CHANGELOG.md b/CHANGELOG.md index a3cb378d..e4af63c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ - Introduced use_materialization_v2 flag for gating materialization revamps. ([844](https://github.com/databricks/dbt-databricks/pull/844)) +### Under the Hood + +- Implement new constraint logic for use_materialization_v2 flag ([846](https://github.com/databricks/dbt-databricks/pull/846/files)) + ## dbt-databricks 1.9.0 (TBD) ### Features diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index df2cdb2d..4e15e912 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from typing import Any, ClassVar, Optional +from dbt_common.contracts.constraints import ColumnLevelConstraint, ConstraintType + +from dbt.adapters.databricks import constraints from dbt.adapters.databricks.utils import quote from dbt.adapters.spark.column import SparkColumn @@ -9,6 +12,8 @@ class DatabricksColumn(SparkColumn): table_comment: Optional[str] = None comment: Optional[str] = None + not_null: Optional[bool] = None + constraints: Optional[list[ColumnLevelConstraint]] = None TYPE_LABELS: ClassVar[dict[str, str]] = { "LONG": "BIGINT", @@ -27,6 +32,45 @@ def create(cls, name: str, label_or_dtype: str) -> "DatabricksColumn": def data_type(self) -> str: return self.translate_type(self.dtype) + def add_constraint(self, constraint: ColumnLevelConstraint) -> None: + # On first constraint add, initialize constraint details + if self.constraints is None: + self.constraints = [] + self.not_null = False + if constraint.type == ConstraintType.not_null: + self.not_null = True + else: + self.constraints.append(constraint) + + def enrich(self, model_column: dict[str, Any]) -> "DatabricksColumn": + """Create a copy that incorporates model column metadata, including constraints.""" + + data_type = model_column.get("data_type") or self.dtype + enriched_column = DatabricksColumn.create(self.name, data_type) + if model_column.get("description"): + enriched_column.comment = model_column["description"] + + if model_column.get("constraints"): + for constraint in model_column["constraints"]: + parsed_constraint = constraints.parse_column_constraint(constraint) + enriched_column.add_constraint(parsed_constraint) + + return enriched_column + + def render_for_create(self) -> str: + """Renders the column for building a create statement.""" + column_str = f"{self.name} {self.dtype}" + if self.not_null: + column_str += " NOT NULL" + if self.comment: + comment = self.comment.replace("'", "\\'") + column_str += f" COMMENT '{comment}'" + for constraint in self.constraints or []: + c = constraints.process_column_constraint(constraint) + if c: + column_str += f" {c}" + return column_str + def __repr__(self) -> str: return "".format(self.name, self.data_type) diff --git a/dbt/adapters/databricks/constraints.py b/dbt/adapters/databricks/constraints.py new file mode 100644 index 00000000..96380090 --- /dev/null +++ b/dbt/adapters/databricks/constraints.py @@ -0,0 +1,234 @@ +from functools import partial +from typing import Any, Callable, Optional, TypeVar + +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) +from dbt_common.events.functions import warn_or_error +from dbt_common.exceptions import DbtValidationError + +from dbt.adapters.base import ConstraintSupport +from dbt.adapters.databricks.logging import logger +from dbt.adapters.events.types import ConstraintNotEnforced, ConstraintNotSupported + +# Support constants +CONSTRAINT_SUPPORT = { + ConstraintType.check: ConstraintSupport.ENFORCED, + ConstraintType.not_null: ConstraintSupport.ENFORCED, + ConstraintType.unique: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.primary_key: ConstraintSupport.NOT_ENFORCED, + ConstraintType.foreign_key: ConstraintSupport.NOT_ENFORCED, +} + +SUPPORTED_FOR_COLUMN = { + ConstraintType.custom, + ConstraintType.primary_key, + ConstraintType.foreign_key, +} + +# Types +"""Generic type variable for constraints.""" +T = TypeVar("T", bound=ColumnLevelConstraint) + +"""Function type for checking constraint support.""" +SupportFunc = Callable[[T], bool] + +"""Function type for rendering constraints.""" +RenderFunc = Callable[[T], str] + + +# Base support and enforcement +def is_supported(constraint: ColumnLevelConstraint) -> bool: + if constraint.type == ConstraintType.custom: + return True + if constraint.type in CONSTRAINT_SUPPORT: + return CONSTRAINT_SUPPORT[constraint.type] != ConstraintSupport.NOT_SUPPORTED + return False + + +def is_enforced(constraint: ColumnLevelConstraint) -> bool: + return constraint.type in CONSTRAINT_SUPPORT and CONSTRAINT_SUPPORT[constraint.type] not in [ + ConstraintSupport.NOT_ENFORCED, + ConstraintSupport.NOT_SUPPORTED, + ] + + +# Core parsing, validation, and processing +def parse_constraint(klass: type[T], raw_constraint: dict[str, Any]) -> T: + try: + klass.validate(raw_constraint) + return klass.from_dict(raw_constraint) + except Exception: + raise DbtValidationError(f"Could not parse constraint: {raw_constraint}") + + +def process_constraint( + constraint: T, + support_func: SupportFunc, + render_funcs: dict[ConstraintType, RenderFunc], +) -> Optional[str]: + if validate_constraint(constraint, support_func): + return render_constraint(constraint, render_funcs) + + return None + + +def validate_constraint(constraint: T, support_func: SupportFunc) -> bool: + # Custom constraints are always supported + if constraint.type == ConstraintType.custom: + return True + + supported = support_func(constraint) + + if constraint.warn_unsupported and not supported: + warn_or_error( + ConstraintNotSupported(constraint=constraint.type.value, adapter="DatabricksAdapter") + ) + elif constraint.warn_unenforced and not is_enforced(constraint): + warn_or_error( + ConstraintNotEnforced(constraint=constraint.type.value, adapter="DatabricksAdapter") + ) + + return supported + + +def render_constraint( + constraint: T, render_funcs: dict[ConstraintType, RenderFunc] +) -> Optional[str]: + rendered_constraint = "" + + if constraint.type in render_funcs: + if constraint.name: + rendered_constraint = f"CONSTRAINT {constraint.name} " + rendered_constraint += render_funcs[constraint.type](constraint) + + rendered_constraint = rendered_constraint.strip() + + return rendered_constraint if rendered_constraint != "" else None + + +def supported_for(constraint: T, support_func: SupportFunc, warning: str) -> bool: + if is_supported(constraint) and not support_func(constraint): + logger.warning(warning.format(type=constraint.type)) + + return is_supported(constraint) and support_func(constraint) + + +# Shared render functions +def render_error(constraint: ColumnLevelConstraint, missing: list[list[str]]) -> DbtValidationError: + fields = " or ".join(["(" + ", ".join([f"'{e}'" for e in x]) + ")" for x in missing]) + constraint_type = constraint.type.value + constraint_name = "" if not constraint.name else f"{constraint.name} " + return DbtValidationError( + f"{constraint_type} constraint {constraint_name}is missing required field(s): {fields}" + ) + + +def render_custom(constraint: ColumnLevelConstraint) -> str: + assert constraint.type == ConstraintType.custom + if constraint.expression: + return constraint.expression + raise render_error(constraint, [["expression"]]) + + +# ColumnLevelConstraint specialization +def parse_column_constraint(raw_constraint: dict[str, Any]) -> ColumnLevelConstraint: + return parse_constraint(ColumnLevelConstraint, raw_constraint) + + +COLUMN_WARNING = ( + "While constraint of type {type} is supported for models, it is not supported for columns." +) + +supported_for_columns = partial( + supported_for, support_func=lambda x: x.type in SUPPORTED_FOR_COLUMN, warning=COLUMN_WARNING +) + + +def column_constraint_map() -> dict[ConstraintType, RenderFunc]: + return { + ConstraintType.primary_key: render_primary_key_for_column, + ConstraintType.foreign_key: render_foreign_key_for_column, + ConstraintType.custom: render_custom, + } + + +def process_column_constraint(constraint: ColumnLevelConstraint) -> Optional[str]: + return process_constraint(constraint, supported_for_columns, column_constraint_map()) + + +def render_primary_key_for_column(constraint: ColumnLevelConstraint) -> str: + assert constraint.type == ConstraintType.primary_key + rendered = "PRIMARY KEY" + if constraint.expression: + rendered += f" {constraint.expression}" + return rendered + + +def render_foreign_key_for_column(constraint: ColumnLevelConstraint) -> str: + assert constraint.type == ConstraintType.foreign_key + rendered = "FOREIGN KEY" + if constraint.expression: + return rendered + f" {constraint.expression}" + elif constraint.to and constraint.to_columns: + return rendered + f" REFERENCES {constraint.to} ({', '.join(constraint.to_columns)})" + raise render_error(constraint, [["expression"], ["to", "to_columns"]]) + + +# ModelLevelConstraint specialization +def parse_model_constraint(raw_constraint: dict[str, Any]) -> ModelLevelConstraint: + return parse_constraint(ModelLevelConstraint, raw_constraint) + + +MODEL_WARNING = ( + "While constraint of type {type} is supported for columns, it is not supported for models." +) + + +supported_for_models = partial( + supported_for, support_func=lambda x: x.type != ConstraintType.not_null, warning=MODEL_WARNING +) + + +def model_constraint_map() -> dict[ConstraintType, RenderFunc]: + return { + ConstraintType.primary_key: render_primary_key_for_model, + ConstraintType.foreign_key: render_foreign_key_for_model, + ConstraintType.custom: render_custom, + ConstraintType.check: render_check, + } + + +def process_model_constraint(constraint: ModelLevelConstraint) -> Optional[str]: + return process_constraint(constraint, supported_for_models, model_constraint_map()) + + +def render_primary_key_for_model(constraint: ModelLevelConstraint) -> str: + prefix = render_primary_key_for_column(constraint) + if constraint.expression: + return prefix + if constraint.columns: + return f"{prefix} ({', '.join(constraint.columns)})" + raise render_error(constraint, [["columns"], ["expression"]]) + + +def render_foreign_key_for_model(constraint: ModelLevelConstraint) -> str: + assert constraint.type == ConstraintType.foreign_key + rendered = "FOREIGN KEY" + if constraint.columns and constraint.to and constraint.to_columns: + columns = ", ".join(constraint.columns) + to_columns = ", ".join(constraint.to_columns) + return rendered + f" ({columns}) REFERENCES {constraint.to} ({to_columns})" + elif constraint.expression: + return rendered + " " + constraint.expression + raise render_error(constraint, [["expression"], ["columns", "to", "to_columns"]]) + + +def render_check(constraint: ColumnLevelConstraint) -> str: + assert constraint.type == ConstraintType.check + if constraint.expression: + return f"CHECK ({constraint.expression})" + + raise render_error(constraint, [["expression"]]) diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 98fb578a..5e4b75d6 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -25,6 +25,7 @@ from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support from dbt.adapters.contracts.connection import AdapterResponse, Connection from dbt.adapters.contracts.relation import RelationConfig, RelationType +from dbt.adapters.databricks import constraints from dbt.adapters.databricks.behaviors.columns import ( GetColumnsBehavior, GetColumnsByDescribe, @@ -179,6 +180,8 @@ class DatabricksAdapter(SparkAdapter): } ) + CONSTRAINT_SUPPORT = constraints.CONSTRAINT_SUPPORT + get_column_behavior: GetColumnsBehavior def __init__(self, config: Any, mp_context: SpawnContext) -> None: @@ -764,6 +767,24 @@ def get_config_from_model(self, model: RelationConfig) -> DatabricksRelationConf def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str: return f"{suffix_initial}_{str(uuid4())}" + @available + @staticmethod + def get_enriched_columns( + existing_columns: list[DatabricksColumn], model_columns: dict[str, dict[str, Any]] + ) -> list[DatabricksColumn]: + """Returns a list of columns that have been updated with features for table create.""" + enriched_columns = [] + + for column in existing_columns: + if column.name in model_columns: + column_info = model_columns[column.name] + enriched_column = column.enrich(column_info) + enriched_columns.append(enriched_column) + else: + enriched_columns.append(column) + + return enriched_columns + @dataclass(frozen=True) class RelationAPIBase(ABC, Generic[DatabricksRelationConfig]): diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index bc1dc991..76506d55 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Optional, Type +from dbt_common.contracts.constraints import ConstraintType, ModelLevelConstraint from dbt_common.dataclass_schema import StrEnum from dbt_common.exceptions import DbtRuntimeError from dbt_common.utils import filter_null_values @@ -10,6 +11,7 @@ from dbt.adapters.contracts.relation import ( ComponentName, ) +from dbt.adapters.databricks.constraints import parse_model_constraint, process_model_constraint from dbt.adapters.databricks.utils import remove_undefined from dbt.adapters.spark.impl import KEY_TABLE_OWNER, KEY_TABLE_STATISTICS from dbt.adapters.utils import classproperty @@ -58,7 +60,8 @@ class DatabricksRelation(BaseRelation): quote_policy: Policy = field(default_factory=lambda: DatabricksQuotePolicy()) include_policy: Policy = field(default_factory=lambda: DatabricksIncludePolicy()) quote_character: str = "`" - + create_constraints: list[ModelLevelConstraint] = field(default_factory=list) + alter_constraints: list[ModelLevelConstraint] = field(default_factory=list) metadata: Optional[dict[str, Any]] = None @classmethod @@ -148,6 +151,23 @@ def information_schema(self, view_name: Optional[str] = None) -> InformationSche def StreamingTable(cls) -> str: return str(DatabricksRelationType.StreamingTable) + def add_constraint(self, constraint: ModelLevelConstraint) -> None: + if constraint.type == ConstraintType.check: + self.alter_constraints.append(constraint) + else: + self.create_constraints.append(constraint) + + def enrich(self, raw_constraints: list[dict[str, Any]]) -> "DatabricksRelation": + copy = self.incorporate() + for constraint in raw_constraints: + copy.add_constraint(parse_model_constraint(constraint)) + + return copy + + def render_constraints_for_create(self) -> str: + processed = [process_model_constraint(c) for c in self.create_constraints] + return ", ".join(c for c in processed if c is not None) + def is_hive_metastore(database: Optional[str]) -> bool: return database is None or database.lower() == "hive_metastore" diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 78ae12cb..1eb4585b 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -412,6 +412,28 @@ def test_simple_catalog_relation(self): ) assert relation.database == "test_catalog" + def expected_column(self, real_vals): + default_col = { + "table_database": None, + "table_schema": None, + "table_name": None, + "table_type": "table", + "table_owner": "root", + "table_comment": None, + "column": "col1", + "column_index": 0, + "dtype": None, + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "comment": None, + "constraints": None, + "not_null": None, + } + + default_col.update(real_vals) + return default_col + def test_parse_relation(self): self.maxDiff = None rel_type = DatabricksRelation.get_relation_type.Table @@ -479,69 +501,51 @@ def test_parse_relation(self): } assert len(rows) == 4 - assert rows[0].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "col1", - "column_index": 0, - "dtype": "decimal(22,0)", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "comment": "comment", - } + assert rows[0].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "col1", + "column_index": 0, + "dtype": "decimal(22,0)", + "comment": "comment", + } + ) - assert rows[1].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "col2", - "column_index": 1, - "dtype": "string", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "comment": "comment", - } + assert rows[1].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "col2", + "column_index": 1, + "dtype": "string", + "comment": "comment", + } + ) - assert rows[2].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "dt", - "column_index": 2, - "dtype": "date", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "comment": None, - } + assert rows[2].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "dt", + "column_index": 2, + "dtype": "date", + } + ) - assert rows[3].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "struct_col", - "column_index": 3, - "dtype": "struct", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "comment": None, - } + assert rows[3].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "struct_col", + "column_index": 3, + "dtype": "struct", + } + ) def test_parse_relation_with_integer_owner(self): self.maxDiff = None @@ -632,29 +636,27 @@ def test_parse_relation_with_statistics(self): } assert len(rows) == 1 - assert rows[0].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": "Table model description", - "column": "col1", - "column_index": 0, - "comment": "comment", - "dtype": "decimal(22,0)", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 1109049927, - "stats:rows:description": "", - "stats:rows:include": True, - "stats:rows:label": "rows", - "stats:rows:value": 14093476, - } + assert rows[0].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": "Table model description", + "column": "col1", + "column_index": 0, + "comment": "comment", + "dtype": "decimal(22,0)", + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1109049927, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 14093476, + } + ) def test_relation_with_database(self): config = self._get_config() @@ -701,45 +703,35 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) relation, information ) assert len(columns) == 4 - assert columns[0].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "col1", - "column_index": 0, - "dtype": "decimal(22,0)", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 123456789, - "comment": None, - } + assert columns[0].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "col1", + "column_index": 0, + "dtype": "decimal(22,0)", + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + } + ) - assert columns[3].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "struct_col", - "column_index": 3, - "dtype": "struct", - "comment": None, - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 123456789, - } + assert columns[3].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "struct_col", + "column_index": 3, + "dtype": "struct", + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + } + ) def test_parse_columns_from_information_with_view_type(self): self.maxDiff = None @@ -786,37 +778,27 @@ def test_parse_columns_from_information_with_view_type(self): relation, information ) assert len(columns) == 4 - assert columns[1].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "col2", - "column_index": 1, - "comment": None, - "dtype": "string", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - } + assert columns[1].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "col2", + "column_index": 1, + "dtype": "string", + } + ) - assert columns[3].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "struct_col", - "column_index": 3, - "comment": None, - "dtype": "struct", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - } + assert columns[3].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "struct_col", + "column_index": 3, + "dtype": "struct", + } + ) def test_parse_columns_from_information_with_table_type_and_parquet_provider(self): self.maxDiff = None @@ -852,53 +834,43 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel relation, information ) assert len(columns) == 4 - assert columns[2].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "dt", - "column_index": 2, - "comment": None, - "dtype": "date", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 1234567890, - "stats:rows:description": "", - "stats:rows:include": True, - "stats:rows:label": "rows", - "stats:rows:value": 12345678, - } + assert columns[2].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "dt", + "column_index": 2, + "dtype": "date", + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + } + ) - assert columns[3].to_column_dict(omit_none=False) == { - "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, - "table_type": rel_type, - "table_owner": "root", - "table_comment": None, - "column": "struct_col", - "column_index": 3, - "comment": None, - "dtype": "struct", - "numeric_scale": None, - "numeric_precision": None, - "char_size": None, - "stats:bytes:description": "", - "stats:bytes:include": True, - "stats:bytes:label": "bytes", - "stats:bytes:value": 1234567890, - "stats:rows:description": "", - "stats:rows:include": True, - "stats:rows:label": "rows", - "stats:rows:value": 12345678, - } + assert columns[3].to_column_dict(omit_none=False) == self.expected_column( + { + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "column": "struct_col", + "column_index": 3, + "dtype": "struct", + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + } + ) def test_describe_table_extended_2048_char_limit(self): """GIVEN a list of table_names whos total character length exceeds 2048 characters diff --git a/tests/unit/test_column.py b/tests/unit/test_column.py index f0aa6562..57ea44fd 100644 --- a/tests/unit/test_column.py +++ b/tests/unit/test_column.py @@ -1,4 +1,5 @@ import pytest +from dbt_common.contracts.constraints import ColumnLevelConstraint, ConstraintType from dbt.adapters.databricks.column import DatabricksColumn @@ -28,6 +29,87 @@ def test_convert_table_stats_with_bytes_and_rows(self): } +class TestAddConstraint: + @pytest.fixture + def column(self): + return DatabricksColumn("id", "LONG") + + def test_add_constraint__not_null(self, column): + column.add_constraint(ColumnLevelConstraint(type=ConstraintType.not_null)) + assert column.not_null is True + assert column.constraints == [] + + def test_add_constraint__other_constraint(self, column): + constraint = ColumnLevelConstraint(type=ConstraintType.custom) + column.add_constraint(constraint) + assert column.not_null is False + assert column.constraints == [constraint] + + +class TestEnrich: + @pytest.fixture + def column(self): + return DatabricksColumn("id", "INT") + + @pytest.fixture + def model_column(self): + return { + "data_type": "LONG", + "description": "this is a column", + "constraints": [ + {"type": "not_null"}, + {"type": "primary_key", "name": "foo"}, + ], + } + + def test_enrich(self, column, model_column): + enriched_column = column.enrich(model_column) + assert enriched_column.data_type == "bigint" + assert enriched_column.comment == "this is a column" + assert enriched_column.not_null is True + assert enriched_column.constraints == [ + ColumnLevelConstraint(type=ConstraintType.primary_key, name="foo") + ] + + +class TestRenderForCreate: + @pytest.fixture + def column(self): + return DatabricksColumn("id", "INT") + + def test_render_for_create__base(self, column): + assert column.render_for_create() == "id INT" + + def test_render_for_create__not_null(self, column): + column.not_null = True + assert column.render_for_create() == "id INT NOT NULL" + + def test_render_for_create__comment(self, column): + column.comment = "this is a column" + assert column.render_for_create() == "id INT COMMENT 'this is a column'" + + def test_render_for_create__constraints(self, column): + column.constraints = [ + ColumnLevelConstraint(type=ConstraintType.primary_key), + ] + assert column.render_for_create() == "id INT PRIMARY KEY" + + def test_render_for_create__everything(self, column): + column.not_null = True + column.comment = "this is a column" + column.constraints = [ + ColumnLevelConstraint(type=ConstraintType.primary_key), + ColumnLevelConstraint(type=ConstraintType.custom, expression="foo"), + ] + assert column.render_for_create() == ( + "id INT NOT NULL COMMENT 'this is a column' " "PRIMARY KEY foo" + ) + + def test_render_for_create__escaping(self, column): + column.comment = "this is a 'column'" + assert column.render_for_create() == "id INT COMMENT 'this is a \\'column\\''" + + class TestColumnStatics: @pytest.mark.parametrize( "column, expected", diff --git a/tests/unit/test_constraints.py b/tests/unit/test_constraints.py new file mode 100644 index 00000000..0b6fc593 --- /dev/null +++ b/tests/unit/test_constraints.py @@ -0,0 +1,306 @@ +from unittest.mock import patch + +import pytest +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) +from dbt_common.exceptions import DbtValidationError + +from dbt.adapters.databricks import constraints + + +class TestConstraintsSupported: + @pytest.mark.parametrize( + "constraint_type, supported", + [ + (ConstraintType.not_null, True), + (ConstraintType.unique, False), + (ConstraintType.primary_key, True), + (ConstraintType.foreign_key, True), + (ConstraintType.check, True), + (ConstraintType.custom, True), + ("invalid", False), + ], + ) + def test_supported__expected(self, constraint_type, supported): + constraint = ColumnLevelConstraint(type=constraint_type) + assert constraints.is_supported(constraint) == supported + + +class TestConstraintsEnforced: + @pytest.mark.parametrize( + "constraint_type, enforced", + [ + (ConstraintType.not_null, True), + (ConstraintType.unique, False), + (ConstraintType.primary_key, False), + (ConstraintType.foreign_key, False), + (ConstraintType.check, True), + (ConstraintType.custom, False), + ("invalid", False), + ], + ) + def test_enforced__expected(self, constraint_type, enforced): + constraint = ColumnLevelConstraint(type=constraint_type) + assert constraints.is_enforced(constraint) == enforced + + +class TestParseConstraint: + @pytest.mark.parametrize("klass", [ColumnLevelConstraint, ModelLevelConstraint]) + def test_parse_constraint__valid_column(self, klass): + raw_constraint = {"type": "not_null"} + constraint = constraints.parse_constraint(klass, raw_constraint) + assert isinstance(constraint, klass) + assert constraint.type == ConstraintType.not_null + + @pytest.mark.parametrize("klass", [ColumnLevelConstraint, ModelLevelConstraint]) + def test_parse_constraint__invalid_column(self, klass): + raw_constraint = {"type": None} + with pytest.raises(DbtValidationError, match="Could not parse constraint"): + constraints.parse_constraint(klass, raw_constraint) + + +@pytest.fixture +def constraint(): + return ColumnLevelConstraint(type=ConstraintType.check) + + +class TestProcessConstraint: + @pytest.fixture + def success(self): + return "SUCCESS" + + @pytest.fixture + def render_map(self, constraint, success): + return {constraint.type: lambda _: success} + + def test_process_constraint__valid_constraint(self, constraint, render_map, success): + assert constraints.process_constraint(constraint, lambda _: True, render_map) == success + + def test_process_constraint__invalid_constraint(self, constraint, render_map): + assert constraints.process_constraint(constraint, lambda x: False, render_map) is None + + +class TestValidateConstraint: + @pytest.fixture + def pk_constraint(self): + return ColumnLevelConstraint( + type=ConstraintType.primary_key, warn_unsupported=True, warn_unenforced=True + ) + + def test_validate_constraint__custom(self): + constraint = ColumnLevelConstraint(type=ConstraintType.custom) + assert constraints.validate_constraint(constraint, lambda _: False) is True + + def test_validate_constraint__supported(self, pk_constraint): + assert constraints.validate_constraint(pk_constraint, lambda _: True) is True + + @patch("dbt.adapters.databricks.constraints.warn_or_error") + def test_validate_constraint__unsupported(self, pk_constraint): + with patch("dbt.adapters.databricks.constraints.ConstraintNotSupported") as mock_warn: + assert constraints.validate_constraint(pk_constraint, lambda _: False) is False + mock_warn.assert_called_with( + constraint=pk_constraint.type.value, adapter="DatabricksAdapter" + ) + + @patch("dbt.adapters.databricks.constraints.warn_or_error") + def test_validate_constraint__unenforced(self, pk_constraint): + with patch("dbt.adapters.databricks.constraints.ConstraintNotEnforced") as mock_warn: + assert constraints.validate_constraint(pk_constraint, lambda _: True) is True + mock_warn.assert_called_with( + constraint=pk_constraint.type.value, adapter="DatabricksAdapter" + ) + + +class TestRenderConstraint: + @pytest.fixture + def success(self): + return "CHECK (1 = 1)" + + @pytest.fixture + def render_map(self, constraint, success): + return {constraint.type: lambda _: success} + + def test_render_constraint__valid_constraint(self, constraint, render_map, success): + assert constraints.render_constraint(constraint, render_map) == success + + def test_render_constraint__invalid_constraint(self, render_map): + assert ( + constraints.render_constraint( + ColumnLevelConstraint(type=ConstraintType.not_null), render_map + ) + is None + ) + + def test_render_constraint__with_name(self, constraint, render_map, success): + constraint.name = "my_constraint" + assert ( + constraints.render_constraint(constraint, render_map) + == f"CONSTRAINT {constraint.name} {success}" + ) + + def test_render_constraint__excess_gets_trimmed(self, constraint): + constraint.name = "my_constraint" + render_map = {constraint.type: lambda _: "CHECK (1 = 1) "} + assert ( + constraints.render_constraint(constraint, render_map) + == f"CONSTRAINT {constraint.name} CHECK (1 = 1)" + ) + + +class TestSupportedFor: + @pytest.fixture + def warning(self): + return "Warning for {type}" + + def test_supported_for__supported(self, constraint, warning): + assert constraints.supported_for(constraint, lambda _: True, warning) is True + + def test_supported_for__not_supported_in_base(self, warning): + constraint = ColumnLevelConstraint(type=ConstraintType.unique) + assert constraints.supported_for(constraint, lambda _: True, warning) is False + + def test_supported_for__not_supported_in_context(self, constraint, warning): + with patch("dbt.adapters.databricks.constraints.logger.warning") as mock_warn: + assert constraints.supported_for(constraint, lambda _: False, warning) is False + mock_warn.assert_called_with(warning.format(type=constraint.type)) + + +class TestRenderError: + @pytest.mark.parametrize( + "constraint, missing, expected", + [ + ( + ColumnLevelConstraint(type=ConstraintType.check), + [["expression"]], + "check constraint is missing required field(s): ('expression')", + ), + ( + ColumnLevelConstraint(type=ConstraintType.foreign_key), + [["expression"], ["to", "to_columns"]], + ( + "foreign_key constraint is missing required field(s): " + "('expression') or ('to', 'to_columns')" + ), + ), + ( + ColumnLevelConstraint(type=ConstraintType.primary_key, name="my_pk"), + [["expression"]], + "primary_key constraint my_pk is missing required field(s): ('expression')", + ), + ], + ) + def test_render_error__expected(self, constraint, missing, expected): + assert constraints.render_error(constraint, missing).msg == expected + + +class TestRenderCustom: + def test_render_custom__valid(self): + constraint = ColumnLevelConstraint(type=ConstraintType.custom, expression="1 = 1") + assert constraints.render_custom(constraint) == "1 = 1" + + def test_render_custom__invalid(self): + constraint = ColumnLevelConstraint(type=ConstraintType.custom) + with pytest.raises(DbtValidationError, match="custom constraint is missing required field"): + constraints.render_custom(constraint) + + +class TestRenderPrimaryKeyForColumn: + def test_render_primary_key_for_column__valid(self): + constraint = ColumnLevelConstraint(type=ConstraintType.primary_key, expression="DEFERRABLE") + assert constraints.render_primary_key_for_column(constraint) == "PRIMARY KEY DEFERRABLE" + + def test_render_primary_key_for_column__no_expression(self): + constraint = ColumnLevelConstraint(type=ConstraintType.primary_key) + assert constraints.render_primary_key_for_column(constraint) == "PRIMARY KEY" + + +class TestRenderForeignKeyForColumn: + def test_render_foreign_key_for_column__valid_to(self): + constraint = ColumnLevelConstraint( + type=ConstraintType.foreign_key, + to="other_table", + to_columns=["other_id"], + ) + assert ( + constraints.render_foreign_key_for_column(constraint) + == "FOREIGN KEY REFERENCES other_table (other_id)" + ) + + def test_render_foreign_key_for_column__valid_expression(self): + constraint = ColumnLevelConstraint( + type=ConstraintType.foreign_key, expression="references other_table (other_id)" + ) + assert ( + constraints.render_foreign_key_for_column(constraint) + == "FOREIGN KEY references other_table (other_id)" + ) + + def test_render_foreign_key_for_column__invalid(self): + constraint = ColumnLevelConstraint(type=ConstraintType.foreign_key) + with pytest.raises( + DbtValidationError, match="foreign_key constraint is missing required field" + ): + constraints.render_foreign_key_for_column(constraint) + + +class TestRenderPrimaryKeyForModel: + def test_render_primary_key_for_model__valid_columns(self): + constraint = ModelLevelConstraint(type=ConstraintType.primary_key, columns=["id", "ts"]) + assert constraints.render_primary_key_for_model(constraint) == "PRIMARY KEY (id, ts)" + + def test_render_primary_key_for_model__expression(self): + constraint = ModelLevelConstraint( + type=ConstraintType.primary_key, expression="(id TIMESERIES)" + ) + assert constraints.render_primary_key_for_model(constraint) == "PRIMARY KEY (id TIMESERIES)" + + def test_render_primary_key_for_model__invalid(self): + constraint = ModelLevelConstraint(type=ConstraintType.primary_key) + with pytest.raises( + DbtValidationError, match="primary_key constraint is missing required field" + ): + constraints.render_primary_key_for_model(constraint) + + +class TestRenderForeignKeyForModel: + def test_render_foreign_key_for_model__valid_columns(self): + constraint = ModelLevelConstraint( + type=ConstraintType.foreign_key, + columns=["id"], + to="other_table", + to_columns=["other_id"], + ) + assert ( + constraints.render_foreign_key_for_model(constraint) + == "FOREIGN KEY (id) REFERENCES other_table (other_id)" + ) + + def test_render_foreign_key_for_model__expression(self): + constraint = ModelLevelConstraint( + type=ConstraintType.foreign_key, expression="references other_table (other_id)" + ) + assert ( + constraints.render_foreign_key_for_model(constraint) + == "FOREIGN KEY references other_table (other_id)" + ) + + def test_render_foreign_key_for_model__invalid(self): + constraint = ModelLevelConstraint(type=ConstraintType.foreign_key) + with pytest.raises( + DbtValidationError, match="foreign_key constraint is missing required field" + ): + constraints.render_foreign_key_for_model(constraint) + + +class TestRenderCheck: + def test_render_check__valid(self): + constraint = ModelLevelConstraint(type=ConstraintType.check, expression="id > 0") + assert constraints.render_check(constraint) == "CHECK (id > 0)" + + def test_render_check__invalid(self): + constraint = ModelLevelConstraint(type=ConstraintType.check) + with pytest.raises(DbtValidationError, match="check constraint is missing required field"): + constraints.render_check(constraint) diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 5bde2fa5..ab15cecf 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -1,4 +1,5 @@ import pytest +from dbt_common.contracts.constraints import ConstraintType, ModelLevelConstraint from dbt.adapters.databricks import relation from dbt.adapters.databricks.relation import DatabricksQuotePolicy, DatabricksRelation @@ -201,3 +202,60 @@ def test_is_hive_metastore(self, database, expected): ) def test_extract_identifiers(self, input, expected): assert relation.extract_identifiers(input) == expected + + +class TestConstraints: + @pytest.fixture + def relation(self): + return DatabricksRelation.create() + + @pytest.fixture + def custom_constraint(self): + return ModelLevelConstraint(type=ConstraintType.custom, expression="a > 1") + + @pytest.fixture + def check_constraint(self): + return ModelLevelConstraint(type=ConstraintType.check, expression="a > 1") + + @pytest.fixture + def pk_constraint(self): + return ModelLevelConstraint(type=ConstraintType.primary_key, columns=["a"]) + + def test_add_constraint__check_is_an_alter_constraint(self, relation, check_constraint): + relation.add_constraint(check_constraint) + assert relation.alter_constraints == [check_constraint] + assert relation.create_constraints == [] + + def test_add_constraint__other_constraints_are_create_constraints( + self, relation, check_constraint, custom_constraint, pk_constraint + ): + relation.add_constraint(check_constraint) + relation.add_constraint(custom_constraint) + relation.add_constraint(pk_constraint) + assert relation.alter_constraints == [check_constraint] + assert relation.create_constraints == [custom_constraint, pk_constraint] + + def test_enrich_relation__returns_a_copy(self, relation): + enriched = relation.enrich([]) + assert id(enriched) != id(relation) + + def test_enrich_relation__adds_constraints(self, relation, check_constraint, custom_constraint): + enriched = relation.enrich( + [{"type": "check", "expression": "a > 1"}, {"type": "custom", "expression": "a > 1"}] + ) + assert enriched.alter_constraints == [check_constraint] + assert enriched.create_constraints == [custom_constraint] + + def test_render_constraints_for_create__no_constraints(self, relation): + assert relation.render_constraints_for_create() == "" + + def test_render_constraints_for_create__check_is_ignored(self, relation, check_constraint): + relation.add_constraint(check_constraint) + assert relation.render_constraints_for_create() == "" + + def test_render_constraints_for_create__with_constraints( + self, relation, custom_constraint, pk_constraint + ): + relation.add_constraint(custom_constraint) + relation.add_constraint(pk_constraint) + assert relation.render_constraints_for_create() == "a > 1, PRIMARY KEY (a)"