From 9b3718314a9b0345baa667518a917afaacf786f6 Mon Sep 17 00:00:00 2001 From: Nicholas Schrock Date: Tue, 19 Sep 2023 06:23:49 -0400 Subject: [PATCH] Extract _get_op_def_compute_fn into wrap_source_asset_observe_fn_in_op_compute_fn This refactoring will be useful in a subsequent PR --- .../dagster/_core/definitions/source_asset.py | 138 ++++++++++-------- 1 file changed, 76 insertions(+), 62 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/source_asset.py b/python_modules/dagster/dagster/_core/definitions/source_asset.py index 33a06525883fc..6ebf646f48766 100644 --- a/python_modules/dagster/dagster/_core/definitions/source_asset.py +++ b/python_modules/dagster/dagster/_core/definitions/source_asset.py @@ -9,7 +9,7 @@ cast, ) -from typing_extensions import TypeAlias +from typing_extensions import TYPE_CHECKING, TypeAlias import dagster._check as check from dagster._annotations import PublicAttr, experimental_param, public @@ -50,10 +50,84 @@ from dagster._utils.merger import merge_dicts from dagster._utils.warnings import disable_dagster_warnings +if TYPE_CHECKING: + from dagster._core.definitions.decorators.op_decorator import ( + DecoratedOpFunction, + ) + # Going with this catch-all for the time-being to permit pythonic resources SourceAssetObserveFunction: TypeAlias = Callable[..., Any] +@staticmethod +def wrap_source_asset_observe_fn_in_op_compute_fn( + source_asset: "SourceAsset", +) -> "DecoratedOpFunction": + from dagster._core.definitions.decorators.op_decorator import ( + DecoratedOpFunction, + is_context_provided, + ) + from dagster._core.execution.context.compute import ( + OpExecutionContext, + ) + + check.not_none(source_asset.observe_fn, "Must be an observable source asset") + assert source_asset.observe_fn # for type checker + + observe_fn = source_asset.observe_fn + + observe_fn_has_context = is_context_provided(get_function_params(observe_fn)) + + def fn(context: OpExecutionContext): + resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)] + resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys} + observe_fn_return_value = ( + observe_fn(context, **resource_kwargs) + if observe_fn_has_context + else observe_fn(**resource_kwargs) + ) + + if isinstance(observe_fn_return_value, DataVersion): + if source_asset.partitions_def is not None: + raise DagsterInvalidObservationError( + f"{source_asset.key} is partitioned, so its observe function should return a" + " DataVersionsByPartition, not a DataVersion" + ) + + context.log_event( + AssetObservation( + asset_key=source_asset.key, + tags={DATA_VERSION_TAG: observe_fn_return_value.value}, + ) + ) + elif isinstance(observe_fn_return_value, DataVersionsByPartition): + if source_asset.partitions_def is None: + raise DagsterInvalidObservationError( + f"{source_asset.key} is not partitioned, so its observe function should return" + " a DataVersion, not a DataVersionsByPartition" + ) + + for ( + partition_key, + data_version, + ) in observe_fn_return_value.data_versions_by_partition.items(): + context.log_event( + AssetObservation( + asset_key=source_asset.key, + tags={DATA_VERSION_TAG: data_version.value}, + partition=partition_key, + ) + ) + else: + raise DagsterInvalidObservationError( + f"Observe function for {source_asset.key} must return a DataVersion or" + " DataVersionsByPartition, but returned a value of type" + f" {type(observe_fn_return_value)}" + ) + + return DecoratedOpFunction(fn) + + @experimental_param(param="resource_defs") @experimental_param(param="io_manager_def") class SourceAsset(ResourceAddable): @@ -180,66 +254,6 @@ def is_observable(self) -> bool: """bool: Whether the asset is observable.""" return self.node_def is not None - def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction): - from dagster._core.definitions.decorators.op_decorator import ( - DecoratedOpFunction, - is_context_provided, - ) - from dagster._core.execution.context.compute import ( - OpExecutionContext, - ) - - observe_fn_has_context = is_context_provided(get_function_params(observe_fn)) - - def fn(context: OpExecutionContext): - resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)] - resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys} - observe_fn_return_value = ( - observe_fn(context, **resource_kwargs) - if observe_fn_has_context - else observe_fn(**resource_kwargs) - ) - - if isinstance(observe_fn_return_value, DataVersion): - if self.partitions_def is not None: - raise DagsterInvalidObservationError( - f"{self.key} is partitioned, so its observe function should return a" - " DataVersionsByPartition, not a DataVersion" - ) - - context.log_event( - AssetObservation( - asset_key=self.key, - tags={DATA_VERSION_TAG: observe_fn_return_value.value}, - ) - ) - elif isinstance(observe_fn_return_value, DataVersionsByPartition): - if self.partitions_def is None: - raise DagsterInvalidObservationError( - f"{self.key} is not partitioned, so its observe function should return a" - " DataVersion, not a DataVersionsByPartition" - ) - - for ( - partition_key, - data_version, - ) in observe_fn_return_value.data_versions_by_partition.items(): - context.log_event( - AssetObservation( - asset_key=self.key, - tags={DATA_VERSION_TAG: data_version.value}, - partition=partition_key, - ) - ) - else: - raise DagsterInvalidObservationError( - f"Observe function for {self.key} must return a DataVersion or" - " DataVersionsByPartition, but returned a value of type" - f" {type(observe_fn_return_value)}" - ) - - return DecoratedOpFunction(fn) - @property def required_resource_keys(self) -> AbstractSet[str]: return {requirement.key for requirement in self.get_resource_requirements()} @@ -252,7 +266,7 @@ def node_def(self) -> Optional[OpDefinition]: if self._node_def is None: self._node_def = OpDefinition( - compute_fn=self._get_op_def_compute_fn(self.observe_fn), + compute_fn=wrap_source_asset_observe_fn_in_op_compute_fn(self), name=self.key.to_python_identifier(), description=self.description, required_resource_keys=self._required_resource_keys,