Skip to content

Commit

Permalink
Extract InOutMapper to begin refactoring AssetsDefinition constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Jun 2, 2024
1 parent b3a0338 commit b2757bf
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.config import ConfigMapping
from dagster._core.definitions.decorators.assets_definition_factory import InOutMapper
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import ArbitraryMetadataMapping, RawMetadataMapping
from dagster._core.definitions.resource_annotation import (
Expand Down Expand Up @@ -748,19 +749,17 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
f" {list(valid_asset_deps)[:20]}",
)

in_out_mapper = InOutMapper.from_asset_ins_and_asset_outs(
asset_ins=asset_ins, asset_outs=output_tuples_by_asset_key
)

arg_resource_keys = {arg.name for arg in get_resource_args(fn)}
check.param_invariant(
len(bare_required_resource_keys or []) == 0 or len(arg_resource_keys) == 0,
"Cannot specify resource requirements in both @multi_asset decorator and as"
" arguments to the decorated function",
)

asset_outs_by_output_name: Mapping[str, Out] = dict(output_tuples_by_asset_key.values())
keys_by_output_name = {
output_name: asset_key
for asset_key, (output_name, _) in output_tuples_by_asset_key.items()
}

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs, list(output_tuples_by_asset_key.keys())
)
Expand All @@ -769,14 +768,14 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
for output_name in check_specs_by_output_name.keys()
}
overlapping_output_names = (
asset_outs_by_output_name.keys() & check_outs_by_output_name.keys()
in_out_mapper.asset_outs_by_output_name.keys() & check_outs_by_output_name.keys()
)
check.invariant(
len(overlapping_output_names) == 0,
f"Check output names overlap with asset output names: {overlapping_output_names}",
)
combined_outs_by_output_name: Mapping[str, Out] = {
**asset_outs_by_output_name,
**in_out_mapper.asset_outs_by_output_name,
**check_outs_by_output_name,
}

Expand All @@ -787,7 +786,9 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
if spec.deps is not None
}
else:
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}
internal_deps = {
in_out_mapper.keys_by_output_name[name]: asset_deps[name] for name in asset_deps
}

# when a subsettable multi-asset is defined, it is possible that it will need to be
# broken into two separate parts, one which depends on the other. in order to represent
Expand Down Expand Up @@ -842,7 +843,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = keys_by_output_name[output_name]
key = in_out_mapper.keys_by_output_name[output_name]
if internal_asset_deps:
deps = [
input_deps_by_key.get(
Expand All @@ -866,7 +867,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:

return AssetsDefinition.dagster_internal_init(
keys_by_input_name=keys_by_input_name,
keys_by_output_name=keys_by_output_name,
keys_by_output_name=in_out_mapper.keys_by_output_name,
node_def=op,
partitions_def=partitions_def,
can_subset=can_subset,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from functools import cached_property
from typing import Mapping, NamedTuple, Tuple

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.input import In
from dagster._core.definitions.output import Out


class InMapping(NamedTuple):
input_name: str
input: In


class OutMapping(NamedTuple):
output_name: str
output: Out


class InOutMapper:
def __init__(
self,
in_mappings: Mapping[AssetKey, InMapping],
out_mappings: Mapping[AssetKey, OutMapping],
) -> None:
self.in_mappings = in_mappings
self.out_mappings = out_mappings

@staticmethod
def from_asset_ins_and_asset_outs(
asset_ins: Mapping[AssetKey, Tuple[str, In]],
asset_outs: Mapping[AssetKey, Tuple[str, Out]],
):
in_mappings = {
asset_key: InMapping(input_name, in_)
for asset_key, (input_name, in_) in asset_ins.items()
}
out_mappings = {
asset_key: OutMapping(output_name, out_)
for asset_key, (output_name, out_) in asset_outs.items()
}
return InOutMapper(in_mappings, out_mappings)

@cached_property
def asset_outs_by_output_name(self) -> Mapping[str, Out]:
return dict(self.out_mappings.values())

@cached_property
def keys_by_output_name(self) -> Mapping[str, AssetKey]:
return {
out_mapping.output_name: asset_key
for asset_key, out_mapping in self.out_mappings.items()
}

0 comments on commit b2757bf

Please sign in to comment.