Skip to content

Commit

Permalink
Move spec resolution into mapper (#22229)
Browse files Browse the repository at this point in the history
## Summary & Motivation

We convert back and forth from non-spec code paths and then reconstruct the specs that *would* have been used to construct the `multi_asset`, which is a bit strange.

This moves that logic into the factory machinery.

## How I Tested These Changes

BK
  • Loading branch information
schrockn authored and salazarm committed Jun 10, 2024
1 parent 4cb1fd6 commit ff36c0c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
ins=ins or {},
fn=fn,
op_name=op_name,
group_name=group_name,
)
else:
in_out_mapper = InOutMapper.from_asset_outs(
Expand All @@ -696,6 +697,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
asset_deps=asset_deps,
deps_directly_passed_to_multi_asset=upstream_asset_deps,
can_subset=can_subset,
group_name=group_name,
)

check.invariant(
Expand All @@ -721,43 +723,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
code_version=code_version,
)(fn)

if specs:
resolved_specs = specs
else:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if internal_asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in internal_asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))

if group_name:
check.invariant(
all(spec.group_name is None for spec in resolved_specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)
resolved_specs = [spec._replace(group_name=group_name) for spec in resolved_specs]

return AssetsDefinition.dagster_internal_init(
keys_by_input_name=in_out_mapper.asset_keys_by_input_names,
keys_by_output_name=in_out_mapper.asset_keys_by_output_name,
Expand All @@ -770,7 +735,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
check_specs_by_output_name=in_out_mapper.check_specs_by_output_name,
selected_asset_check_keys=None, # no subselection in decorator
is_subset=False,
specs=resolved_specs,
specs=in_out_mapper.resolved_specs,
)

return inner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def __init__(
internal_deps: Mapping[AssetKey, Set[AssetKey]],
can_subset: bool,
deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]],
specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]],
spec_resolver: Callable[["InOutMapper"], Sequence[AssetSpec]],
op_name: str,
group_name: Optional[str] = None,
) -> None:
self.directly_passed_asset_ins = directly_passed_asset_ins
self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key
Expand All @@ -200,8 +201,9 @@ def __init__(
self.internal_deps = internal_deps
self.can_subset = can_subset
self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset
self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset
self.spec_resolver = spec_resolver
self.op_name = op_name
self.group_name = group_name

@staticmethod
def from_specs(
Expand All @@ -212,6 +214,7 @@ def from_specs(
ins: Mapping[str, AssetIn],
fn: Callable[..., Any],
op_name: str,
group_name: Optional[str],
):
output_tuples_by_asset_key = {}
for asset_spec in specs:
Expand Down Expand Up @@ -266,8 +269,9 @@ def from_specs(
can_subset=can_subset,
# when specs are used deps are never passed to multi-asset
deps_directly_passed_to_multi_asset=None,
specs_directly_passed_to_multi_asset=specs,
spec_resolver=lambda _: specs,
op_name=op_name,
group_name=group_name,
)

@staticmethod
Expand All @@ -281,6 +285,7 @@ def from_asset_outs(
check_specs: Sequence[AssetCheckSpec],
can_subset: bool,
op_name: str,
group_name: Optional[str],
):
inputs_tuples_by_asset_key = build_asset_ins(
fn,
Expand Down Expand Up @@ -326,6 +331,34 @@ def from_asset_outs(
keys_by_output_name = make_keys_by_output_name(output_tuples_by_asset_key)
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}

def _spec_resolver(in_out_mapper: "InOutMapper") -> Sequence[AssetSpec]:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))
return resolved_specs

return InOutMapper(
directly_passed_asset_ins=ins,
input_tuples_by_asset_key=inputs_tuples_by_asset_key,
Expand All @@ -334,8 +367,9 @@ def from_asset_outs(
internal_deps=internal_deps,
can_subset=can_subset,
deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset,
specs_directly_passed_to_multi_asset=None,
spec_resolver=_spec_resolver,
op_name=op_name,
group_name=group_name,
)

@cached_property
Expand Down Expand Up @@ -440,6 +474,20 @@ def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
asset_name=self.op_name,
)

@cached_property
def resolved_specs(self) -> Sequence[AssetSpec]:
specs = self.spec_resolver(self)
if not self.group_name:
return specs

check.invariant(
all(spec.group_name is None for spec in specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)

return [spec._replace(group_name=self.group_name) for spec in specs]


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
Expand Down

0 comments on commit ff36c0c

Please sign in to comment.