Skip to content

Commit

Permalink
Move op creation into builder
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Jun 9, 2024
1 parent a59180d commit cb2b629
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Optional, Sequence, Set, Union
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Set,
Union,
)

from typing_extensions import TypeAlias

Expand All @@ -14,6 +26,7 @@
from dagster._core.definitions.output import Out
from dagster._core.definitions.policy import RetryPolicy
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.utils import DEFAULT_OUTPUT
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.execution.build_resources import wrap_resources_for_execution
from dagster._core.storage.tags import COMPUTE_KIND_TAG
Expand All @@ -24,12 +37,16 @@
DecoratorAssetsDefinitionBuilder,
DecoratorAssetsDefinitionBuilderArgs,
NamedIn,
NamedOut,
build_named_ins,
compute_required_resource_keys,
get_function_params_without_context_or_config_or_resources,
)
from .op_decorator import _Op

if TYPE_CHECKING:
from dagster._core.definitions.base_asset_graph import AssetKeyOrCheckKey

AssetCheckFunctionReturn: TypeAlias = AssetCheckResult
AssetCheckFunction: TypeAlias = Callable[..., AssetCheckFunctionReturn]

Expand Down Expand Up @@ -208,7 +225,7 @@ def inner(fn: AssetCheckFunction) -> AssetChecksDefinition:

builder_args = DecoratorAssetsDefinitionBuilderArgs(
decorator_name="@asset_check",
name=name,
name=spec.get_python_identifier(),
description=description,
required_resource_keys=required_resource_keys or set(),
config_schema=config_schema,
Expand All @@ -232,33 +249,44 @@ def inner(fn: AssetCheckFunction) -> AssetChecksDefinition:
asset_out_map={},
)

named_outs_by_key: Dict[AssetKeyOrCheckKey, NamedOut] = {
spec.key: NamedOut(output_name=DEFAULT_OUTPUT, output=Out(dagster_type=None))
}

builder = DecoratorAssetsDefinitionBuilder(
named_ins_by_asset_key=named_in_by_asset_key,
named_outs_by_asset_key={},
named_outs_by_asset_graph_key=named_outs_by_key,
internal_deps={},
op_name=spec.get_python_identifier(),
args=builder_args,
fn=fn,
)

op_required_resource_keys = builder.required_resource_keys

out = Out(dagster_type=None)

op_def = _Op(
name=spec.get_python_identifier(),
ins=dict(named_in_by_asset_key.values()),
out=out,
# Any resource requirements specified as arguments will be identified as
# part of the Op definition instantiation
required_resource_keys=op_required_resource_keys,
tags={
**({COMPUTE_KIND_TAG: compute_kind} if compute_kind else {}),
**(op_tags or {}),
},
config_schema=config_schema,
retry_policy=retry_policy,
)(fn)
# op_required_resource_keys = builder.required_resource_keys

# out = Out(dagster_type=None)

# old_op_def = _Op(
# name=spec.get_python_identifier(),
# ins=dict(named_in_by_asset_key.values()),
# out=out,
# # Any resource requirements specified as arguments will be identified as
# # part of the Op definition instantiation
# required_resource_keys=op_required_resource_keys,
# tags={
# **({COMPUTE_KIND_TAG: compute_kind} if compute_kind else {}),
# **(op_tags or {}),
# },
# config_schema=config_schema,
# retry_policy=retry_policy,
# )(fn)

op_def = builder.create_op_definition()

# check.invariant(
# [od.name for od in old_op_def.output_defs] == [od.name for od in op_def.output_defs],
# f"Comparing {old_op_def.output_defs} to {op_def.output_defs}",
# )

return AssetChecksDefinition.create(
keys_by_input_name={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_partition_mappings_from_deps,
)
from dagster._core.definitions.backfill_policy import BackfillPolicy
from dagster._core.definitions.base_asset_graph import AssetKeyOrCheckKey
from dagster._core.definitions.input import In
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.definitions.output import Out
Expand Down Expand Up @@ -152,7 +153,7 @@ def build_named_outs(asset_outs: Mapping[str, AssetOut]) -> Mapping[AssetKey, "N

def build_subsettable_named_ins(
asset_ins: Mapping[AssetKey, Tuple[str, In]],
asset_outs: Mapping[AssetKey, Tuple[str, Out]],
asset_outs: Mapping[AssetKeyOrCheckKey, Tuple[str, Out]],
internal_upstream_deps: Iterable[AbstractSet[AssetKey]],
) -> Mapping[AssetKey, "NamedIn"]:
"""Creates a mapping from AssetKey to (name of input, In object) for any asset key that is not
Expand All @@ -166,7 +167,8 @@ def build_subsettable_named_ins(
# set of asset keys which are upstream of another asset, and are not currently inputs
potential_deps = set().union(*internal_upstream_deps).difference(set(asset_ins.keys()))
return {
key: NamedIn(f"{ASSET_SUBSET_INPUT_PREFIX}{name}", In(Nothing))
# anything in potential deps should be an AssetKey ATM
check.inst(key, AssetKey): NamedIn(f"{ASSET_SUBSET_INPUT_PREFIX}{name}", In(Nothing))
for key, (name, _) in asset_outs.items()
if key in potential_deps
}
Expand Down Expand Up @@ -236,13 +238,13 @@ def __init__(
self,
*,
named_ins_by_asset_key: Mapping[AssetKey, NamedIn],
named_outs_by_asset_key: Mapping[AssetKey, NamedOut],
named_outs_by_asset_graph_key: Mapping[AssetKeyOrCheckKey, NamedOut],
internal_deps: Mapping[AssetKey, Set[AssetKey]],
op_name: str,
args: DecoratorAssetsDefinitionBuilderArgs,
fn: Callable[..., Any],
) -> None:
self.named_outs_by_asset_key = named_outs_by_asset_key
self.named_outs_by_asset_graph_key = named_outs_by_asset_graph_key
self.internal_deps = internal_deps
self.op_name = op_name
self.args = args
Expand All @@ -254,7 +256,7 @@ def __init__(
**named_ins_by_asset_key,
**build_subsettable_named_ins(
named_ins_by_asset_key,
named_outs_by_asset_key,
named_outs_by_asset_graph_key,
self.internal_deps.values(),
),
}
Expand Down Expand Up @@ -317,10 +319,10 @@ def from_multi_asset_specs(
) -> "DecoratorAssetsDefinitionBuilder":
check.param_invariant(passed_args.specs, "passed_args", "Must use specs in this codepath")

named_outs_by_asset_key: Mapping[AssetKey, NamedOut] = {}
named_outs_by_asset_graph_key: Mapping[AssetKeyOrCheckKey, NamedOut] = {}
for asset_spec in asset_specs:
output_name = asset_spec.key.to_python_identifier()
named_outs_by_asset_key[asset_spec.key] = NamedOut(
named_outs_by_asset_graph_key[asset_spec.key] = NamedOut(
output_name,
Out(
Nothing,
Expand All @@ -334,9 +336,12 @@ def from_multi_asset_specs(
upstream_keys = set()
for spec in asset_specs:
for dep in spec.deps:
if dep.asset_key not in named_outs_by_asset_key:
if dep.asset_key not in named_outs_by_asset_graph_key:
upstream_keys.add(dep.asset_key)
if dep.asset_key in named_outs_by_asset_key and dep.partition_mapping is not None:
if (
dep.asset_key in named_outs_by_asset_graph_key
and dep.partition_mapping is not None
):
# self-dependent asset also needs to be considered an upstream_key
upstream_keys.add(dep.asset_key)

Expand All @@ -357,9 +362,13 @@ def from_multi_asset_specs(
if spec.deps is not None
}

_validate_check_specs_target_relevant_asset_keys(
passed_args.check_specs, [spec.key for spec in asset_specs]
)

return DecoratorAssetsDefinitionBuilder(
named_ins_by_asset_key=named_ins_by_asset_key,
named_outs_by_asset_key=named_outs_by_asset_key,
named_outs_by_asset_graph_key=named_outs_by_asset_graph_key,
internal_deps=internal_deps,
op_name=op_name,
args=passed_args,
Expand Down Expand Up @@ -420,9 +429,13 @@ def from_asset_outs_in_asset_centric_decorator(
keys_by_output_name = make_keys_by_output_name(named_outs_by_asset_key)
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}

_validate_check_specs_target_relevant_asset_keys(
passed_args.check_specs, list(named_outs_by_asset_key.keys())
)

return DecoratorAssetsDefinitionBuilder(
named_ins_by_asset_key=named_ins_by_asset_key,
named_outs_by_asset_key=named_outs_by_asset_key,
named_outs_by_asset_graph_key=named_outs_by_asset_key, # type: ignore
internal_deps=internal_deps,
op_name=op_name,
args=passed_args,
Expand All @@ -435,7 +448,7 @@ def group_name(self) -> Optional[str]:

@cached_property
def outs_by_output_name(self) -> Mapping[str, Out]:
return dict(self.named_outs_by_asset_key.values())
return dict(self.named_outs_by_asset_graph_key.values())

@cached_property
def asset_keys_by_input_name(self) -> Mapping[str, AssetKey]:
Expand All @@ -447,13 +460,16 @@ def asset_keys_by_input_name(self) -> Mapping[str, AssetKey]:
@cached_property
def asset_keys_by_output_name(self) -> Mapping[str, AssetKey]:
return {
out_mapping.output_name: asset_key
for asset_key, out_mapping in self.named_outs_by_asset_key.items()
out_mapping.output_name: kry
for kry, out_mapping in self.named_outs_by_asset_graph_key.items()
if isinstance(kry, AssetKey)
}

@cached_property
def asset_keys(self) -> Set[AssetKey]:
return set(self.named_outs_by_asset_key.keys())
return {
key for key in self.named_outs_by_asset_graph_key.keys() if isinstance(key, AssetKey)
}

@cached_property
def check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]:
Expand All @@ -470,6 +486,9 @@ def check_outs_by_output_name(self) -> Mapping[str, Out]:

@cached_property
def combined_outs_by_output_name(self) -> Mapping[str, Out]:
if self.args.decorator_name == "@asset_check":
return self.outs_by_output_name

return {
**self.outs_by_output_name,
**self.check_outs_by_output_name,
Expand Down Expand Up @@ -516,7 +535,7 @@ def required_resource_keys(self) -> AbstractSet[str]:
decorator_name=self.args.decorator_name,
)

def _create_op_definition(self) -> OpDefinition:
def create_op_definition(self) -> OpDefinition:
return _Op(
name=self.op_name,
description=self.args.description,
Expand All @@ -536,7 +555,7 @@ def create_assets_definition(self) -> AssetsDefinition:
return AssetsDefinition.dagster_internal_init(
keys_by_input_name=self.asset_keys_by_input_names,
keys_by_output_name=self.asset_keys_by_output_name,
node_def=self._create_op_definition(),
node_def=self.create_op_definition(),
partitions_def=self.args.partitions_def,
can_subset=self.args.can_subset,
resource_defs=self.args.assets_def_resource_defs,
Expand Down Expand Up @@ -595,7 +614,7 @@ def _synthesize_specs(self) -> Sequence[AssetSpec]:
def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
) -> Mapping[str, AssetCheckSpec]:
_validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
# _validate_check_specs_target_relevant_asset_keys(check_specs, valid_asset_keys)
return _assign_output_names_to_check_specs(check_specs)


Expand Down

0 comments on commit cb2b629

Please sign in to comment.