Skip to content

Commit

Permalink
Revert "Extract _get_op_def_compute_fn into wrap_source_asset_observe…
Browse files Browse the repository at this point in the history
…_fn_in_op_compute_fn (#16618)"

This reverts commit 4d3369a.
  • Loading branch information
schrockn committed Sep 21, 2023
1 parent 239c17e commit fa35eb2
Showing 1 changed file with 61 additions and 76 deletions.
137 changes: 61 additions & 76 deletions python_modules/dagster/dagster/_core/definitions/source_asset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Expand Down Expand Up @@ -47,11 +46,6 @@
DagsterInvalidInvocationError,
DagsterInvalidObservationError,
)

if TYPE_CHECKING:
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
)
from dagster._core.storage.io_manager import IOManagerDefinition
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import disable_dagster_warnings
Expand All @@ -60,75 +54,6 @@
SourceAssetObserveFunction: TypeAlias = Callable[..., Any]


@staticmethod
def wrap_source_asset_observe_fn_in_op_compute_fn(
source_asset: "SourceAsset",
) -> "DecoratedOpFunction":
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

check.not_none(source_asset.observe_fn, "Must be an observable source asset")
assert source_asset.observe_fn # for type checker

observe_fn = source_asset.observe_fn

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if source_asset.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if source_asset.partitions_def is None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is not partitioned, so its observe function should return"
" a DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {source_asset.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)


@experimental_param(param="resource_defs")
@experimental_param(param="io_manager_def")
class SourceAsset(ResourceAddable):
Expand Down Expand Up @@ -255,6 +180,66 @@ def is_observable(self) -> bool:
"""bool: Whether the asset is observable."""
return self.node_def is not None

def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if self.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{self.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if self.partitions_def is None:
raise DagsterInvalidObservationError(
f"{self.key} is not partitioned, so its observe function should return a"
" DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {self.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)

@property
def required_resource_keys(self) -> AbstractSet[str]:
return {requirement.key for requirement in self.get_resource_requirements()}
Expand All @@ -267,7 +252,7 @@ def node_def(self) -> Optional[OpDefinition]:

if self._node_def is None:
self._node_def = OpDefinition(
compute_fn=wrap_source_asset_observe_fn_in_op_compute_fn(self),
compute_fn=self._get_op_def_compute_fn(self.observe_fn),
name=self.key.to_python_identifier(),
description=self.description,
required_resource_keys=self._required_resource_keys,
Expand Down

0 comments on commit fa35eb2

Please sign in to comment.