Skip to content

Commit

Permalink
Move check_specs_by_output_name into InOutMapper (#22223)
Browse files Browse the repository at this point in the history
## Summary & Motivation

A ton of the complexity in this codepath is bookkeeping checks and assets to their respective inputs and outputs in the underlying op. In this case we move the logic of tracking which checks specs correspond to what outputs into its own property in the mapper

## How I Tested These Changes

BK
  • Loading branch information
schrockn authored and salazarm committed Jun 10, 2024
1 parent 09956c4 commit 146347b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 53 deletions.
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _from_node(
check_specs: Optional[Sequence[AssetCheckSpec]] = None,
owners_by_output_name: Optional[Mapping[str, Sequence[str]]] = None,
) -> "AssetsDefinition":
from dagster._core.definitions.decorators.asset_decorator import (
from dagster._core.definitions.decorators.assets_definition_factory import (
_assign_output_names_to_check_specs,
_validate_check_specs_target_relevant_asset_keys,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import Counter
from inspect import Parameter
from typing import (
AbstractSet,
Expand All @@ -25,7 +24,10 @@
from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.config import ConfigMapping
from dagster._core.definitions.decorators.assets_definition_factory import InOutMapper
from dagster._core.definitions.decorators.assets_definition_factory import (
InOutMapper,
validate_and_assign_output_names_to_check_specs,
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import ArbitraryMetadataMapping, RawMetadataMapping
from dagster._core.definitions.resource_annotation import get_resource_args
Expand Down Expand Up @@ -412,7 +414,7 @@ def __call__(self, fn: Callable[..., Any]) -> AssetsDefinition:
code_version=self.code_version,
)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
self.check_specs, [out_asset_key]
)
check_outs: Mapping[str, Out] = {
Expand Down Expand Up @@ -741,7 +743,9 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
)

in_out_mapper = InOutMapper.from_asset_ins_and_asset_outs(
asset_ins=asset_ins, asset_outs=output_tuples_by_asset_key
asset_ins=asset_ins,
asset_outs=output_tuples_by_asset_key,
check_specs=check_specs or [],
)

arg_resource_keys = {arg.name for arg in get_resource_args(fn)}
Expand All @@ -751,12 +755,9 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
" arguments to the decorated function",
)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs, list(output_tuples_by_asset_key.keys())
)
check_outs_by_output_name: Mapping[str, Out] = {
output_name: Out(dagster_type=None, is_required=not can_subset)
for output_name in check_specs_by_output_name.keys()
for output_name in in_out_mapper.check_specs_by_output_name.keys()
}
overlapping_output_names = (
in_out_mapper.asset_outs_by_output_name.keys() & check_outs_by_output_name.keys()
Expand Down Expand Up @@ -865,7 +866,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
resource_defs=resource_defs,
backfill_policy=backfill_policy,
selected_asset_keys=None, # no subselection in decorator
check_specs_by_output_name=check_specs_by_output_name,
check_specs_by_output_name=in_out_mapper.check_specs_by_output_name,
selected_asset_check_keys=None, # no subselection in decorator
is_subset=False,
specs=resolved_specs,
Expand Down Expand Up @@ -1176,7 +1177,7 @@ def graph_asset_no_defaults(
if asset_in.partition_mapping
}

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
check_specs, [out_asset_key]
)
check_outs_by_output_name: Mapping[str, GraphOut] = {
Expand Down Expand Up @@ -1278,7 +1279,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
}
asset_outs = build_asset_outs(outs)

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs_by_output_name = validate_and_assign_output_names_to_check_specs(
check_specs, list(asset_outs.keys())
)
check_outs_by_output_name: Mapping[str, GraphOut] = {
Expand Down Expand Up @@ -1428,39 +1429,3 @@ def make_asset_deps(deps: Optional[Iterable[CoercibleToAssetDep]]) -> Optional[I
dep_dict[asset_dep.asset_key] = asset_dep

return list(dep_dict.values())


def _assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]],
) -> Mapping[str, AssetCheckSpec]:
checks_by_output_name = {spec.get_python_identifier(): spec for spec in check_specs or []}
if check_specs and len(checks_by_output_name) != len(check_specs):
duplicates = {
item: count
for item, count in Counter(
[(spec.asset_key, spec.name) for spec in check_specs]
).items()
if count > 1
}

raise DagsterInvalidDefinitionError(f"Duplicate check specs: {duplicates}")

return checks_by_output_name


def _validate_check_specs_target_relevant_asset_keys(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> None:
for spec in check_specs or []:
if spec.asset_key not in valid_asset_keys:
raise DagsterInvalidDefinitionError(
f"Invalid asset key {spec.asset_key} in check spec {spec.name}. Must be one of"
f" {valid_asset_keys}"
)


def _validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> Mapping[str, AssetCheckSpec]:
_validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
return _assign_output_names_to_check_specs(check_specs)
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from collections import Counter
from functools import cached_property
from typing import Mapping, NamedTuple, Tuple
from typing import (
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.input import In
from dagster._core.definitions.output import Out
from dagster._core.errors import DagsterInvalidDefinitionError

from ..asset_check_spec import AssetCheckSpec


class InMapping(NamedTuple):
Expand All @@ -21,14 +32,17 @@ def __init__(
self,
in_mappings: Mapping[AssetKey, InMapping],
out_mappings: Mapping[AssetKey, OutMapping],
check_specs: Sequence[AssetCheckSpec],
) -> None:
self.in_mappings = in_mappings
self.out_mappings = out_mappings
self.check_specs = check_specs

@staticmethod
def from_asset_ins_and_asset_outs(
asset_ins: Mapping[AssetKey, Tuple[str, In]],
asset_outs: Mapping[AssetKey, Tuple[str, Out]],
check_specs: Sequence[AssetCheckSpec],
):
in_mappings = {
asset_key: InMapping(input_name, in_)
Expand All @@ -38,7 +52,7 @@ def from_asset_ins_and_asset_outs(
asset_key: OutMapping(output_name, out_)
for asset_key, (output_name, out_) in asset_outs.items()
}
return InOutMapper(in_mappings, out_mappings)
return InOutMapper(in_mappings, out_mappings, check_specs)

@cached_property
def asset_outs_by_output_name(self) -> Mapping[str, Out]:
Expand All @@ -50,3 +64,49 @@ def keys_by_output_name(self) -> Mapping[str, AssetKey]:
out_mapping.output_name: asset_key
for asset_key, out_mapping in self.out_mappings.items()
}

@cached_property
def asset_keys(self) -> Set[AssetKey]:
return set(self.out_mappings.keys())

@cached_property
def check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]:
return validate_and_assign_output_names_to_check_specs(
self.check_specs, list(self.asset_keys)
)


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> Mapping[str, AssetCheckSpec]:
_validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
return _assign_output_names_to_check_specs(check_specs)


def _assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]],
) -> Mapping[str, AssetCheckSpec]:
checks_by_output_name = {spec.get_python_identifier(): spec for spec in check_specs or []}
if check_specs and len(checks_by_output_name) != len(check_specs):
duplicates = {
item: count
for item, count in Counter(
[(spec.asset_key, spec.name) for spec in check_specs]
).items()
if count > 1
}

raise DagsterInvalidDefinitionError(f"Duplicate check specs: {duplicates}")

return checks_by_output_name


def _validate_check_specs_target_relevant_asset_keys(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> None:
for spec in check_specs or []:
if spec.asset_key not in valid_asset_keys:
raise DagsterInvalidDefinitionError(
f"Invalid asset key {spec.asset_key} in check spec {spec.name}. Must be one of"
f" {valid_asset_keys}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
_check as check,
define_asset_job,
)
from dagster._core.definitions.decorators.asset_decorator import (
_validate_and_assign_output_names_to_check_specs,
from dagster._core.definitions.decorators.assets_definition_factory import (
validate_and_assign_output_names_to_check_specs,
)
from dagster._core.definitions.metadata import TableMetadataSet
from dagster._utils.merger import merge_dicts
Expand Down Expand Up @@ -821,7 +821,7 @@ def get_asset_deps(

check_specs_by_output_name = cast(
Dict[str, AssetCheckSpec],
_validate_and_assign_output_names_to_check_specs(
validate_and_assign_output_names_to_check_specs(
list(check_specs_by_key.values()), list(asset_outs.keys())
),
)
Expand Down

0 comments on commit 146347b

Please sign in to comment.