diff --git a/python_modules/dagster/dagster/__init__.py b/python_modules/dagster/dagster/__init__.py index 7d9659c5ba35c..7ebea7d7a558c 100644 --- a/python_modules/dagster/dagster/__init__.py +++ b/python_modules/dagster/dagster/__init__.py @@ -330,7 +330,9 @@ make_values_resource as make_values_resource, resource as resource, ) -from dagster._core.definitions.result import MaterializeResult as MaterializeResult +from dagster._core.definitions.result import ( + MaterializeResult as MaterializeResult, +) from dagster._core.definitions.run_config import RunConfig as RunConfig from dagster._core.definitions.run_request import ( AddDynamicPartitionsRequest as AddDynamicPartitionsRequest, diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 270f62a3a632c..a730aa688117d 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -359,7 +359,7 @@ def dagster_internal_init( is_subset=is_subset, ) - def __call__(self, *args: object, **kwargs: object) -> object: + def __call__(self, *args: object, **kwargs: object) -> Any: from .composition import is_in_composition from .graph_definition import GraphDefinition diff --git a/python_modules/dagster/dagster/_core/definitions/external_asset.py b/python_modules/dagster/dagster/_core/definitions/external_asset.py index b44b681c77896..dea5cf03085bc 100644 --- a/python_modules/dagster/dagster/_core/definitions/external_asset.py +++ b/python_modules/dagster/dagster/_core/definitions/external_asset.py @@ -8,7 +8,9 @@ ) from dagster._core.definitions.assets import AssetsDefinition from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset +from dagster._core.definitions.events import Output from dagster._core.definitions.source_asset import ( + SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION, SourceAsset, wrap_source_asset_observe_fn_in_op_compute_fn, ) @@ -137,12 +139,6 @@ def create_external_asset_from_source_asset(source_asset: SourceAsset) -> Assets " should be None", ) - injected_metadata = ( - {SYSTEM_METADATA_KEY_ASSET_EXECUTION_TYPE: AssetExecutionType.UNEXECUTABLE.value} - if source_asset.observe_fn is None - else {} - ) - injected_metadata = ( {SYSTEM_METADATA_KEY_ASSET_EXECUTION_TYPE: AssetExecutionType.UNEXECUTABLE.value} if source_asset.observe_fn is None @@ -173,10 +169,11 @@ def _shim_assets_def(context: AssetExecutionContext): op_function = wrap_source_asset_observe_fn_in_op_compute_fn(source_asset) return_value = op_function.decorated_fn(context) check.invariant( - return_value is None, - "The wrapped decorated_fn should return a value. If this changes, this code path must" - " changed to process the events appopriately.", + isinstance(return_value, Output) + and SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION in return_value.metadata, + "The wrapped decorated_fn should return an Output with a special metadata key.", ) + return return_value check.invariant(isinstance(_shim_assets_def, AssetsDefinition)) assert isinstance(_shim_assets_def, AssetsDefinition) # appease pyright diff --git a/python_modules/dagster/dagster/_core/definitions/op_definition.py b/python_modules/dagster/dagster/_core/definitions/op_definition.py index 0daadb8f11778..ca89af06bdc36 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/op_definition.py @@ -47,7 +47,7 @@ from .inference import infer_output_props from .input import In, InputDefinition from .output import Out, OutputDefinition -from .result import MaterializeResult +from .result import AssetResult if TYPE_CHECKING: from dagster._core.definitions.asset_layer import AssetLayer @@ -574,4 +574,4 @@ def _validate_context_type_hint(fn): def _is_result_object_type(ttype): # Is this type special result object type - return ttype in (MaterializeResult, AssetCheckResult) + return ttype in (AssetResult, AssetCheckResult) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index d77dd6eb81d3a..8b3d2cd634d10 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -29,7 +29,7 @@ Output, ) from .output import DynamicOutputDefinition, OutputDefinition -from .result import MaterializeResult +from .result import AssetResult if TYPE_CHECKING: from ..execution.context.compute import OpExecutionContext @@ -344,7 +344,7 @@ def _resolve_inputs( return input_dict -def _key_for_result(result: MaterializeResult, context: "BaseDirectExecutionContext") -> AssetKey: +def _key_for_result(result: AssetResult, context: "BaseDirectExecutionContext") -> AssetKey: if not context.per_invocation_properties.assets_def: raise DagsterInvariantViolationError( f"Op {context.per_invocation_properties.alias} does not have an assets definition." @@ -359,13 +359,13 @@ def _key_for_result(result: MaterializeResult, context: "BaseDirectExecutionCont return next(iter(context.per_invocation_properties.assets_def.keys)) raise DagsterInvariantViolationError( - "MaterializeResult did not include asset_key and it can not be inferred. Specify which" + f"{result.__class__.__name__} did not include asset_key and it can not be inferred. Specify which" f" asset_key, options are: {context.per_invocation_properties.assets_def.keys}" ) def _output_name_for_result_obj( - event: MaterializeResult, + event: AssetResult, context: "BaseDirectExecutionContext", ): if not context.per_invocation_properties.assets_def: @@ -388,7 +388,7 @@ def _handle_gen_event( (AssetMaterialization, AssetObservation, ExpectationResult), ): return event - elif isinstance(event, MaterializeResult): + elif isinstance(event, AssetResult): output_name = _output_name_for_result_obj(event, context) outputs_seen.add(output_name) return event @@ -516,7 +516,7 @@ def _type_check_function_output( for event in validate_and_coerce_op_result_to_iterator(result, op_context, op_def.output_defs): if isinstance(event, (Output, DynamicOutput)): _type_check_output(output_defs_by_name[event.output_name], event, context) - elif isinstance(event, (MaterializeResult)): + elif isinstance(event, AssetResult): # ensure result objects are contextually valid _output_name_for_result_obj(event, context) diff --git a/python_modules/dagster/dagster/_core/definitions/result.py b/python_modules/dagster/dagster/_core/definitions/result.py index dc3b912b9895b..c32fd69726180 100644 --- a/python_modules/dagster/dagster/_core/definitions/result.py +++ b/python_modules/dagster/dagster/_core/definitions/result.py @@ -1,7 +1,7 @@ from typing import NamedTuple, Optional, Sequence import dagster._check as check -from dagster._annotations import PublicAttr +from dagster._annotations import PublicAttr, experimental from dagster._core.definitions.asset_check_result import AssetCheckResult from dagster._core.definitions.data_version import DataVersion @@ -12,9 +12,9 @@ from .metadata import MetadataUserInput -class MaterializeResult( +class AssetResult( NamedTuple( - "_MaterializeResult", + "_AssetResult", [ ("asset_key", PublicAttr[Optional[AssetKey]]), ("metadata", PublicAttr[Optional[MetadataUserInput]]), @@ -23,14 +23,7 @@ class MaterializeResult( ], ) ): - """An object representing a successful materialization of an asset. These can be returned from - @asset and @multi_asset decorated functions to pass metadata or specify specific assets were - materialized. - - Attributes: - asset_key (Optional[AssetKey]): Optional in @asset, required in @multi_asset to discern which asset this refers to. - metadata (Optional[MetadataUserInput]): Metadata to record with the corresponding AssetMaterialization event. - """ + """Base class for MaterializeResult and ObserveResult.""" def __new__( cls, @@ -62,3 +55,32 @@ def check_result_named(self, check_name: str) -> AssetCheckResult: return check_result check.failed(f"Could not find check result named {check_name}") + + +class MaterializeResult(AssetResult): + """An object representing a successful materialization of an asset. These can be returned from + @asset and @multi_asset decorated functions to pass metadata or specify specific assets were + materialized. + + Attributes: + asset_key (Optional[AssetKey]): Optional in @asset, required in @multi_asset to discern which asset this refers to. + metadata (Optional[MetadataUserInput]): Metadata to record with the corresponding AssetMaterialization event. + check_results (Optional[Sequence[AssetCheckResult]]): Check results to record with the + corresponding AssetMaterialization event. + data_version (Optional[DataVersion]): The data version of the asset that was observed. + """ + + +@experimental +class ObserveResult(AssetResult): + """An object representing a successful observation of an asset. These can be returned from + @asset and @multi_asset decorated functions to pass metadata or specify that specific assets were + observed. + + Attributes: + asset_key (Optional[AssetKey]): Optional in @asset, required in @multi_asset to discern which asset this refers to. + metadata (Optional[MetadataUserInput]): Metadata to record with the corresponding AssetMaterialization event. + check_results (Optional[Sequence[AssetCheckResult]]): Check results to record with the + corresponding AssetObservation event. + data_version (Optional[DataVersion]): The data version of the asset that was observed. + """ diff --git a/python_modules/dagster/dagster/_core/definitions/source_asset.py b/python_modules/dagster/dagster/_core/definitions/source_asset.py index 79f6ca213d571..5751ad26ca716 100644 --- a/python_modules/dagster/dagster/_core/definitions/source_asset.py +++ b/python_modules/dagster/dagster/_core/definitions/source_asset.py @@ -20,7 +20,7 @@ DataVersion, DataVersionsByPartition, ) -from dagster._core.definitions.events import AssetKey, AssetObservation, CoercibleToAssetKey +from dagster._core.definitions.events import AssetKey, AssetObservation, CoercibleToAssetKey, Output from dagster._core.definitions.metadata import ( ArbitraryMetadataMapping, MetadataMapping, @@ -59,6 +59,12 @@ # Going with this catch-all for the time-being to permit pythonic resources SourceAssetObserveFunction: TypeAlias = Callable[..., Any] +# This is a private key that is attached to the Output emitted from a source asset observation +# function and used to prevent observations from being auto-generated from it. This is a workaround +# because we cannot currently auto-convert the observation function to use `ObserveResult`. It can +# be removed when that conversion is completed. +SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION = "__source_asset_observation__" + def wrap_source_asset_observe_fn_in_op_compute_fn( source_asset: "SourceAsset", @@ -78,7 +84,7 @@ def wrap_source_asset_observe_fn_in_op_compute_fn( observe_fn_has_context = is_context_provided(get_function_params(observe_fn)) - def fn(context: OpExecutionContext) -> None: + def fn(context: OpExecutionContext) -> Output[None]: resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)] resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys} observe_fn_return_value = ( @@ -124,6 +130,7 @@ def fn(context: OpExecutionContext) -> None: " DataVersionsByPartition, but returned a value of type" f" {type(observe_fn_return_value)}" ) + return Output(None, metadata={SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION: True}) return DecoratedOpFunction(fn) diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 2d67fb38bab2e..fcebfcbd11ce8 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -29,7 +29,7 @@ from dagster._core.definitions.asset_check_spec import AssetCheckKey from dagster._core.definitions.asset_layer import AssetLayer from dagster._core.definitions.op_definition import OpComputeFunction -from dagster._core.definitions.result import MaterializeResult +from dagster._core.definitions.result import AssetResult, MaterializeResult, ObserveResult from dagster._core.errors import ( DagsterExecutionStepExecutionError, DagsterInvariantViolationError, @@ -58,6 +58,7 @@ AssetCheckEvaluation, AssetCheckResult, MaterializeResult, + ObserveResult, ] @@ -114,6 +115,7 @@ def _validate_event(event: Any, step_context: StepExecutionContext) -> OpOutputU AssetCheckResult, AssetCheckEvaluation, MaterializeResult, + ObserveResult, ), ): raise DagsterInvariantViolationError( @@ -213,7 +215,7 @@ def execute_core_compute( yield step_output if isinstance(step_output, (DynamicOutput, Output)): emitted_result_names.add(step_output.output_name) - elif isinstance(step_output, MaterializeResult): + elif isinstance(step_output, AssetResult): asset_key = ( step_output.asset_key or step_context.job_def.asset_layer.asset_key_for_node(step_context.node_handle) diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py index 375be39a7ea43..729ab0861a0f7 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py @@ -30,7 +30,7 @@ from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction from dagster._core.definitions.input import InputDefinition from dagster._core.definitions.op_definition import OpDefinition -from dagster._core.definitions.result import MaterializeResult +from dagster._core.definitions.result import AssetResult, ObserveResult from dagster._core.errors import DagsterInvariantViolationError from dagster._core.types.dagster_type import DagsterTypeKind, is_generic_output_annotation from dagster._utils import is_named_tuple_instance @@ -168,10 +168,10 @@ def _filter_expected_output_defs( result_tuple = ( (result,) if not isinstance(result, tuple) or is_named_tuple_instance(result) else result ) - materialize_results = [x for x in result_tuple if isinstance(x, MaterializeResult)] + asset_results = [x for x in result_tuple if isinstance(x, AssetResult)] remove_outputs = [ r.get_spec_python_identifier(asset_key=x.asset_key or context.asset_key) - for x in materialize_results + for x in asset_results for r in x.check_results or [] ] return [out for out in output_defs if out.name not in remove_outputs] @@ -257,7 +257,8 @@ def validate_and_coerce_op_result_to_iterator( "value. Check out the docs on logging events here: " "https://docs.dagster.io/concepts/ops-jobs-graphs/op-events#op-events-and-exceptions" ) - elif isinstance(result, AssetCheckResult): + # These don't correspond to output defs so pass them through + elif isinstance(result, (AssetCheckResult, ObserveResult)): yield result elif result is not None and not output_defs: raise DagsterInvariantViolationError( @@ -310,7 +311,7 @@ def validate_and_coerce_op_result_to_iterator( mapping_key=dynamic_output.mapping_key, metadata=dynamic_output.metadata, ) - elif isinstance(element, MaterializeResult): + elif isinstance(element, AssetResult): yield element # coerced in to Output in outer iterator elif isinstance(element, Output): if annotation != inspect.Parameter.empty and not is_generic_output_annotation( diff --git a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py index f0e75917644cf..77677b7963e19 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py @@ -27,6 +27,7 @@ ) from dagster._core.definitions.asset_check_result import AssetCheckResult from dagster._core.definitions.asset_spec import AssetExecutionType +from dagster._core.definitions.assets import AssetsDefinition from dagster._core.definitions.data_version import ( CODE_VERSION_TAG, DATA_VERSION_IS_USER_PROVIDED_TAG, @@ -48,7 +49,8 @@ MultiPartitionKey, get_tags_from_multi_partition_key, ) -from dagster._core.definitions.result import MaterializeResult +from dagster._core.definitions.result import AssetResult +from dagster._core.definitions.source_asset import SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION from dagster._core.errors import ( DagsterAssetCheckFailedError, DagsterExecutionHandleOutputError, @@ -101,23 +103,9 @@ def _process_asset_results_to_events( def _process_user_event( step_context: StepExecutionContext, user_event: OpOutputUnion ) -> Iterator[OpOutputUnion]: - if isinstance(user_event, MaterializeResult): - assets_def = step_context.job_def.asset_layer.assets_def_for_node(step_context.node_handle) - if not assets_def: - raise DagsterInvariantViolationError( - "MaterializeResult is only valid within asset computations, no backing" - " AssetsDefinition found." - ) - if user_event.asset_key: - asset_key = user_event.asset_key - else: - if len(assets_def.keys) != 1: - raise DagsterInvariantViolationError( - "MaterializeResult did not include asset_key and it can not be inferred." - f" Specify which asset_key, options are: {assets_def.keys}." - ) - asset_key = assets_def.key - + if isinstance(user_event, AssetResult): + assets_def = _get_assets_def_for_step(step_context, user_event) + asset_key = _resolve_asset_result_asset_key(user_event, assets_def) output_name = assets_def.get_output_name_for_asset_key(asset_key) for check_result in user_event.check_results or []: @@ -169,6 +157,32 @@ def _process_user_event( yield user_event +def _get_assets_def_for_step( + step_context: StepExecutionContext, user_event: OpOutputUnion +) -> AssetsDefinition: + assets_def = step_context.job_def.asset_layer.assets_def_for_node(step_context.node_handle) + if not assets_def: + raise DagsterInvariantViolationError( + f"{user_event.__class__.__name__} is only valid within asset computations, no backing" + " AssetsDefinition found." + ) + return assets_def + + +def _resolve_asset_result_asset_key( + asset_result: AssetResult, assets_def: AssetsDefinition +) -> AssetKey: + if asset_result.asset_key: + return asset_result.asset_key + else: + if len(assets_def.keys) != 1: + raise DagsterInvariantViolationError( + f"{asset_result.__class__.__name__} did not include asset_key and it can not be inferred." + f" Specify which asset_key, options are: {assets_def.keys}." + ) + return assets_def.key + + def _step_output_error_checked_user_event_sequence( step_context: StepExecutionContext, user_event_sequence: Iterator[OpOutputUnion] ) -> Iterator[OpOutputUnion]: @@ -587,23 +601,24 @@ def _materializing_asset_key_and_partitions_for_output( return None, set() -def _get_output_asset_materializations( +def _get_output_asset_events( asset_key: AssetKey, asset_partitions: AbstractSet[str], output: Union[Output, DynamicOutput], output_def: OutputDefinition, io_manager_metadata: Mapping[str, MetadataValue], step_context: StepExecutionContext, -) -> Iterator[AssetMaterialization]: + execution_type: AssetExecutionType, +) -> Iterator[Union[AssetMaterialization, AssetObservation]]: all_metadata = {**output.metadata, **io_manager_metadata} # Clear any cached record associated with this asset, since we are about to generate a new # materialization. step_context.wipe_input_asset_version_info(asset_key) - tags: Dict[str, str] if ( - step_context.is_external_input_asset_version_info_loaded + execution_type == AssetExecutionType.MATERIALIZATION + and step_context.is_external_input_asset_version_info_loaded and asset_key in step_context.job_def.asset_layer.asset_keys ): assert isinstance(output, Output) @@ -632,6 +647,11 @@ def _get_output_asset_materializations( if not step_context.has_data_version(asset_key): data_version = DataVersion(tags[DATA_VERSION_TAG]) step_context.set_data_version(asset_key, data_version) + elif execution_type == AssetExecutionType.OBSERVATION: + assert isinstance(output, Output) + tags = ( + _build_data_version_observation_tags(output.data_version) if output.data_version else {} + ) else: tags = {} @@ -639,6 +659,13 @@ def _get_output_asset_materializations( if backfill_id: tags[BACKFILL_ID_TAG] = backfill_id + if execution_type == AssetExecutionType.MATERIALIZATION: + event_class = AssetMaterialization + elif execution_type == AssetExecutionType.OBSERVATION: + event_class = AssetObservation + else: + check.failed(f"Unexpected asset execution type {execution_type}") + if asset_partitions: for partition in asset_partitions: with disable_dagster_warnings(): @@ -648,7 +675,7 @@ def _get_output_asset_materializations( else {} ) - yield AssetMaterialization( + yield event_class( asset_key=asset_key, partition=partition, metadata=all_metadata, @@ -656,7 +683,7 @@ def _get_output_asset_materializations( ) else: with disable_dagster_warnings(): - yield AssetMaterialization(asset_key=asset_key, metadata=all_metadata, tags=tags) + yield event_class(asset_key=asset_key, metadata=all_metadata, tags=tags) def _get_code_version(asset_key: AssetKey, step_context: StepExecutionContext) -> str: @@ -718,6 +745,15 @@ def _build_data_version_tags( return tags +def _build_data_version_observation_tags( + data_version: DataVersion, +) -> Dict[str, str]: + return { + DATA_VERSION_TAG: data_version.value, + DATA_VERSION_IS_USER_PROVIDED_TAG: "true", + } + + def _store_output( step_context: StepExecutionContext, step_output_handle: StepOutputHandle, @@ -859,18 +895,41 @@ def _log_asset_materialization_events_for_asset( f"Unexpected asset execution type {execution_type}", ) - yield from ( - ( - DagsterEvent.asset_materialization(step_context, materialization) - for materialization in _get_output_asset_materializations( - asset_key, - partitions, - output, - output_def, - manager_metadata, - step_context, + # This is a temporary workaround to prevent duplicate observation events from external + # observable assets that were auto-converted from source assets. These assets yield + # observation events through the context in their body, and will continue to do so until we + # can convert them to using ObserveResult, which requires a solution to partition-scoped + # metadata and data version on output. We identify these auto-converted assets by looking + # for OBSERVATION-type asset that have this special metadata key (added in + # `wrap_source_asset_observe_fn_in_op_compute_fn`), which should only occur for these + # auto-converted source assets. This can be removed when source asset observation functions + # are converted to use ObserveResult. + if ( + execution_type == AssetExecutionType.OBSERVATION + and SYSTEM_METADATA_KEY_SOURCE_ASSET_OBSERVATION in output.metadata + ): + pass + elif execution_type != AssetExecutionType.UNEXECUTABLE: + yield from ( + ( + _dagster_event_for_asset_event(step_context, event) + for event in _get_output_asset_events( + asset_key, + partitions, + output, + output_def, + manager_metadata, + step_context, + execution_type, + ) ) ) - if execution_type == AssetExecutionType.MATERIALIZATION - else () - ) + + +def _dagster_event_for_asset_event( + step_context: StepExecutionContext, asset_event: Union[AssetMaterialization, AssetObservation] +): + if isinstance(asset_event, AssetMaterialization): + return DagsterEvent.asset_materialization(step_context, asset_event) + else: # observation + return DagsterEvent.asset_observation(step_context, asset_event) diff --git a/python_modules/dagster/dagster/_core/types/dagster_type.py b/python_modules/dagster/dagster/_core/types/dagster_type.py index d279f63c07556..bd9a38e4c6552 100644 --- a/python_modules/dagster/dagster/_core/types/dagster_type.py +++ b/python_modules/dagster/dagster/_core/types/dagster_type.py @@ -840,7 +840,7 @@ def resolve_dagster_type(dagster_type: object) -> DagsterType: # circular dep from dagster._utils.typing_api import is_typing_type - from ..definitions.result import MaterializeResult + from ..definitions.result import MaterializeResult, ObserveResult from .primitive_mapping import ( is_supported_runtime_python_builtin, remap_python_builtin_for_runtime, @@ -877,6 +877,9 @@ def resolve_dagster_type(dagster_type: object) -> DagsterType: # scalar values via MaterializeResult is supported # https://github.com/dagster-io/dagster/issues/16887 dagster_type = Nothing + elif dagster_type == ObserveResult: + # ObserveResult does not include a value + dagster_type = Nothing # Then, check to see if it is part of python's typing library if is_typing_type(dagster_type): diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_observe_result.py b/python_modules/dagster/dagster_tests/definitions_tests/test_observe_result.py new file mode 100644 index 0000000000000..099f903ff7774 --- /dev/null +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_observe_result.py @@ -0,0 +1,544 @@ +import asyncio +from typing import Any, Callable, Dict, Generator, Tuple + +import pytest +from dagster import ( + AssetCheckResult, + AssetCheckSpec, + AssetExecutionContext, + AssetKey, + AssetOut, + AssetSpec, + IOManager, + StaticPartitionsDefinition, + asset, + build_op_context, + instance_for_test, + materialize, + multi_asset, +) +from dagster._core.definitions.asset_check_spec import AssetCheckKey +from dagster._core.definitions.asset_spec import ( + SYSTEM_METADATA_KEY_ASSET_EXECUTION_TYPE, + AssetExecutionType, +) +from dagster._core.definitions.assets import AssetsDefinition +from dagster._core.definitions.result import ObserveResult +from dagster._core.errors import DagsterInvariantViolationError, DagsterStepOutputNotFoundError +from dagster._core.execution.context.invocation import build_asset_context +from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus + + +def _exec_asset(asset_def, selection=None, partition_key=None): + result = materialize([asset_def], selection=selection, partition_key=partition_key) + assert result.success + return result.asset_observations_for_node(asset_def.node_def.name) + + +def _with_observe_metadata(kwargs: Dict[str, Any]) -> Dict[str, Any]: + metadata = kwargs.pop("metadata", {}) + metadata[SYSTEM_METADATA_KEY_ASSET_EXECUTION_TYPE] = AssetExecutionType.OBSERVATION.value + return {**kwargs, "metadata": metadata} + + +def _external_observable_asset(**kwargs) -> Callable[..., AssetsDefinition]: + def _decorator(fn: Callable[..., Any]) -> AssetsDefinition: + new_kwargs = _with_observe_metadata(kwargs) + return asset(**new_kwargs)(fn) + + return _decorator + + +def _external_observable_multi_asset(**kwargs) -> Callable[..., AssetsDefinition]: + def _decorator(fn: Callable[..., Any]) -> AssetsDefinition: + if "outs" in kwargs: + kwargs["outs"] = { + name: AssetOut(**_with_observe_metadata(out._asdict())) + for name, out in kwargs["outs"].items() + } + elif "specs" in kwargs: + kwargs["specs"] = [ + AssetSpec(**_with_observe_metadata(spec._asdict())) for spec in kwargs["specs"] + ] + return multi_asset(**kwargs)(fn) + + return _decorator + + +def test_observe_result_asset(): + @_external_observable_asset() + def ret_untyped(context: AssetExecutionContext): + return ObserveResult( + metadata={"one": 1}, + ) + + observations = _exec_asset(ret_untyped) + assert len(observations) == 1, observations + assert "one" in observations[0].metadata + + # key mismatch + @_external_observable_asset() + def ret_mismatch(context: AssetExecutionContext): + return ObserveResult( + asset_key="random", + metadata={"one": 1}, + ) + + # core execution + with pytest.raises( + DagsterInvariantViolationError, + match="Asset key random not found in AssetsDefinition", + ): + materialize([ret_mismatch]) + + # direct invocation + with pytest.raises( + DagsterInvariantViolationError, + match="Asset key random not found in AssetsDefinition", + ): + ret_mismatch(build_asset_context()) + + # tuple + @_external_observable_asset() + def ret_two(): + return ObserveResult(metadata={"one": 1}), ObserveResult(metadata={"two": 2}) + + # core execution + result = materialize([ret_two]) + assert result.success + + # direct invocation + direct_results = ret_two() + assert len(direct_results) == 2 + + +def test_return_observe_result_with_asset_checks(): + with instance_for_test() as instance: + + @_external_observable_asset( + check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey("ret_checks"))] + ) + def ret_checks(context: AssetExecutionContext): + return ObserveResult( + check_results=[ + AssetCheckResult(check_name="foo_check", metadata={"one": 1}, passed=True) + ] + ) + + # core execution + materialize([ret_checks], instance=instance) + asset_check_executions = instance.event_log_storage.get_asset_check_execution_history( + AssetCheckKey(asset_key=ret_checks.key, name="foo_check"), + limit=1, + ) + assert len(asset_check_executions) == 1 + assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED + + # direct invocation + context = build_asset_context() + direct_results = ret_checks(context) + assert direct_results + + +def test_multi_asset_observe_result(): + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def outs_multi_asset(): + return ObserveResult(asset_key="one", metadata=({"foo": "bar"})), ObserveResult( + asset_key="two", metadata={"baz": "qux"} + ) + + assert materialize([outs_multi_asset]).success + + res = outs_multi_asset() + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + @_external_observable_multi_asset( + specs=[ + AssetSpec(["prefix", "one"]), + AssetSpec(["prefix", "two"]), + ] + ) + def specs_multi_asset(): + return ObserveResult(asset_key=["prefix", "one"], metadata={"foo": "bar"}), ObserveResult( + asset_key=["prefix", "two"], metadata={"baz": "qux"} + ) + + assert materialize([specs_multi_asset]).success + + res = specs_multi_asset() + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + +def test_yield_materialization_multi_asset(): + # + # yield successful + # + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def multi(): + yield ObserveResult( + asset_key="one", + metadata={"one": 1}, + ) + yield ObserveResult( + asset_key="two", + metadata={"two": 2}, + ) + + mats = _exec_asset(multi) + + assert len(mats) == 2, mats + assert "one" in mats[0].metadata + assert "two" in mats[1].metadata + + direct_results = list(multi()) + assert len(direct_results) == 2 + + # + # missing a non optional out + # + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def missing(): + yield ObserveResult( + asset_key="one", + metadata={"one": 1}, + ) + + # currently a less than ideal error + with pytest.raises( + DagsterStepOutputNotFoundError, + match=( + 'Core compute for op "missing" did not return an output for non-optional output "two"' + ), + ): + _exec_asset(missing) + + with pytest.raises( + DagsterInvariantViolationError, + match='Invocation of op "missing" did not return an output for non-optional output "two"', + ): + list(missing()) + + # + # missing asset_key + # + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def no_key(): + yield ObserveResult( + metadata={"one": 1}, + ) + yield ObserveResult( + metadata={"two": 2}, + ) + + with pytest.raises( + DagsterInvariantViolationError, + match=( + "ObserveResult did not include asset_key and it can not be inferred. Specify which" + " asset_key, options are:" + ), + ): + _exec_asset(no_key) + + with pytest.raises( + DagsterInvariantViolationError, + match=( + "ObserveResult did not include asset_key and it can not be inferred. Specify which" + " asset_key, options are:" + ), + ): + list(no_key()) + + # + # return tuple success + # + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def ret_multi(): + return ( + ObserveResult( + asset_key="one", + metadata={"one": 1}, + ), + ObserveResult( + asset_key="two", + metadata={"two": 2}, + ), + ) + + mats = _exec_asset(ret_multi) + + assert len(mats) == 2, mats + assert "one" in mats[0].metadata + assert "two" in mats[1].metadata + + res = ret_multi() + assert len(res) == 2 + + # + # return list error + # + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def ret_list(): + return [ + ObserveResult( + asset_key="one", + metadata={"one": 1}, + ), + ObserveResult( + asset_key="two", + metadata={"two": 2}, + ), + ] + + # not the best + with pytest.raises( + DagsterInvariantViolationError, + match=( + "When using multiple outputs, either yield each output, or return a tuple containing a" + " value for each output." + ), + ): + _exec_asset(ret_list) + + with pytest.raises( + DagsterInvariantViolationError, + match=( + "When using multiple outputs, either yield each output, or return a tuple containing a" + " value for each output." + ), + ): + ret_list() + + +def test_observe_result_output_typing(): + # Test that the return annotation ObserveResult is interpreted as a Nothing type, since we + # coerce returned ObserveResults to Output(None) + + class TestingIOManager(IOManager): + def handle_output(self, context, obj): + assert context.dagster_type.is_nothing + return None + + def load_input(self, context): + return 1 + + @_external_observable_asset() + def asset_with_type_annotation() -> ObserveResult: + return ObserveResult(metadata={"foo": "bar"}) + + assert materialize( + [asset_with_type_annotation], resources={"io_manager": TestingIOManager()} + ).success + + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def multi_asset_with_outs_and_type_annotation() -> Tuple[ObserveResult, ObserveResult]: + return ObserveResult(asset_key="one"), ObserveResult(asset_key="two") + + assert materialize( + [multi_asset_with_outs_and_type_annotation], resources={"io_manager": TestingIOManager()} + ).success + + @_external_observable_multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + def multi_asset_with_specs_and_type_annotation() -> Tuple[ObserveResult, ObserveResult]: + return ObserveResult(asset_key="one"), ObserveResult(asset_key="two") + + assert materialize( + [multi_asset_with_specs_and_type_annotation], resources={"io_manager": TestingIOManager()} + ).success + + @_external_observable_multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + def multi_asset_with_specs_and_no_type_annotation(): + return ObserveResult(asset_key="one"), ObserveResult(asset_key="two") + + assert materialize( + [multi_asset_with_specs_and_no_type_annotation], + resources={"io_manager": TestingIOManager()}, + ).success + + @_external_observable_asset( + check_specs=[ + AssetCheckSpec(name="check_one", asset="with_checks"), + AssetCheckSpec(name="check_two", asset="with_checks"), + ] + ) + def with_checks(context: AssetExecutionContext) -> ObserveResult: + return ObserveResult( + check_results=[ + AssetCheckResult( + check_name="check_one", + passed=True, + ), + AssetCheckResult( + check_name="check_two", + passed=True, + ), + ] + ) + + assert materialize( + [with_checks], + resources={"io_manager": TestingIOManager()}, + ).success + + @_external_observable_multi_asset( + specs=[ + AssetSpec("asset_one"), + AssetSpec("asset_two"), + ], + check_specs=[ + AssetCheckSpec(name="check_one", asset="asset_one"), + AssetCheckSpec(name="check_two", asset="asset_two"), + ], + ) + def multi_checks(context: AssetExecutionContext) -> Tuple[ObserveResult, ObserveResult]: + return ObserveResult( + asset_key="asset_one", + check_results=[ + AssetCheckResult( + check_name="check_one", + passed=True, + asset_key="asset_one", + ), + ], + ), ObserveResult( + asset_key="asset_two", + check_results=[ + AssetCheckResult( + check_name="check_two", + passed=True, + asset_key="asset_two", + ), + ], + ) + + assert materialize( + [multi_checks], + resources={"io_manager": TestingIOManager()}, + ).success + + +@pytest.mark.skip( + "Generator return types are interpreted as Any. See" + " https://github.com/dagster-io/dagster/pull/16906" +) +def test_generator_return_type_annotation(): + class TestingIOManager(IOManager): + def handle_output(self, context, obj): + assert context.dagster_type.is_nothing + return None + + def load_input(self, context): + return 1 + + @asset + def generator_asset() -> Generator[ObserveResult, None, None]: + yield ObserveResult(metadata={"foo": "bar"}) + + materialize([generator_asset], resources={"io_manager": TestingIOManager()}) + + +def test_observe_result_generators(): + @_external_observable_asset() + def generator_asset() -> Generator[ObserveResult, None, None]: + yield ObserveResult(metadata={"foo": "bar"}) + + res = _exec_asset(generator_asset) + assert len(res) == 1 + assert res[0].metadata["foo"].value == "bar" + + res = list(generator_asset()) + assert len(res) == 1 + assert res[0].metadata["foo"] == "bar" + + @_external_observable_multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + def generator_specs_multi_asset(): + yield ObserveResult(asset_key="one", metadata={"foo": "bar"}) + yield ObserveResult(asset_key="two", metadata={"baz": "qux"}) + + res = _exec_asset(generator_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + res = list(generator_specs_multi_asset()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + @_external_observable_multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) + def generator_outs_multi_asset(): + yield ObserveResult(asset_key="one", metadata={"foo": "bar"}) + yield ObserveResult(asset_key="two", metadata={"baz": "qux"}) + + res = _exec_asset(generator_outs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + res = list(generator_outs_multi_asset()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + @_external_observable_multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + async def async_specs_multi_asset(): + return ObserveResult(asset_key="one", metadata={"foo": "bar"}), ObserveResult( + asset_key="two", metadata={"baz": "qux"} + ) + + res = _exec_asset(async_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + res = asyncio.run(async_specs_multi_asset()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + @_external_observable_multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + async def async_gen_specs_multi_asset(): + yield ObserveResult(asset_key="one", metadata={"foo": "bar"}) + yield ObserveResult(asset_key="two", metadata={"baz": "qux"}) + + res = _exec_asset(async_gen_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + async def _run_async_gen(): + results = [] + async for result in async_gen_specs_multi_asset(): + results.append(result) + return results + + res = asyncio.run(_run_async_gen()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + +def test_observe_result_with_partitions(): + @_external_observable_asset( + partitions_def=StaticPartitionsDefinition(["red", "blue", "yellow"]) + ) + def partitioned_asset(context: AssetExecutionContext) -> ObserveResult: + return ObserveResult(metadata={"key": context.partition_key}) + + mats = _exec_asset(partitioned_asset, partition_key="red") + assert len(mats) == 1, mats + assert mats[0].metadata["key"].text == "red" + + +def test_observe_result_with_partitions_direct_invocation(): + @_external_observable_asset( + partitions_def=StaticPartitionsDefinition(["red", "blue", "yellow"]) + ) + def partitioned_asset(context: AssetExecutionContext) -> ObserveResult: + return ObserveResult(metadata={"key": context.partition_key}) + + context = build_op_context(partition_key="red") + + res = partitioned_asset(context) + assert res.metadata["key"] == "red"