Skip to content

Commit

Permalink
Move spec resolution into mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Jun 2, 2024
1 parent 9a30fe6 commit 4a40d8f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
ins=ins or {},
fn=fn,
op_name=op_name,
group_name=group_name,
)
else:
in_out_mapper = InOutMapper.non_spec_based_create(
Expand All @@ -701,6 +702,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
asset_deps=asset_deps,
deps_directly_passed_to_multi_asset=upstream_asset_deps,
can_subset=can_subset,
group_name=group_name,
)

check.invariant(
Expand All @@ -726,43 +728,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
code_version=code_version,
)(fn)

if specs:
resolved_specs = specs
else:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if internal_asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in internal_asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))

if group_name:
check.invariant(
all(spec.group_name is None for spec in resolved_specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)
resolved_specs = [spec._replace(group_name=group_name) for spec in resolved_specs]

return AssetsDefinition.dagster_internal_init(
keys_by_input_name=in_out_mapper.asset_keys_by_input_names,
keys_by_output_name=in_out_mapper.asset_keys_by_output_name,
Expand All @@ -775,7 +740,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
check_specs_by_output_name=in_out_mapper.check_specs_by_output_name,
selected_asset_check_keys=None, # no subselection in decorator
is_subset=False,
specs=resolved_specs,
specs=in_out_mapper.resolved_specs,
)

return inner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ def __init__(
internal_deps: Mapping[AssetKey, Set[AssetKey]],
can_subset: bool,
deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]],
specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]],
spec_resolver: Callable[["InOutMapper"], Sequence[AssetSpec]],
op_name: str,
group_name: Optional[str] = None,
) -> None:
self.directly_passed_asset_ins = directly_passed_asset_ins
self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key
Expand All @@ -223,8 +224,9 @@ def __init__(
self.internal_deps = internal_deps
self.can_subset = can_subset
self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset
self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset
self.spec_resolver = spec_resolver
self.op_name = op_name
self.group_name = group_name

@staticmethod
def from_specs(
Expand All @@ -235,6 +237,7 @@ def from_specs(
ins: Mapping[str, AssetIn],
fn: Callable[..., Any],
op_name: str,
group_name: Optional[str],
):
output_tuples_by_asset_key = {}
for asset_spec in specs:
Expand Down Expand Up @@ -289,8 +292,9 @@ def from_specs(
can_subset=can_subset,
# when specs are used deps are never passed to multi-asset
deps_directly_passed_to_multi_asset=None,
specs_directly_passed_to_multi_asset=specs,
spec_resolver=lambda _: specs,
op_name=op_name,
group_name=group_name,
)

@staticmethod
Expand All @@ -304,6 +308,7 @@ def non_spec_based_create(
check_specs: Sequence[AssetCheckSpec],
can_subset: bool,
op_name: str,
group_name: Optional[str],
):
inputs_tuples_by_asset_key = build_asset_ins(
fn,
Expand Down Expand Up @@ -349,6 +354,34 @@ def non_spec_based_create(
keys_by_output_name = make_keys_by_output_name(output_tuples_by_asset_key)
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}

def _spec_resolver(in_out_mapper: "InOutMapper") -> Sequence[AssetSpec]:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))
return resolved_specs

return InOutMapper(
directly_passed_asset_ins=ins,
input_tuples_by_asset_key=inputs_tuples_by_asset_key,
Expand All @@ -357,8 +390,9 @@ def non_spec_based_create(
internal_deps=internal_deps,
can_subset=can_subset,
deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset,
specs_directly_passed_to_multi_asset=None,
spec_resolver=_spec_resolver,
op_name=op_name,
group_name=group_name,
)

@cached_property
Expand Down Expand Up @@ -467,6 +501,20 @@ def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
asset_name=self.op_name,
)

@cached_property
def resolved_specs(self) -> Sequence[AssetSpec]:
specs = self.spec_resolver(self)
if not self.group_name:
return specs

check.invariant(
all(spec.group_name is None for spec in specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)

return [spec._replace(group_name=self.group_name) for spec in specs]


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
Expand Down

0 comments on commit 4a40d8f

Please sign in to comment.