Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract _get_op_def_compute_fn into wrap_source_asset_observe_fn_in_op_compute_fn #16618

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 76 additions & 61 deletions python_modules/dagster/dagster/_core/definitions/source_asset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Expand Down Expand Up @@ -46,6 +47,11 @@
DagsterInvalidInvocationError,
DagsterInvalidObservationError,
)

if TYPE_CHECKING:
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
)
from dagster._core.storage.io_manager import IOManagerDefinition
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import disable_dagster_warnings
Expand All @@ -54,6 +60,75 @@
SourceAssetObserveFunction: TypeAlias = Callable[..., Any]


@staticmethod
def wrap_source_asset_observe_fn_in_op_compute_fn(
source_asset: "SourceAsset",
) -> "DecoratedOpFunction":
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

check.not_none(source_asset.observe_fn, "Must be an observable source asset")
assert source_asset.observe_fn # for type checker

observe_fn = source_asset.observe_fn

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if source_asset.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if source_asset.partitions_def is None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is not partitioned, so its observe function should return"
" a DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {source_asset.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)


@experimental_param(param="resource_defs")
@experimental_param(param="io_manager_def")
class SourceAsset(ResourceAddable):
Expand Down Expand Up @@ -180,66 +255,6 @@ def is_observable(self) -> bool:
"""bool: Whether the asset is observable."""
return self.node_def is not None

def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if self.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{self.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if self.partitions_def is None:
raise DagsterInvalidObservationError(
f"{self.key} is not partitioned, so its observe function should return a"
" DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {self.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)

@property
def required_resource_keys(self) -> AbstractSet[str]:
return {requirement.key for requirement in self.get_resource_requirements()}
Expand All @@ -252,7 +267,7 @@ def node_def(self) -> Optional[OpDefinition]:

if self._node_def is None:
self._node_def = OpDefinition(
compute_fn=self._get_op_def_compute_fn(self.observe_fn),
compute_fn=wrap_source_asset_observe_fn_in_op_compute_fn(self),
name=self.key.to_python_identifier(),
description=self.description,
required_resource_keys=self._required_resource_keys,
Expand Down