Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new constraint logic that will be used with V2 flag #846

Merged
merged 13 commits into from
Dec 6, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

- Introduced use_materialization_v2 flag for gating materialization revamps. ([844](https://github.com/databricks/dbt-databricks/pull/844))

### Under the Hood

- Implement new constraint logic for use_materialization_v2 flag ([846](https://github.com/databricks/dbt-databricks/pull/846/files))

## dbt-databricks 1.9.0 (TBD)

### Features
Expand Down
46 changes: 45 additions & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from dataclasses import dataclass
from typing import ClassVar
from typing import Any, ClassVar
from typing import Optional

from dbt.adapters.spark.column import SparkColumn
from dbt_common.contracts.constraints import ColumnLevelConstraint
from dbt_common.contracts.constraints import ConstraintType
from dbt.adapters.databricks import constraints


@dataclass
class DatabricksColumn(SparkColumn):
table_comment: Optional[str] = None
comment: Optional[str] = None
not_null: Optional[bool] = None
constraints: Optional[list[ColumnLevelConstraint]] = None

TYPE_LABELS: ClassVar[dict[str, str]] = {
"LONG": "BIGINT",
Expand All @@ -27,5 +32,44 @@ def create(cls, name: str, label_or_dtype: str) -> "DatabricksColumn":
def data_type(self) -> str:
return self.translate_type(self.dtype)

def add_constraint(self, constraint: ColumnLevelConstraint) -> None:
# On first constraint add, initialize constraint details
if self.constraints is None:
self.constraints = []
self.not_null = False
if constraint.type == ConstraintType.not_null:
self.not_null = True
else:
self.constraints.append(constraint)

def enrich(self, model_column: dict[str, Any]) -> "DatabricksColumn":
"""Create a copy that incorporates model column metadata, including constraints."""

data_type = model_column.get("data_type") or self.dtype
enriched_column = DatabricksColumn.create(self.name, data_type)
if model_column.get("description"):
enriched_column.comment = model_column["description"]

if model_column.get("constraints"):
for constraint in model_column["constraints"]:
parsed_constraint = constraints.parse_column_constraint(constraint)
enriched_column.add_constraint(parsed_constraint)

return enriched_column

def render_for_create(self) -> str:
"""Renders the column for building a create statement."""
column_str = f"{self.name} {self.dtype}"
if self.not_null:
column_str += " NOT NULL"
if self.comment:
comment = self.comment.replace("'", "\\'")
column_str += f" COMMENT '{comment}'"
for constraint in self.constraints or []:
c = constraints.process_column_constraint(constraint)
if c != "":
benc-db marked this conversation as resolved.
Show resolved Hide resolved
column_str += f" {c}"
return column_str

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)
233 changes: 233 additions & 0 deletions dbt/adapters/databricks/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from typing import Any, Callable, Optional, TypeVar
from dbt_common.contracts.constraints import (
ColumnLevelConstraint,
ModelLevelConstraint,
ConstraintType,
)
from dbt_common.exceptions import DbtValidationError
from dbt_common.events.functions import warn_or_error
from dbt.adapters.base import ConstraintSupport
from dbt.adapters.events.types import ConstraintNotSupported, ConstraintNotEnforced
from dbt.adapters.databricks.logging import logger
from functools import partial


# Support constants
CONSTRAINT_SUPPORT = {
ConstraintType.check: ConstraintSupport.ENFORCED,
ConstraintType.not_null: ConstraintSupport.ENFORCED,
ConstraintType.unique: ConstraintSupport.NOT_SUPPORTED,
ConstraintType.primary_key: ConstraintSupport.NOT_ENFORCED,
ConstraintType.foreign_key: ConstraintSupport.NOT_ENFORCED,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems ConstraintType.Custom is also supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm on the fence on that one. I don't think it hurts anything to say that custom is supported but it is not expected by the dbt framework (their code, like mine, just skips validation on custom and doesn't include it in their constraint support map). I think the weird case is that we don't really know whether Databricks is going to support or enforce a custom constraint, because it depends on what the user does with the custom constraint. Custom is basically just a catch-all for if there is some new feature that's not in the adapter yet, the user can manually specify the full expression for it.

}

SUPPORTED_FOR_COLUMN = {
ConstraintType.custom,
ConstraintType.primary_key,
ConstraintType.foreign_key,
}

# Types
"""Generic type variable for constraints."""
T = TypeVar("T", bound=ColumnLevelConstraint)

"""Function type for checking constraint support."""
SupportFunc = Callable[[T], bool]

"""Function type for rendering constraints."""
RenderFunc = Callable[[T], str]


# Base support and enforcement
def is_supported(constraint: ColumnLevelConstraint) -> bool:
if constraint.type == ConstraintType.custom:
return True
if constraint.type in CONSTRAINT_SUPPORT:
return CONSTRAINT_SUPPORT[constraint.type] != ConstraintSupport.NOT_SUPPORTED
return False


def is_enforced(constraint: ColumnLevelConstraint) -> bool:
return constraint.type in CONSTRAINT_SUPPORT and CONSTRAINT_SUPPORT[constraint.type] not in [
ConstraintSupport.NOT_ENFORCED,
ConstraintSupport.NOT_SUPPORTED,
]


# Core parsing, validation, and processing
def parse_constraint(klass: type[T], raw_constraint: dict[str, Any]) -> T:
try:
klass.validate(raw_constraint)
return klass.from_dict(raw_constraint)
except Exception:
raise DbtValidationError(f"Could not parse constraint: {raw_constraint}")


def process_constraint(
constraint: T,
support_func: SupportFunc,
render_funcs: dict[ConstraintType, RenderFunc],
) -> Optional[str]:
if validate_constraint(constraint, support_func):
return render_constraint(constraint, render_funcs)

return None


def validate_constraint(constraint: T, support_func: SupportFunc) -> bool:
# Custom constraints are always supported
if constraint.type == ConstraintType.custom:
return True

supported = support_func(constraint)

if constraint.warn_unsupported and not supported:
warn_or_error(
ConstraintNotSupported(constraint=constraint.type.value, adapter="DatabricksAdapter")
)
elif constraint.warn_unenforced and not is_enforced(constraint):
warn_or_error(
ConstraintNotEnforced(constraint=constraint.type.value, adapter="DatabricksAdapter")
)

return supported


def render_constraint(
constraint: T, render_funcs: dict[ConstraintType, RenderFunc]
) -> Optional[str]:
rendered_constraint = ""

if constraint.type in render_funcs:
if constraint.name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does constraint always have a name? Do we need to validate and throw or skip?
It seems in Databricks we always require the sql to specify CONSTRAINT with a name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if a constraint doesn't have a name then a.) you don't need to specify CONSTRAINT before the constraint definition (that is basically a keyword that tells DBSQL that the next token is going to be a name, rather than a constraint type) and b.) Databricks will generate one for you.

rendered_constraint = f"CONSTRAINT {constraint.name} "
rendered_constraint += render_funcs[constraint.type](constraint)

rendered_constraint = rendered_constraint.strip()

return rendered_constraint if rendered_constraint != "" else None


def supported_for(constraint: T, support_func: SupportFunc, warning: str) -> bool:
if is_supported(constraint) and not support_func(constraint):
logger.warning(warning.format(type=constraint.type))

return is_supported(constraint) and support_func(constraint)


# Shared render functions
def render_error(constraint: ColumnLevelConstraint, missing: list[list[str]]) -> DbtValidationError:
fields = " or ".join(["(" + ", ".join([f"'{e}'" for e in x]) + ")" for x in missing])
constraint_type = constraint.type.value
constraint_name = "" if not constraint.name else f"{constraint.name} "
return DbtValidationError(
f"{constraint_type} constraint {constraint_name}is missing required field(s): {fields}"
)


def render_custom(constraint: ColumnLevelConstraint) -> str:
assert constraint.type == ConstraintType.custom
if constraint.expression:
return constraint.expression
raise render_error(constraint, [["expression"]])


# ColumnLevelConstraint specialization
def parse_column_constraint(raw_constraint: dict[str, Any]) -> ColumnLevelConstraint:
return parse_constraint(ColumnLevelConstraint, raw_constraint)


COLUMN_WARNING = (
"While constraint of type {type} is supported for models, it is not supported for columns."
)

supported_for_columns = partial(
supported_for, support_func=lambda x: x.type in SUPPORTED_FOR_COLUMN, warning=COLUMN_WARNING
)


def column_constraint_map() -> dict[ConstraintType, RenderFunc]:
return {
ConstraintType.primary_key: render_primary_key_for_column,
ConstraintType.foreign_key: render_foreign_key_for_column,
ConstraintType.custom: render_custom,
}


def process_column_constraint(constraint: ColumnLevelConstraint) -> Optional[str]:
return process_constraint(constraint, supported_for_columns, column_constraint_map())


def render_primary_key_for_column(constraint: ColumnLevelConstraint) -> str:
assert constraint.type == ConstraintType.primary_key
rendered = "PRIMARY KEY"
if constraint.expression:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is expression always exist? If it's empty do we wanna just return PRIMARY KEY without column info?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, because this is a column level constraint, so it is perfectly valid to just state that the column is a PRIMARY KEY. In constrast, a model-level constraint needs to know which columns to treat as the primary key. The support for expression is in case users want to add a constraint option such as DEFERRABLE.

rendered += f" {constraint.expression}"
return rendered


def render_foreign_key_for_column(constraint: ColumnLevelConstraint) -> str:
assert constraint.type == ConstraintType.foreign_key
rendered = "FOREIGN KEY"
if constraint.expression:
return rendered + f" {constraint.expression}"
elif constraint.to and constraint.to_columns:
return rendered + f" REFERENCES {constraint.to} ({', '.join(constraint.to_columns)})"
raise render_error(constraint, [["expression"], ["to", "to_columns"]])


# ModelLevelConstraint specialization
def parse_model_constraint(raw_constraint: dict[str, Any]) -> ModelLevelConstraint:
return parse_constraint(ModelLevelConstraint, raw_constraint)


MODEL_WARNING = (
"While constraint of type {type} is supported for columns, it is not supported for models."
)


supported_for_models = partial(
supported_for, support_func=lambda x: x.type != ConstraintType.not_null, warning=MODEL_WARNING
)


def model_constraint_map() -> dict[ConstraintType, RenderFunc]:
return {
ConstraintType.primary_key: render_primary_key_for_model,
ConstraintType.foreign_key: render_foreign_key_for_model,
ConstraintType.custom: render_custom,
ConstraintType.check: render_check,
}


def process_model_constraint(constraint: ModelLevelConstraint) -> Optional[str]:
return process_constraint(constraint, supported_for_models, model_constraint_map())


def render_primary_key_for_model(constraint: ModelLevelConstraint) -> str:
prefix = render_primary_key_for_column(constraint)
if constraint.expression:
return prefix
if constraint.columns:
return f"{prefix} ({', '.join(constraint.columns)})"
raise render_error(constraint, [["columns"], ["expression"]])


def render_foreign_key_for_model(constraint: ModelLevelConstraint) -> str:
assert constraint.type == ConstraintType.foreign_key
rendered = "FOREIGN KEY"
if constraint.columns and constraint.to and constraint.to_columns:
columns = ", ".join(constraint.columns)
to_columns = ", ".join(constraint.to_columns)
return rendered + f" ({columns}) REFERENCES {constraint.to} ({to_columns})"
elif constraint.expression:
return rendered + " " + constraint.expression
raise render_error(constraint, [["expression"], ["columns", "to", "to_columns"]])


def render_check(constraint: ColumnLevelConstraint) -> str:
assert constraint.type == ConstraintType.check
if constraint.expression:
return f"CHECK ({constraint.expression})"

raise render_error(constraint, [["expression"]])
21 changes: 21 additions & 0 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dbt.adapters.contracts.connection import Connection
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.contracts.relation import RelationType
from dbt.adapters.databricks import constraints
from dbt.adapters.databricks.behaviors.columns import (
GetColumnsBehavior,
GetColumnsByDescribe,
Expand Down Expand Up @@ -197,6 +198,8 @@ class DatabricksAdapter(SparkAdapter):
}
)

CONSTRAINT_SUPPORT = constraints.CONSTRAINT_SUPPORT

get_column_behavior: GetColumnsBehavior

def __init__(self, config: Any, mp_context: SpawnContext) -> None:
Expand Down Expand Up @@ -781,6 +784,24 @@ def get_config_from_model(self, model: RelationConfig) -> DatabricksRelationConf
def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str:
return f"{suffix_initial}_{str(uuid4())}"

@available
@staticmethod
def get_enriched_columns(
existing_columns: list[DatabricksColumn], model_columns: dict[str, dict[str, Any]]
) -> list[DatabricksColumn]:
"""Returns a list of columns that have been updated with features for table create."""
enriched_columns = []

for column in existing_columns:
if column.name in model_columns:
column_info = model_columns[column.name]
enriched_column = column.enrich(column_info)
enriched_columns.append(enriched_column)
else:
enriched_columns.append(column)

return enriched_columns


@dataclass(frozen=True)
class RelationAPIBase(ABC, Generic[DatabricksRelationConfig]):
Expand Down
Loading
Loading