From b1ec250cb64041bdb091cc5543ed6a6f4843bad6 Mon Sep 17 00:00:00 2001 From: Nick Schrock Date: Thu, 6 Jun 2024 04:10:47 -0400 Subject: [PATCH] Move spec resolution into mapper (#22229) ## Summary & Motivation We convert back and forth from non-spec code paths and then reconstruct the specs that *would* have been used to construct the `multi_asset`, which is a bit strange. This moves that logic into the factory machinery. ## How I Tested These Changes BK --- .../definitions/decorators/asset_decorator.py | 41 +------------- .../decorators/assets_definition_factory.py | 56 +++++++++++++++++-- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py index 01aeb5553f0a1..638b5ac1b41d1 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py @@ -685,6 +685,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: ins=ins or {}, fn=fn, op_name=op_name, + group_name=group_name, ) else: in_out_mapper = InOutMapper.from_asset_outs( @@ -696,6 +697,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: asset_deps=asset_deps, deps_directly_passed_to_multi_asset=upstream_asset_deps, can_subset=can_subset, + group_name=group_name, ) check.invariant( @@ -721,43 +723,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: code_version=code_version, )(fn) - if specs: - resolved_specs = specs - else: - resolved_specs = [] - input_deps_by_key = { - key: AssetDep( - asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key) - ) - for key in in_out_mapper.asset_keys_by_input_names.values() - } - input_deps = list(input_deps_by_key.values()) - for output_name, asset_out in asset_out_map.items(): - key = in_out_mapper.asset_keys_by_output_name[output_name] - if internal_asset_deps: - deps = [ - input_deps_by_key.get( - dep_key, - AssetDep( - asset=dep_key, - partition_mapping=in_out_mapper.partition_mappings.get(key), - ), - ) - for dep_key in internal_asset_deps.get(output_name, []) - ] - else: - deps = input_deps - - resolved_specs.append(asset_out.to_spec(key, deps=deps)) - - if group_name: - check.invariant( - all(spec.group_name is None for spec in resolved_specs), - "Cannot set group_name parameter on multi_asset if one or more of the" - " AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.", - ) - resolved_specs = [spec._replace(group_name=group_name) for spec in resolved_specs] - return AssetsDefinition.dagster_internal_init( keys_by_input_name=in_out_mapper.asset_keys_by_input_names, keys_by_output_name=in_out_mapper.asset_keys_by_output_name, @@ -770,7 +735,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: check_specs_by_output_name=in_out_mapper.check_specs_by_output_name, selected_asset_check_keys=None, # no subselection in decorator is_subset=False, - specs=resolved_specs, + specs=in_out_mapper.resolved_specs, ) return inner diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py b/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py index a1cceddf5fb94..3ccd2f78c79c0 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py @@ -190,8 +190,9 @@ def __init__( internal_deps: Mapping[AssetKey, Set[AssetKey]], can_subset: bool, deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]], - specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]], + spec_resolver: Callable[["InOutMapper"], Sequence[AssetSpec]], op_name: str, + group_name: Optional[str] = None, ) -> None: self.directly_passed_asset_ins = directly_passed_asset_ins self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key @@ -200,8 +201,9 @@ def __init__( self.internal_deps = internal_deps self.can_subset = can_subset self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset - self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset + self.spec_resolver = spec_resolver self.op_name = op_name + self.group_name = group_name @staticmethod def from_specs( @@ -212,6 +214,7 @@ def from_specs( ins: Mapping[str, AssetIn], fn: Callable[..., Any], op_name: str, + group_name: Optional[str], ): output_tuples_by_asset_key = {} for asset_spec in specs: @@ -266,8 +269,9 @@ def from_specs( can_subset=can_subset, # when specs are used deps are never passed to multi-asset deps_directly_passed_to_multi_asset=None, - specs_directly_passed_to_multi_asset=specs, + spec_resolver=lambda _: specs, op_name=op_name, + group_name=group_name, ) @staticmethod @@ -281,6 +285,7 @@ def from_asset_outs( check_specs: Sequence[AssetCheckSpec], can_subset: bool, op_name: str, + group_name: Optional[str], ): inputs_tuples_by_asset_key = build_asset_ins( fn, @@ -326,6 +331,34 @@ def from_asset_outs( keys_by_output_name = make_keys_by_output_name(output_tuples_by_asset_key) internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps} + def _spec_resolver(in_out_mapper: "InOutMapper") -> Sequence[AssetSpec]: + resolved_specs = [] + input_deps_by_key = { + key: AssetDep( + asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key) + ) + for key in in_out_mapper.asset_keys_by_input_names.values() + } + input_deps = list(input_deps_by_key.values()) + for output_name, asset_out in asset_out_map.items(): + key = in_out_mapper.asset_keys_by_output_name[output_name] + if asset_deps: + deps = [ + input_deps_by_key.get( + dep_key, + AssetDep( + asset=dep_key, + partition_mapping=in_out_mapper.partition_mappings.get(key), + ), + ) + for dep_key in asset_deps.get(output_name, []) + ] + else: + deps = input_deps + + resolved_specs.append(asset_out.to_spec(key, deps=deps)) + return resolved_specs + return InOutMapper( directly_passed_asset_ins=ins, input_tuples_by_asset_key=inputs_tuples_by_asset_key, @@ -334,8 +367,9 @@ def from_asset_outs( internal_deps=internal_deps, can_subset=can_subset, deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset, - specs_directly_passed_to_multi_asset=None, + spec_resolver=_spec_resolver, op_name=op_name, + group_name=group_name, ) @cached_property @@ -440,6 +474,20 @@ def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]: asset_name=self.op_name, ) + @cached_property + def resolved_specs(self) -> Sequence[AssetSpec]: + specs = self.spec_resolver(self) + if not self.group_name: + return specs + + check.invariant( + all(spec.group_name is None for spec in specs), + "Cannot set group_name parameter on multi_asset if one or more of the" + " AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.", + ) + + return [spec._replace(group_name=self.group_name) for spec in specs] + def validate_and_assign_output_names_to_check_specs( check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]