From 4a40d8ffe61f5c070629a8bf8854a8204b68f7de Mon Sep 17 00:00:00 2001 From: Nick Schrock Date: Sun, 2 Jun 2024 16:40:38 -0400 Subject: [PATCH] Move spec resolution into mapper --- .../definitions/decorators/asset_decorator.py | 41 +------------- .../decorators/assets_definition_factory.py | 56 +++++++++++++++++-- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py index f2753bb1435de..a9ba3fbcd584e 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py @@ -690,6 +690,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: ins=ins or {}, fn=fn, op_name=op_name, + group_name=group_name, ) else: in_out_mapper = InOutMapper.non_spec_based_create( @@ -701,6 +702,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: asset_deps=asset_deps, deps_directly_passed_to_multi_asset=upstream_asset_deps, can_subset=can_subset, + group_name=group_name, ) check.invariant( @@ -726,43 +728,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: code_version=code_version, )(fn) - if specs: - resolved_specs = specs - else: - resolved_specs = [] - input_deps_by_key = { - key: AssetDep( - asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key) - ) - for key in in_out_mapper.asset_keys_by_input_names.values() - } - input_deps = list(input_deps_by_key.values()) - for output_name, asset_out in asset_out_map.items(): - key = in_out_mapper.asset_keys_by_output_name[output_name] - if internal_asset_deps: - deps = [ - input_deps_by_key.get( - dep_key, - AssetDep( - asset=dep_key, - partition_mapping=in_out_mapper.partition_mappings.get(key), - ), - ) - for dep_key in internal_asset_deps.get(output_name, []) - ] - else: - deps = input_deps - - resolved_specs.append(asset_out.to_spec(key, deps=deps)) - - if group_name: - check.invariant( - all(spec.group_name is None for spec in resolved_specs), - "Cannot set group_name parameter on multi_asset if one or more of the" - " AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.", - ) - resolved_specs = [spec._replace(group_name=group_name) for spec in resolved_specs] - return AssetsDefinition.dagster_internal_init( keys_by_input_name=in_out_mapper.asset_keys_by_input_names, keys_by_output_name=in_out_mapper.asset_keys_by_output_name, @@ -775,7 +740,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: check_specs_by_output_name=in_out_mapper.check_specs_by_output_name, selected_asset_check_keys=None, # no subselection in decorator is_subset=False, - specs=resolved_specs, + specs=in_out_mapper.resolved_specs, ) return inner diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py b/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py index 8b7d1cb043cd9..ce85e827a4a48 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/assets_definition_factory.py @@ -213,8 +213,9 @@ def __init__( internal_deps: Mapping[AssetKey, Set[AssetKey]], can_subset: bool, deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]], - specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]], + spec_resolver: Callable[["InOutMapper"], Sequence[AssetSpec]], op_name: str, + group_name: Optional[str] = None, ) -> None: self.directly_passed_asset_ins = directly_passed_asset_ins self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key @@ -223,8 +224,9 @@ def __init__( self.internal_deps = internal_deps self.can_subset = can_subset self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset - self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset + self.spec_resolver = spec_resolver self.op_name = op_name + self.group_name = group_name @staticmethod def from_specs( @@ -235,6 +237,7 @@ def from_specs( ins: Mapping[str, AssetIn], fn: Callable[..., Any], op_name: str, + group_name: Optional[str], ): output_tuples_by_asset_key = {} for asset_spec in specs: @@ -289,8 +292,9 @@ def from_specs( can_subset=can_subset, # when specs are used deps are never passed to multi-asset deps_directly_passed_to_multi_asset=None, - specs_directly_passed_to_multi_asset=specs, + spec_resolver=lambda _: specs, op_name=op_name, + group_name=group_name, ) @staticmethod @@ -304,6 +308,7 @@ def non_spec_based_create( check_specs: Sequence[AssetCheckSpec], can_subset: bool, op_name: str, + group_name: Optional[str], ): inputs_tuples_by_asset_key = build_asset_ins( fn, @@ -349,6 +354,34 @@ def non_spec_based_create( keys_by_output_name = make_keys_by_output_name(output_tuples_by_asset_key) internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps} + def _spec_resolver(in_out_mapper: "InOutMapper") -> Sequence[AssetSpec]: + resolved_specs = [] + input_deps_by_key = { + key: AssetDep( + asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key) + ) + for key in in_out_mapper.asset_keys_by_input_names.values() + } + input_deps = list(input_deps_by_key.values()) + for output_name, asset_out in asset_out_map.items(): + key = in_out_mapper.asset_keys_by_output_name[output_name] + if asset_deps: + deps = [ + input_deps_by_key.get( + dep_key, + AssetDep( + asset=dep_key, + partition_mapping=in_out_mapper.partition_mappings.get(key), + ), + ) + for dep_key in asset_deps.get(output_name, []) + ] + else: + deps = input_deps + + resolved_specs.append(asset_out.to_spec(key, deps=deps)) + return resolved_specs + return InOutMapper( directly_passed_asset_ins=ins, input_tuples_by_asset_key=inputs_tuples_by_asset_key, @@ -357,8 +390,9 @@ def non_spec_based_create( internal_deps=internal_deps, can_subset=can_subset, deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset, - specs_directly_passed_to_multi_asset=None, + spec_resolver=_spec_resolver, op_name=op_name, + group_name=group_name, ) @cached_property @@ -467,6 +501,20 @@ def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]: asset_name=self.op_name, ) + @cached_property + def resolved_specs(self) -> Sequence[AssetSpec]: + specs = self.spec_resolver(self) + if not self.group_name: + return specs + + check.invariant( + all(spec.group_name is None for spec in specs), + "Cannot set group_name parameter on multi_asset if one or more of the" + " AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.", + ) + + return [spec._replace(group_name=self.group_name) for spec in specs] + def validate_and_assign_output_names_to_check_specs( check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]