Skip to content

Commit

Permalink
Move check_specs_by_output_name into InOutMapper
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Jun 2, 2024
1 parent b2757bf commit 96f3fc9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 50 deletions.
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _from_node(
check_specs: Optional[Sequence[AssetCheckSpec]] = None,
owners_by_output_name: Optional[Mapping[str, Sequence[str]]] = None,
) -> "AssetsDefinition":
from dagster._core.definitions.decorators.asset_decorator import (
from dagster._core.definitions.decorators.assets_definition_factory import (
_assign_output_names_to_check_specs,
_validate_check_specs_target_relevant_asset_keys,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import Counter
from inspect import Parameter
from typing import (
AbstractSet,
Expand All @@ -25,7 +24,10 @@
from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.config import ConfigMapping
from dagster._core.definitions.decorators.assets_definition_factory import InOutMapper
from dagster._core.definitions.decorators.assets_definition_factory import (
InOutMapper,
validate_and_assign_output_names_to_check_specs,
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import ArbitraryMetadataMapping, RawMetadataMapping
from dagster._core.definitions.resource_annotation import (
Expand Down Expand Up @@ -423,7 +425,7 @@ def __call__(self, fn: Callable[..., Any]) -> AssetsDefinition:
code_version=self.code_version,
)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
self.check_specs, [out_asset_key]
)
check_outs: Mapping[str, Out] = {
Expand Down Expand Up @@ -750,7 +752,9 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
)

in_out_mapper = InOutMapper.from_asset_ins_and_asset_outs(
asset_ins=asset_ins, asset_outs=output_tuples_by_asset_key
asset_ins=asset_ins,
asset_outs=output_tuples_by_asset_key,
check_specs=check_specs or [],
)

arg_resource_keys = {arg.name for arg in get_resource_args(fn)}
Expand All @@ -760,12 +764,9 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
" arguments to the decorated function",
)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs, list(output_tuples_by_asset_key.keys())
)
check_outs_by_output_name: Mapping[str, Out] = {
output_name: Out(dagster_type=None, is_required=not can_subset)
for output_name in check_specs_by_output_name.keys()
for output_name in in_out_mapper.check_specs_by_output_name.keys()
}
overlapping_output_names = (
in_out_mapper.asset_outs_by_output_name.keys() & check_outs_by_output_name.keys()
Expand Down Expand Up @@ -874,7 +875,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
resource_defs=resource_defs,
backfill_policy=backfill_policy,
selected_asset_keys=None, # no subselection in decorator
check_specs_by_output_name=check_specs_by_output_name,
check_specs_by_output_name=in_out_mapper.check_specs_by_output_name,
selected_asset_check_keys=None, # no subselection in decorator
is_subset=False,
specs=resolved_specs,
Expand Down Expand Up @@ -1185,7 +1186,7 @@ def graph_asset_no_defaults(
if asset_in.partition_mapping
}

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
check_specs, [out_asset_key]
)
check_outs_by_output_name: Mapping[str, GraphOut] = {
Expand Down Expand Up @@ -1287,7 +1288,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
}
asset_outs = build_asset_outs(outs)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
check_specs, list(asset_outs.keys())
)
check_outs_by_output_name: Mapping[str, GraphOut] = {
Expand Down Expand Up @@ -1430,39 +1431,3 @@ def make_asset_deps(deps: Optional[Iterable[CoercibleToAssetDep]]) -> Optional[I
dep_dict[asset_dep.asset_key] = asset_dep

return list(dep_dict.values())


def _assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]],
) -> Mapping[str, AssetCheckSpec]:
checks_by_output_name = {spec.get_python_identifier(): spec for spec in check_specs or []}
if check_specs and len(checks_by_output_name) != len(check_specs):
duplicates = {
item: count
for item, count in Counter(
[(spec.asset_key, spec.name) for spec in check_specs]
).items()
if count > 1
}

raise DagsterInvalidDefinitionError(f"Duplicate check specs: {duplicates}")

return checks_by_output_name


def _validate_check_specs_target_relevant_asset_keys(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> None:
for spec in check_specs or []:
if spec.asset_key not in valid_asset_keys:
raise DagsterInvalidDefinitionError(
f"Invalid asset key {spec.asset_key} in check spec {spec.name}. Must be one of"
f" {valid_asset_keys}"
)


def _validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> Mapping[str, AssetCheckSpec]:
_validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
return _assign_output_names_to_check_specs(check_specs)
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from collections import Counter
from functools import cached_property
from typing import Mapping, NamedTuple, Tuple
from typing import (
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.input import In
from dagster._core.definitions.output import Out
from dagster._core.errors import DagsterInvalidDefinitionError

from ..asset_check_spec import AssetCheckSpec


class InMapping(NamedTuple):
Expand All @@ -21,14 +32,17 @@ def __init__(
self,
in_mappings: Mapping[AssetKey, InMapping],
out_mappings: Mapping[AssetKey, OutMapping],
check_specs: Sequence[AssetCheckSpec],
) -> None:
self.in_mappings = in_mappings
self.out_mappings = out_mappings
self.check_specs = check_specs

@staticmethod
def from_asset_ins_and_asset_outs(
asset_ins: Mapping[AssetKey, Tuple[str, In]],
asset_outs: Mapping[AssetKey, Tuple[str, Out]],
check_specs: Sequence[AssetCheckSpec],
):
in_mappings = {
asset_key: InMapping(input_name, in_)
Expand All @@ -38,7 +52,7 @@ def from_asset_ins_and_asset_outs(
asset_key: OutMapping(output_name, out_)
for asset_key, (output_name, out_) in asset_outs.items()
}
return InOutMapper(in_mappings, out_mappings)
return InOutMapper(in_mappings, out_mappings, check_specs)

@cached_property
def asset_outs_by_output_name(self) -> Mapping[str, Out]:
Expand All @@ -50,3 +64,49 @@ def keys_by_output_name(self) -> Mapping[str, AssetKey]:
out_mapping.output_name: asset_key
for asset_key, out_mapping in self.out_mappings.items()
}

@cached_property
def asset_keys(self) -> Set[AssetKey]:
return set(self.out_mappings.keys())

@cached_property
def check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]:
return validate_and_assign_output_names_to_check_specs(
self.check_specs, list(self.asset_keys)
)


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> Mapping[str, AssetCheckSpec]:
_validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
return _assign_output_names_to_check_specs(check_specs)


def _assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]],
) -> Mapping[str, AssetCheckSpec]:
checks_by_output_name = {spec.get_python_identifier(): spec for spec in check_specs or []}
if check_specs and len(checks_by_output_name) != len(check_specs):
duplicates = {
item: count
for item, count in Counter(
[(spec.asset_key, spec.name) for spec in check_specs]
).items()
if count > 1
}

raise DagsterInvalidDefinitionError(f"Duplicate check specs: {duplicates}")

return checks_by_output_name


def _validate_check_specs_target_relevant_asset_keys(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> None:
for spec in check_specs or []:
if spec.asset_key not in valid_asset_keys:
raise DagsterInvalidDefinitionError(
f"Invalid asset key {spec.asset_key} in check spec {spec.name}. Must be one of"
f" {valid_asset_keys}"
)

0 comments on commit 96f3fc9

Please sign in to comment.