Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move spec resolution into mapper #22229

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
ins=ins or {},
fn=fn,
op_name=op_name,
group_name=group_name,
)
else:
in_out_mapper = InOutMapper.from_asset_outs(
Expand All @@ -696,6 +697,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
asset_deps=asset_deps,
deps_directly_passed_to_multi_asset=upstream_asset_deps,
can_subset=can_subset,
group_name=group_name,
)

check.invariant(
Expand All @@ -721,43 +723,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
code_version=code_version,
)(fn)

if specs:
resolved_specs = specs
else:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if internal_asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in internal_asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))

if group_name:
check.invariant(
all(spec.group_name is None for spec in resolved_specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)
resolved_specs = [spec._replace(group_name=group_name) for spec in resolved_specs]

return AssetsDefinition.dagster_internal_init(
keys_by_input_name=in_out_mapper.asset_keys_by_input_names,
keys_by_output_name=in_out_mapper.asset_keys_by_output_name,
Expand All @@ -770,7 +735,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
check_specs_by_output_name=in_out_mapper.check_specs_by_output_name,
selected_asset_check_keys=None, # no subselection in decorator
is_subset=False,
specs=resolved_specs,
specs=in_out_mapper.resolved_specs,
)

return inner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def __init__(
internal_deps: Mapping[AssetKey, Set[AssetKey]],
can_subset: bool,
deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]],
specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]],
spec_resolver: Callable[["InOutMapper"], Sequence[AssetSpec]],
op_name: str,
group_name: Optional[str] = None,
) -> None:
self.directly_passed_asset_ins = directly_passed_asset_ins
self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key
Expand All @@ -200,8 +201,9 @@ def __init__(
self.internal_deps = internal_deps
self.can_subset = can_subset
self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset
self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset
self.spec_resolver = spec_resolver
self.op_name = op_name
self.group_name = group_name

@staticmethod
def from_specs(
Expand All @@ -212,6 +214,7 @@ def from_specs(
ins: Mapping[str, AssetIn],
fn: Callable[..., Any],
op_name: str,
group_name: Optional[str],
):
output_tuples_by_asset_key = {}
for asset_spec in specs:
Expand Down Expand Up @@ -266,8 +269,9 @@ def from_specs(
can_subset=can_subset,
# when specs are used deps are never passed to multi-asset
deps_directly_passed_to_multi_asset=None,
specs_directly_passed_to_multi_asset=specs,
spec_resolver=lambda _: specs,
op_name=op_name,
group_name=group_name,
)

@staticmethod
Expand All @@ -281,6 +285,7 @@ def from_asset_outs(
check_specs: Sequence[AssetCheckSpec],
can_subset: bool,
op_name: str,
group_name: Optional[str],
):
inputs_tuples_by_asset_key = build_asset_ins(
fn,
Expand Down Expand Up @@ -326,6 +331,34 @@ def from_asset_outs(
keys_by_output_name = make_keys_by_output_name(output_tuples_by_asset_key)
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}

def _spec_resolver(in_out_mapper: "InOutMapper") -> Sequence[AssetSpec]:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.asset_keys_by_input_names.values()
}
input_deps = list(input_deps_by_key.values())
for output_name, asset_out in asset_out_map.items():
key = in_out_mapper.asset_keys_by_output_name[output_name]
if asset_deps:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in asset_deps.get(output_name, [])
]
else:
deps = input_deps

resolved_specs.append(asset_out.to_spec(key, deps=deps))
return resolved_specs
Comment on lines +334 to +360
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We move this strange-ish callback into a more standard up the stack.


return InOutMapper(
directly_passed_asset_ins=ins,
input_tuples_by_asset_key=inputs_tuples_by_asset_key,
Expand All @@ -334,8 +367,9 @@ def from_asset_outs(
internal_deps=internal_deps,
can_subset=can_subset,
deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset,
specs_directly_passed_to_multi_asset=None,
spec_resolver=_spec_resolver,
op_name=op_name,
group_name=group_name,
)

@cached_property
Expand Down Expand Up @@ -440,6 +474,20 @@ def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
asset_name=self.op_name,
)

@cached_property
def resolved_specs(self) -> Sequence[AssetSpec]:
specs = self.spec_resolver(self)
if not self.group_name:
return specs

check.invariant(
all(spec.group_name is None for spec in specs),
"Cannot set group_name parameter on multi_asset if one or more of the"
" AssetSpecs/AssetOuts supplied to this multi_asset have a group_name defined.",
)

return [spec._replace(group_name=self.group_name) for spec in specs]


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
Expand Down