Skip to content

Commit

Permalink
[external-assets] Hoist resolution of input asset keys to RepositoryD…
Browse files Browse the repository at this point in the history
…ataBuilder
  • Loading branch information
smackesey committed Mar 3, 2024
1 parent dc40710 commit c2031ac
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 150 deletions.
17 changes: 5 additions & 12 deletions python_modules/dagster/dagster/_core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from dagster._core.definitions.assets import AssetsDefinition, SourceAsset
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.resolved_asset_deps import ResolvedAssetDependencies
from dagster._core.execution.context.output import OutputContext

from .partition import PartitionedConfig, PartitionsDefinition
Expand Down Expand Up @@ -411,7 +410,6 @@ def from_graph_and_assets_node_mapping(
asset_checks_defs_by_node_handle: Mapping[NodeHandle, "AssetChecksDefinition"],
observable_source_assets_by_node_handle: Mapping[NodeHandle, "SourceAsset"],
source_assets: Sequence["SourceAsset"],
resolved_asset_deps: "ResolvedAssetDependencies",
) -> "AssetLayer":
"""Generate asset info from a GraphDefinition and a mapping from nodes in that graph to the
corresponding AssetsDefinition objects.
Expand Down Expand Up @@ -452,25 +450,20 @@ def from_graph_and_assets_node_mapping(

for node_handle, assets_def in assets_defs_by_outer_node_handle.items():
for key in assets_def.keys:
asset_deps[key] = resolved_asset_deps.get_resolved_upstream_asset_keys(
assets_def, key
)
asset_deps[key] = assets_def.asset_deps[key]

for input_name in assets_def.node_keys_by_input_name.keys():
resolved_asset_key = resolved_asset_deps.get_resolved_asset_key_for_input(
assets_def, input_name
)
for input_name, input_asset_key in assets_def.node_keys_by_input_name.items():
input_handle = NodeInputHandle(node_handle, input_name)
asset_key_by_input[input_handle] = resolved_asset_key
asset_key_by_input[input_handle] = input_asset_key
# resolve graph input to list of op inputs that consume it
node_input_handles = assets_def.node_def.resolve_input_to_destinations(input_handle)
for node_input_handle in node_input_handles:
asset_key_by_input[node_input_handle] = resolved_asset_key
asset_key_by_input[node_input_handle] = input_asset_key

partition_mapping = assets_def.get_partition_mapping_for_input(input_name)
if partition_mapping is not None:
partition_mappings_by_asset_dep[
(node_handle, resolved_asset_key)
(node_handle, input_asset_key)
] = partition_mapping

for output_name, asset_key in assets_def.node_keys_by_output_name.items():
Expand Down
16 changes: 4 additions & 12 deletions python_modules/dagster/dagster/_core/definitions/assets_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .job_definition import JobDefinition, default_job_io_manager
from .metadata import RawMetadataValue
from .partition import PartitionedConfig, PartitionsDefinition
from .resolved_asset_deps import ResolvedAssetDependencies
from .resource_definition import ResourceDefinition
from .resource_requirement import ensure_requirements_satisfied
from .source_asset import SourceAsset
Expand Down Expand Up @@ -210,18 +209,17 @@ def asset2(asset1):
# figure out what partitions (if any) exist for this job
partitions_def = partitions_def or build_job_partitions_from_assets(assets)

resolved_asset_deps = ResolvedAssetDependencies(assets, resolved_source_assets)
deps, assets_defs_by_node_handle, asset_checks_defs_by_node_handle = build_node_deps(
assets, asset_checks, resolved_asset_deps
assets, asset_checks
)

# attempt to resolve cycles using multi-asset subsetting
if _has_cycles(deps):
assets = _attempt_resolve_cycles(assets, resolved_source_assets)
resolved_asset_deps = ResolvedAssetDependencies(assets, resolved_source_assets)

deps, assets_defs_by_node_handle, asset_checks_defs_by_node_handle = build_node_deps(
assets, asset_checks, resolved_asset_deps
assets,
asset_checks,
)

if len(assets) > 0 or len(asset_checks) > 0:
Expand Down Expand Up @@ -257,7 +255,6 @@ def asset2(asset1):
graph_def=graph,
asset_checks_defs_by_node_handle=asset_checks_defs_by_node_handle,
source_assets=resolved_source_assets,
resolved_asset_deps=resolved_asset_deps,
assets_defs_by_outer_node_handle=assets_defs_by_node_handle,
observable_source_assets_by_node_handle=observable_source_assets_by_node_handle,
)
Expand Down Expand Up @@ -363,7 +360,6 @@ def _get_blocking_asset_check_output_handles_by_asset_key(
def build_node_deps(
assets_defs: Iterable[AssetsDefinition],
asset_checks_defs: Sequence[AssetChecksDefinition],
resolved_asset_deps: ResolvedAssetDependencies,
) -> Tuple[
DependencyMapping[NodeInvocation],
Mapping[NodeHandle, AssetsDefinition],
Expand Down Expand Up @@ -413,11 +409,7 @@ def build_node_deps(
deps[node_key] = {}

# connect each input of this AssetsDefinition to the proper upstream node
for input_name in assets_def.input_names:
upstream_asset_key = resolved_asset_deps.get_resolved_asset_key_for_input(
assets_def, input_name
)

for input_name, upstream_asset_key in assets_def.keys_by_input_name.items():
# ignore self-deps
if upstream_asset_key in assets_def.keys:
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from dagster._core.definitions.partitioned_schedule import (
UnresolvedPartitionedAssetScheduleDefinition,
)
from dagster._core.definitions.resolved_asset_deps import ResolvedAssetDependencies
from dagster._core.definitions.resource_definition import ResourceDefinition
from dagster._core.definitions.schedule_definition import ScheduleDefinition
from dagster._core.definitions.sensor_definition import SensorDefinition
Expand Down Expand Up @@ -233,6 +234,19 @@ def build_caching_repository_data_from_list(
else:
check.failed(f"Unexpected repository entry {definition}")

# Resolve all asset dependencies. An asset dependency is resolved when it's key is an AssetKey
# not subject to any further manipulation.
resolved_deps = ResolvedAssetDependencies(assets_defs, [])
assets_defs = [
ad.with_attributes(
input_asset_key_replacements={
raw_key: resolved_deps.get_resolved_asset_key_for_input(ad, input_name)
for input_name, raw_key in ad.keys_by_input_name.items()
}
)
for ad in assets_defs
]

if assets_defs or source_assets or asset_checks_defs:
for job_def in get_base_asset_jobs(
assets=assets_defs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,14 @@ def __new__(
def generate_asset_dep_graph(
assets_defs: Iterable["AssetsDefinition"], source_assets: Iterable["SourceAsset"]
) -> DependencyGraph[AssetKey]:
from dagster._core.definitions.resolved_asset_deps import ResolvedAssetDependencies

resolved_asset_deps = ResolvedAssetDependencies(assets_defs, source_assets)

upstream: Dict[AssetKey, Set[AssetKey]] = {}
downstream: Dict[AssetKey, Set[AssetKey]] = {}
for assets_def in assets_defs:
for asset_key in assets_def.keys:
upstream[asset_key] = set()
downstream[asset_key] = downstream.get(asset_key, set())
# for each asset upstream of this one, set that as upstream, and this downstream of it
upstream_asset_keys = resolved_asset_deps.get_resolved_upstream_asset_keys(
assets_def, asset_key
)
for upstream_key in upstream_asset_keys:
for upstream_key in assets_def.asset_deps[asset_key]:
upstream[asset_key].add(upstream_key)
downstream[upstream_key] = downstream.get(upstream_key, set()) | {asset_key}
return {"upstream": upstream, "downstream": downstream}
Expand Down
24 changes: 24 additions & 0 deletions python_modules/dagster/dagster/_core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@
fs_io_manager,
)
from dagster._config import Array, Field
from dagster._core.definitions.asset_selection import CoercibleToAssetSelection
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.decorators import op
from dagster._core.definitions.decorators.graph_decorator import graph
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.graph_definition import GraphDefinition
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.node_definition import NodeDefinition
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.unresolved_asset_job_definition import define_asset_job
from dagster._core.errors import DagsterUserCodeUnreachableError
from dagster._core.events import DagsterEvent
from dagster._core.host_representation.origin import (
Expand Down Expand Up @@ -718,3 +724,21 @@ def ensure_dagster_tests_import() -> None:
dagster_package_root / "dagster_tests"
).exists(), "Could not find dagster_tests where expected"
sys.path.append(dagster_package_root.as_posix())


def resolve_asset_job(
assets: Sequence[Union[AssetsDefinition, SourceAsset]],
*,
selection: Optional[CoercibleToAssetSelection] = None,
name: str = "asset_job",
resources: Mapping[str, object] = {},
**kwargs: Any,
) -> JobDefinition:
assets_defs = [a for a in assets if isinstance(a, AssetsDefinition)]
source_assets = [a for a in assets if isinstance(a, SourceAsset)]
selection = selection or assets_defs
return Definitions(
assets=[*assets_defs, *source_assets],
jobs=[define_asset_job(name, selection, **kwargs)],
resources=resources,
).get_job_def(name)
Loading

0 comments on commit c2031ac

Please sign in to comment.