Skip to content

Commit

Permalink
partitin mappings property
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Jun 2, 2024
1 parent c6c023f commit f483de3
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
can_subset=can_subset,
ins=ins or {},
fn=fn,
op_name=op_name,
)
else:
in_out_mapper = InOutMapper.non_spec_based_create(
Expand All @@ -685,7 +686,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
check_specs=check_specs or [],
op_name=op_name,
asset_deps=asset_deps,
upstream_asset_deps=upstream_asset_deps or [],
deps_directly_passed_to_multi_asset=upstream_asset_deps,
can_subset=can_subset,
)

Expand Down Expand Up @@ -719,23 +720,14 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
code_version=code_version,
)(fn)

partition_mappings = {
in_out_mapper.keys_by_input_name[input_name]: asset_in.partition_mapping
for input_name, asset_in in (ins or {}).items()
if asset_in.partition_mapping is not None
}

if upstream_asset_deps:
partition_mappings = get_partition_mappings_from_deps(
partition_mappings=partition_mappings, deps=upstream_asset_deps, asset_name=op_name
)

if specs:
resolved_specs = specs
else:
resolved_specs = []
input_deps_by_key = {
key: AssetDep(asset=key, partition_mapping=partition_mappings.get(key))
key: AssetDep(
asset=key, partition_mapping=in_out_mapper.partition_mappings.get(key)
)
for key in in_out_mapper.keys_by_input_name.values()
}
input_deps = list(input_deps_by_key.values())
Expand All @@ -745,7 +737,10 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
deps = [
input_deps_by_key.get(
dep_key,
AssetDep(asset=dep_key, partition_mapping=partition_mappings.get(key)),
AssetDep(
asset=dep_key,
partition_mapping=in_out_mapper.partition_mappings.get(key),
),
)
for dep_key in internal_asset_deps.get(output_name, [])
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dagster._core.definitions.assets import ASSET_SUBSET_INPUT_PREFIX
from dagster._core.definitions.input import In
from dagster._core.definitions.output import Out
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.resource_annotation import (
get_resource_args,
)
Expand Down Expand Up @@ -163,6 +164,28 @@ def build_subsettable_asset_ins(
}


def get_partition_mappings_from_deps(
partition_mappings: Dict[AssetKey, PartitionMapping], deps: Iterable[AssetDep], asset_name: str
) -> Mapping[AssetKey, PartitionMapping]:
# Add PartitionMappings specified via AssetDeps to partition_mappings dictionary. Error on duplicates
for dep in deps:
if dep.partition_mapping is None:
continue
if partition_mappings.get(dep.asset_key, None) is None:
partition_mappings[dep.asset_key] = dep.partition_mapping
continue
if partition_mappings[dep.asset_key] == dep.partition_mapping:
continue
else:
raise DagsterInvalidDefinitionError(
f"Two different PartitionMappings for {dep.asset_key} provided for"
f" asset {asset_name}. Please use the same PartitionMapping for"
f" {dep.asset_key}."
)

return partition_mappings


class InMapping(NamedTuple):
input_name: str
input: In
Expand All @@ -183,17 +206,25 @@ class InOutMapper:
def __init__(
self,
*,
asset_ins: Mapping[AssetKey, Tuple[str, In]],
asset_outs: Mapping[AssetKey, Tuple[str, Out]],
directly_passed_asset_ins: Mapping[str, AssetIn],
input_tuples_by_asset_key: Mapping[AssetKey, Tuple[str, In]],
output_tables_by_asset_key: Mapping[AssetKey, Tuple[str, Out]],
check_specs: Sequence[AssetCheckSpec],
internal_deps: Mapping[AssetKey, Set[AssetKey]],
can_subset: bool,
deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]],
specs_directly_passed_to_multi_asset: Optional[Sequence[AssetSpec]],
op_name: str,
) -> None:
self._passed_asset_ins = asset_ins
self.asset_outs = asset_outs
self.directly_passed_asset_ins = directly_passed_asset_ins
self._passed_input_tuples_by_asset_key = input_tuples_by_asset_key
self.output_tuples_by_asset_key = output_tables_by_asset_key
self.check_specs = check_specs
self.internal_deps = internal_deps
self.can_subset = can_subset
self.deps_directly_passed_to_multi_asset = deps_directly_passed_to_multi_asset
self.specs_directly_passed_to_multi_asset = specs_directly_passed_to_multi_asset
self.op_name = op_name

@staticmethod
def from_specs(
Expand All @@ -203,6 +234,7 @@ def from_specs(
can_subset: bool,
ins: Mapping[str, AssetIn],
fn: Callable[..., Any],
op_name: str,
):
output_tuples_by_asset_key = {}
for asset_spec in specs:
Expand Down Expand Up @@ -240,7 +272,7 @@ def from_specs(
" AssetSpec(s). Set the deps on the appropriate AssetSpec(s)."
)
remaining_upstream_keys = {key for key in upstream_keys if key not in loaded_upstreams}
asset_ins = build_asset_ins(fn, explicit_ins, deps=remaining_upstream_keys)
input_tuples_by_asset_key = build_asset_ins(fn, explicit_ins, deps=remaining_upstream_keys)

internal_deps = {
spec.key: {dep.asset_key for dep in spec.deps}
Expand All @@ -249,35 +281,44 @@ def from_specs(
}

return InOutMapper(
asset_ins=asset_ins,
asset_outs=output_tuples_by_asset_key,
directly_passed_asset_ins=ins,
input_tuples_by_asset_key=input_tuples_by_asset_key,
output_tables_by_asset_key=output_tuples_by_asset_key,
check_specs=check_specs,
internal_deps=internal_deps,
can_subset=can_subset,
# when specs are used deps are never passed to multi-asset
deps_directly_passed_to_multi_asset=None,
specs_directly_passed_to_multi_asset=specs,
op_name=op_name,
)

@staticmethod
def non_spec_based_create(
*,
asset_out_map: Mapping[str, AssetOut],
asset_deps: Mapping[str, Set[AssetKey]],
upstream_asset_deps: Iterable[AssetDep],
deps_directly_passed_to_multi_asset: Optional[Iterable[AssetDep]],
ins: Mapping[str, AssetIn],
op_name: str,
fn: Callable[..., Any],
check_specs: Sequence[AssetCheckSpec],
can_subset: bool,
op_name: str,
):
asset_ins = build_asset_ins(
inputs_tuples_by_asset_key = build_asset_ins(
fn,
ins or {},
deps=({dep.asset_key for dep in upstream_asset_deps} if upstream_asset_deps else set()),
deps=(
{dep.asset_key for dep in deps_directly_passed_to_multi_asset}
if deps_directly_passed_to_multi_asset
else set()
),
)
output_tuples_by_asset_key = build_asset_outs(asset_out_map)

# validate that the asset_ins are a subset of the upstream asset_deps.
upstream_internal_asset_keys = set().union(*asset_deps.values())
asset_in_keys = set(asset_ins.keys())
asset_in_keys = set(inputs_tuples_by_asset_key.keys())
if asset_deps and not asset_in_keys.issubset(upstream_internal_asset_keys):
invalid_asset_in_keys = asset_in_keys - upstream_internal_asset_keys
check.failed(
Expand Down Expand Up @@ -309,37 +350,43 @@ def non_spec_based_create(
internal_deps = {keys_by_output_name[name]: asset_deps[name] for name in asset_deps}

return InOutMapper(
asset_ins=asset_ins,
asset_outs=output_tuples_by_asset_key,
directly_passed_asset_ins=ins,
input_tuples_by_asset_key=inputs_tuples_by_asset_key,
output_tables_by_asset_key=output_tuples_by_asset_key,
check_specs=check_specs or [],
internal_deps=internal_deps,
can_subset=can_subset,
deps_directly_passed_to_multi_asset=deps_directly_passed_to_multi_asset,
specs_directly_passed_to_multi_asset=None,
op_name=op_name,
)

@cached_property
def asset_ins(self) -> Mapping[AssetKey, Tuple[str, In]]:
def input_tuples_by_asset_key(self) -> Mapping[AssetKey, Tuple[str, In]]:
if self.can_subset and self.internal_deps:
return {
**self._passed_asset_ins,
**self._passed_input_tuples_by_asset_key,
**build_subsettable_asset_ins(
self._passed_asset_ins, self.asset_outs, self.internal_deps.values()
self._passed_input_tuples_by_asset_key,
self.output_tuples_by_asset_key,
self.internal_deps.values(),
),
}
else:
return self._passed_asset_ins
return self._passed_input_tuples_by_asset_key

@cached_property
def in_mappings(self) -> Mapping[AssetKey, InMapping]:
return {
asset_key: InMapping(input_name, in_)
for asset_key, (input_name, in_) in self.asset_ins.items()
for asset_key, (input_name, in_) in self.input_tuples_by_asset_key.items()
}

@cached_property
def out_mappings(self) -> Mapping[AssetKey, OutMapping]:
return {
asset_key: OutMapping(output_name, out_)
for asset_key, (output_name, out_) in self.asset_outs.items()
for asset_key, (output_name, out_) in self.output_tuples_by_asset_key.items()
}

@cached_property
Expand Down Expand Up @@ -393,6 +440,23 @@ def overlapping_output_names(self) -> Set[str]:
self.check_outs_by_output_name.keys()
)

@cached_property
def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
partition_mappings = {
self.keys_by_input_name[input_name]: asset_in.partition_mapping
for input_name, asset_in in self.directly_passed_asset_ins.items()
if asset_in.partition_mapping is not None
}

if not self.deps_directly_passed_to_multi_asset:
return partition_mappings

return get_partition_mappings_from_deps(
partition_mappings=partition_mappings,
deps=self.deps_directly_passed_to_multi_asset,
asset_name=self.op_name,
)


def validate_and_assign_output_names_to_check_specs(
check_specs: Optional[Sequence[AssetCheckSpec]], valid_asset_keys: Sequence[AssetKey]
Expand Down

0 comments on commit f483de3

Please sign in to comment.