Skip to content

Commit

Permalink
wip get output metadata from upstream assets
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 28, 2023
1 parent d47008c commit dc32768
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AssetKey,
AssetMaterialization,
AssetObservation,
CoercibleToAssetKey,
ExpectationResult,
UserEvent,
)
Expand Down Expand Up @@ -1358,6 +1359,9 @@ def get() -> "AssetExecutionContext":
def get_op_execution_context(self) -> "OpExecutionContext":
return OpExecutionContext(self._step_execution_context)

def get_metadata_for_asset(self, key: CoercibleToAssetKey):
return self._step_execution_context._upstream_metadata.get(AssetKey.from_coercible(key), {}) # noqa: SLF001


@contextmanager
def enter_execution_context(
Expand Down
15 changes: 11 additions & 4 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
DataVersion,
)
from dagster._core.definitions.dependency import NodeHandle
from dagster._core.definitions.metadata import MetadataValue
from dagster._core.definitions.resource_definition import Resources
from dagster._core.event_api import EventLogRecord
from dagster._core.execution.plan.plan import ExecutionPlan
Expand Down Expand Up @@ -571,6 +572,8 @@ def __init__(
self._output_metadata: Dict[str, Any] = {}
self._seen_outputs: Dict[str, Union[str, Set[str]]] = {}

self._upstream_metadata: Dict[AssetKey, Mapping[str, MetadataValue]] = {}

self._input_asset_version_info: Dict[AssetKey, Optional["InputAssetVersionInfo"]] = {}
self._is_external_input_asset_version_info_loaded = False
self._data_version_cache: Dict[AssetKey, "DataVersion"] = {}
Expand Down Expand Up @@ -955,11 +958,11 @@ def is_external_input_asset_version_info_loaded(self) -> bool:

def get_input_asset_version_info(self, key: AssetKey) -> Optional["InputAssetVersionInfo"]:
if key not in self._input_asset_version_info:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_metadata_and_version_info(key)
return self._input_asset_version_info[key]

# "external" refers to records for inputs generated outside of this step
def fetch_external_input_asset_version_info(self) -> None:
def fetch_external_input_asset_version_info_and_metadata(self) -> None:
output_keys = self.get_output_asset_keys()

all_dep_keys: List[AssetKey] = []
Expand All @@ -973,18 +976,22 @@ def fetch_external_input_asset_version_info(self) -> None:

self._input_asset_version_info = {}
for key in all_dep_keys:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_metadata_and_version_info(key)
self._is_external_input_asset_version_info_loaded = True

def _fetch_input_asset_version_info(self, key: AssetKey) -> None:
def _fetch_input_asset_metadata_and_version_info(self, key: AssetKey) -> None:
from dagster._core.definitions.data_version import (
extract_data_version_from_entry,
)

event = self._get_input_asset_event(key)
if event is None:
self._input_asset_version_info[key] = None
self._upstream_metadata[key] = {}
else:
self._upstream_metadata[key] = (
event.asset_materialization.metadata if event.asset_materialization else {}
)
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self.job_def.asset_layer.input_for_asset_key(self.node_handle, key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def core_dagster_event_sequence_for_step(
inputs = {}

if step_context.is_sda_step:
step_context.fetch_external_input_asset_version_info()
step_context.fetch_external_input_asset_version_info_and_metadata()

for step_input in step_context.step.step_inputs:
input_def = step_context.op_def.input_def_named(step_input.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DagsterInstance,
Definitions,
GraphDefinition,
MaterializeResult,
OpExecutionContext,
Output,
asset,
Expand Down Expand Up @@ -426,3 +427,31 @@ def a(context: AssetExecutionContext):
assert context == AssetExecutionContext.get()

assert materialize([a]).success


def test_upstream_metadata():
# with output metadata
@asset
def upstream(context: AssetExecutionContext):
context.add_output_metadata({"foo": "bar"})

@asset
def downstream(context: AssetExecutionContext, upstream):
metadata = context.get_metadata_for_asset("upstream")
assert metadata["foo"].value == "bar"

materialize([upstream, downstream])


def test_upstream_metadata_materialize_result():
# with asset materialization
@asset
def upstream():
return MaterializeResult(metadata={"foo": "bar"})

@asset
def downstream(context: AssetExecutionContext, upstream):
metadata = context.get_metadata_for_asset("upstream")
assert metadata["foo"].value == "bar"

materialize([upstream, downstream])

0 comments on commit dc32768

Please sign in to comment.