From 51870dad5efea3e65efb390e893b10e02a597424 Mon Sep 17 00:00:00 2001 From: alangenfeld Date: Fri, 23 Jun 2023 16:52:21 -0500 Subject: [PATCH] [prototype] return AssetMaterialization --- .../dagster/_core/definitions/events.py | 19 +++++- .../dagster/_core/execution/context/system.py | 18 +++++ .../_core/execution/plan/compute_generator.py | 16 +++-- .../_core/execution/plan/execute_step.py | 38 ++++++++++- .../asset_defs_tests/test_assets.py | 67 +++++++++++++++++++ 5 files changed, 150 insertions(+), 8 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/events.py b/python_modules/dagster/dagster/_core/definitions/events.py index d70de99d5343e..8a0daed4dc1cf 100644 --- a/python_modules/dagster/dagster/_core/definitions/events.py +++ b/python_modules/dagster/dagster/_core/definitions/events.py @@ -20,6 +20,7 @@ import dagster._seven as seven from dagster._annotations import PublicAttr, public from dagster._core.definitions.data_version import DataVersion +from dagster._core.errors import DagsterInvariantViolationError from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX, SYSTEM_TAG_PREFIX from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import NamedTupleSerializer @@ -476,7 +477,7 @@ class AssetMaterialization( Args: asset_key (Union[str, List[str], AssetKey]): A key to identify the materialized asset across - job runs + job runs. Optional in cases when the key can be inferred from the current context. description (Optional[str]): A longer human-readable description of the materialized value. partition (Optional[str]): The name of the partition that was materialized. @@ -490,18 +491,32 @@ class AssetMaterialization( def __new__( cls, - asset_key: CoercibleToAssetKey, + asset_key: Optional[CoercibleToAssetKey] = None, description: Optional[str] = None, metadata: Optional[Mapping[str, RawMetadataValue]] = None, partition: Optional[str] = None, tags: Optional[Mapping[str, str]] = None, ): from dagster._core.definitions.multi_dimensional_partitions import MultiPartitionKey + from dagster._core.execution.context.compute import get_execution_context if isinstance(asset_key, AssetKey): check.inst_param(asset_key, "asset_key", AssetKey) elif isinstance(asset_key, str): asset_key = AssetKey(parse_asset_key_string(asset_key)) + elif asset_key is None: + current_ctx = get_execution_context() + if current_ctx is None: + raise DagsterInvariantViolationError( + "Could not infer asset_key, not currently in the context of an execution." + ) + keys = current_ctx.selected_asset_keys + if len(keys) != 1: + raise DagsterInvariantViolationError( + f"Could not infer asset_key, there are {len(keys)} in the current execution" + " context. Specify the appropriate asset_key." + ) + asset_key = next(iter(keys)) else: check.sequence_param(asset_key, "asset_key", of_type=str) asset_key = AssetKey(asset_key) diff --git a/python_modules/dagster/dagster/_core/execution/context/system.py b/python_modules/dagster/dagster/_core/execution/context/system.py index c62d4acce47e7..2538a69d137f3 100644 --- a/python_modules/dagster/dagster/_core/execution/context/system.py +++ b/python_modules/dagster/dagster/_core/execution/context/system.py @@ -521,6 +521,7 @@ def __init__( self._output_metadata: Dict[str, Any] = {} self._seen_outputs: Dict[str, Union[str, Set[str]]] = {} + self._seen_user_asset_mats = {} self._input_asset_records: Dict[AssetKey, Optional["EventLogRecord"]] = {} self._is_external_input_asset_records_loaded = False @@ -846,6 +847,23 @@ def is_sda_step(self) -> bool: return True return False + def asset_key_for_output(self, output_name: str) -> AssetKey: + # note: duped on AssetExecutionContext + asset_output_info = self.job_def.asset_layer.asset_info_for_output( + node_handle=self.node_handle, output_name=output_name + ) + if asset_output_info is None: + check.failed(f"Output '{output_name}' has no asset") + else: + return asset_output_info.key + + def observe_user_asset_mat(self, asset_key, event): + # will need to store N events for partitions + self._seen_user_asset_mats[asset_key] = event + + def get_observed_user_asset_mat(self, asset_key): + return self._seen_user_asset_mats.get(asset_key) + def set_data_version(self, asset_key: AssetKey, data_version: "DataVersion") -> None: self._data_version_cache[asset_key] = data_version diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py index 67a2f2509d982..322e58174a58f 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py @@ -18,6 +18,7 @@ from typing_extensions import get_args +import dagster._check as check from dagster._config.pythonic_config import Config from dagster._core.definitions import ( AssetMaterialization, @@ -195,17 +196,22 @@ def validate_and_coerce_op_result_to_iterator( # this happens when a user explicitly returns a generator in the op for event in result: yield event - elif isinstance(result, (AssetMaterialization, ExpectationResult)): + + # [A] yield it here... + elif isinstance(result, AssetMaterialization): + yield result + + elif isinstance(result, (ExpectationResult)): raise DagsterInvariantViolationError( f"Error in {context.describe_op()}: If you are " - "returning an AssetMaterialization " - "or an ExpectationResult from " - f"{context.op_def.node_type_str} you must yield them " + "returning an ExpectationResult from " + f"{context.op_def.node_type_str} you must yield it " "directly, or log them using the OpExecutionContext.log_event method to avoid " "ambiguity with an implied result from returning a " "value. Check out the docs on logging events here: " "https://docs.dagster.io/concepts/ops-jobs-graphs/op-events#op-events-and-exceptions" ) + elif result is not None and not output_defs: raise DagsterInvariantViolationError( f"Error in {context.describe_op()}: Unexpectedly returned output of type" @@ -285,3 +291,5 @@ def validate_and_coerce_op_result_to_iterator( "https://docs.dagster.io/concepts/ops-jobs-graphs/graphs#with-conditional-branching" ) yield Output(output_name=output_def.name, value=element) + else: + check.failed("do we ever hit this unhandled case?") diff --git a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py index f83fc5c268dc2..6d06e42503ccf 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py @@ -88,10 +88,16 @@ def _step_output_error_checked_user_event_sequence( output_names = list([output_def.name for output_def in step.step_outputs]) for user_event in user_event_sequence: - if not isinstance(user_event, (Output, DynamicOutput)): + if not isinstance(user_event, (Output, DynamicOutput, AssetMaterialization)): yield user_event continue + # [A] ... to swallow here + if isinstance(user_event, AssetMaterialization): + step_context.observe_user_asset_mat(user_event.asset_key, user_event) + # defer yielding til post resolve to apply tags + continue + # do additional processing on Outputs output = user_event if not step.has_step_output(cast(str, output.output_name)): @@ -165,6 +171,15 @@ def _step_output_error_checked_user_event_sequence( f'Emitting implicit Nothing for output "{step_output_def.name}" on {op_label}' ) yield Output(output_name=step_output_def.name, value=None) + + if step_context.is_sda_step and step_context.get_observed_user_asset_mat( + step_context.asset_key_for_output(step_output_def.name) + ): + # think its fine to omit log + # step_context.log.info( + # f"Emitting implicit Nothing for materialized asset {op_label}" + # ) + yield Output(output_name=step_output_def.name, value=None) elif not step_output_def.is_dynamic: raise DagsterStepOutputNotFoundError( ( @@ -376,6 +391,7 @@ def core_dagster_event_sequence_for_step( yield evt # for now, I'm ignoring AssetMaterializations yielded manually, but we might want # to do something with these in the above path eventually + # ^ wat? elif isinstance(user_event, AssetMaterialization): yield DagsterEvent.asset_materialization(step_context, user_event) elif isinstance(user_event, AssetObservation): @@ -490,7 +506,11 @@ def _get_output_asset_materializations( if backfill_id: tags[BACKFILL_ID_TAG] = backfill_id - if asset_partitions: + user_event = step_context.get_observed_user_asset_mat(asset_key) + if asset_partitions and user_event: + # this will be a bit involved + check.failed("unhandled") + elif asset_partitions: for partition in asset_partitions: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=DeprecationWarning) @@ -507,6 +527,20 @@ def _get_output_asset_materializations( metadata=all_metadata, tags=tags, ) + elif user_event: + if tags: + yield AssetMaterialization( + **{ # dirty mergey + **user_event._asdict(), + "tags": { + **(user_event.tags if user_event.tags else {}), + **tags, + }, + }, + ) + else: + yield user_event + else: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=DeprecationWarning) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py index 33ad559409713..83114da5f18a8 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py @@ -35,11 +35,13 @@ from dagster._check import CheckError from dagster._core.definitions import AssetIn, SourceAsset, asset, multi_asset from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy +from dagster._core.definitions.events import AssetMaterialization from dagster._core.errors import ( DagsterInvalidDefinitionError, DagsterInvalidInvocationError, DagsterInvalidPropertyError, ) +from dagster._core.execution.context.compute import AssetExecutionContext from dagster._core.instance import DagsterInstance from dagster._core.storage.mem_io_manager import InMemoryIOManager from dagster._core.test_utils import instance_for_test @@ -1476,3 +1478,68 @@ def blah(context): defs = Definitions(assets=[blah]) defs.get_implicit_global_asset_job_def().execute_in_process() assert executed["yes"] + + +def test_return_materialization(): + # + # status quo - use add add_output_metadata + # + @asset + def add(context: AssetExecutionContext): + context.add_output_metadata( + metadata={"one": 1}, + ) + + asset_job = define_asset_job("bar", [add]).resolve([add], []) + + result = asset_job.execute_in_process() + assert result.success + + mats = result.asset_materializations_for_node(add.node_def.name) + assert len(mats) == 1 + # working with core metadata repr values sucks, ie IntMetadataValue + assert "one" in mats[0].metadata + assert mats[0].tags + + # + # side quest: may want to update this pattern to work as well + # + @asset + def logged(context: AssetExecutionContext): + context.log_event( + AssetMaterialization( + metadata={"one": 1}, + ) + ) + + asset_job = define_asset_job("bar", [logged]).resolve([logged], []) + + result = asset_job.execute_in_process() + assert result.success + + mats = result.asset_materializations_for_node(logged.node_def.name) + # should we change this? currently get implicit materialization for output + logged event + assert len(mats) == 2 + assert "one" in mats[0].metadata + # assert mats[0].tags # fails + # assert "one" in mats[1].metadata # fails + assert mats[1].tags + + # + # main exploration + # + @asset + def ret_untyped(context: AssetExecutionContext): + return AssetMaterialization( + metadata={"one": 1}, + ) + + asset_job = define_asset_job("bar", [ret_untyped]).resolve([ret_untyped], []) + + result = asset_job.execute_in_process() + assert result.success + + mats = result.asset_materializations_for_node(ret_untyped.node_def.name) + assert len(mats) == 1, mats + assert "one" in mats[0].metadata + assert mats[0].tags