diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index bf226555434ac..49fda763b31e3 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -32,6 +32,7 @@ AssetKey, AssetMaterialization, AssetObservation, + CoercibleToAssetKey, ExpectationResult, UserEvent, ) @@ -1358,6 +1359,9 @@ def get() -> "AssetExecutionContext": def get_op_execution_context(self) -> "OpExecutionContext": return OpExecutionContext(self._step_execution_context) + def get_metadata_for_asset(self, key: CoercibleToAssetKey): + return self._step_execution_context._upstream_metadata.get(AssetKey.from_coercible(key), {}) # noqa: SLF001 + @contextmanager def enter_execution_context( diff --git a/python_modules/dagster/dagster/_core/execution/context/system.py b/python_modules/dagster/dagster/_core/execution/context/system.py index 60235f365eeb2..052e0d441f990 100644 --- a/python_modules/dagster/dagster/_core/execution/context/system.py +++ b/python_modules/dagster/dagster/_core/execution/context/system.py @@ -77,6 +77,7 @@ DataVersion, ) from dagster._core.definitions.dependency import NodeHandle + from dagster._core.definitions.metadata import MetadataValue from dagster._core.definitions.resource_definition import Resources from dagster._core.event_api import EventLogRecord from dagster._core.execution.plan.plan import ExecutionPlan @@ -571,6 +572,8 @@ def __init__( self._output_metadata: Dict[str, Any] = {} self._seen_outputs: Dict[str, Union[str, Set[str]]] = {} + self._upstream_metadata: Dict[AssetKey, Mapping[str, MetadataValue]] = {} + self._input_asset_version_info: Dict[AssetKey, Optional["InputAssetVersionInfo"]] = {} self._is_external_input_asset_version_info_loaded = False self._data_version_cache: Dict[AssetKey, "DataVersion"] = {} @@ -955,11 +958,11 @@ def is_external_input_asset_version_info_loaded(self) -> bool: def get_input_asset_version_info(self, key: AssetKey) -> Optional["InputAssetVersionInfo"]: if key not in self._input_asset_version_info: - self._fetch_input_asset_version_info(key) + self._fetch_input_asset_metadata_and_version_info(key) return self._input_asset_version_info[key] # "external" refers to records for inputs generated outside of this step - def fetch_external_input_asset_version_info(self) -> None: + def fetch_external_input_asset_version_info_and_metadata(self) -> None: output_keys = self.get_output_asset_keys() all_dep_keys: List[AssetKey] = [] @@ -973,10 +976,10 @@ def fetch_external_input_asset_version_info(self) -> None: self._input_asset_version_info = {} for key in all_dep_keys: - self._fetch_input_asset_version_info(key) + self._fetch_input_asset_metadata_and_version_info(key) self._is_external_input_asset_version_info_loaded = True - def _fetch_input_asset_version_info(self, key: AssetKey) -> None: + def _fetch_input_asset_metadata_and_version_info(self, key: AssetKey) -> None: from dagster._core.definitions.data_version import ( extract_data_version_from_entry, ) @@ -984,7 +987,11 @@ def _fetch_input_asset_version_info(self, key: AssetKey) -> None: event = self._get_input_asset_event(key) if event is None: self._input_asset_version_info[key] = None + self._upstream_metadata[key] = {} else: + self._upstream_metadata[key] = ( + event.asset_materialization.metadata if event.asset_materialization else {} + ) storage_id = event.storage_id # Input name will be none if this is an internal dep input_name = self.job_def.asset_layer.input_for_asset_key(self.node_handle, key) diff --git a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py index 648af3f63dcf6..bd82f736aaead 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py @@ -431,7 +431,7 @@ def core_dagster_event_sequence_for_step( inputs = {} if step_context.is_sda_step: - step_context.fetch_external_input_asset_version_info() + step_context.fetch_external_input_asset_version_info_and_metadata() for step_input in step_context.step.step_inputs: input_def = step_context.op_def.input_def_named(step_input.name) diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index 6bd4da58fb3d2..4459283d42a66 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -9,6 +9,7 @@ DagsterInstance, Definitions, GraphDefinition, + MaterializeResult, OpExecutionContext, Output, asset, @@ -426,3 +427,31 @@ def a(context: AssetExecutionContext): assert context == AssetExecutionContext.get() assert materialize([a]).success + + +def test_upstream_metadata(): + # with output metadata + @asset + def upstream(context: AssetExecutionContext): + context.add_output_metadata({"foo": "bar"}) + + @asset + def downstream(context: AssetExecutionContext, upstream): + metadata = context.get_metadata_for_asset("upstream") + assert metadata["foo"].value == "bar" + + materialize([upstream, downstream]) + + +def test_upstream_metadata_materialize_result(): + # with asset materialization + @asset + def upstream(): + return MaterializeResult(metadata={"foo": "bar"}) + + @asset + def downstream(context: AssetExecutionContext, upstream): + metadata = context.get_metadata_for_asset("upstream") + assert metadata["foo"].value == "bar" + + materialize([upstream, downstream])