Skip to content

Commit

Permalink
Extract _get_op_def_compute_fn into wrap_source_asset_observe_fn_in_o…
Browse files Browse the repository at this point in the history
…p_compute_fn (Take two) (#16699)

## Summary & Motivation

This is another shot at #16618 which was reverted in #16688 after @johannkm  discovered a bug.

It turns out I had spuriously left a `@staticmethod` decoration on wrap_source_asset_observe_fn_in_op_compute_fn
which worked fine both locally and in CI. This because this only worked in Python 3.10 and later.

https://docs.python.org/3/whatsnew/3.10.html#other-language-changes

```
Static methods (@staticmethod) and class methods (@classmethod) now inherit the method attributes 
(__module__, __name__, __qualname__, __doc__, __annotations__) and have a new __wrapped__ 
attribute. Moreover, static methods are now callable as regular functions.
```

We only run 3.10 in most CI now for cost control. However, sometimes we pay the iron price for this optimization.

## How I Tested These Changes

Load original PR locally in python 3.9. Confirm error on original PR. Apply this patch. See no error.
  • Loading branch information
schrockn authored Sep 21, 2023
1 parent f670c9f commit 0eb0443
Showing 1 changed file with 75 additions and 61 deletions.
136 changes: 75 additions & 61 deletions python_modules/dagster/dagster/_core/definitions/source_asset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Expand Down Expand Up @@ -46,6 +47,11 @@
DagsterInvalidInvocationError,
DagsterInvalidObservationError,
)

if TYPE_CHECKING:
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
)
from dagster._core.storage.io_manager import IOManagerDefinition
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import disable_dagster_warnings
Expand All @@ -54,6 +60,74 @@
SourceAssetObserveFunction: TypeAlias = Callable[..., Any]


def wrap_source_asset_observe_fn_in_op_compute_fn(
source_asset: "SourceAsset",
) -> "DecoratedOpFunction":
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

check.not_none(source_asset.observe_fn, "Must be an observable source asset")
assert source_asset.observe_fn # for type checker

observe_fn = source_asset.observe_fn

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if source_asset.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if source_asset.partitions_def is None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is not partitioned, so its observe function should return"
" a DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {source_asset.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)


@experimental_param(param="resource_defs")
@experimental_param(param="io_manager_def")
class SourceAsset(ResourceAddable):
Expand Down Expand Up @@ -180,66 +254,6 @@ def is_observable(self) -> bool:
"""bool: Whether the asset is observable."""
return self.node_def is not None

def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if self.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{self.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if self.partitions_def is None:
raise DagsterInvalidObservationError(
f"{self.key} is not partitioned, so its observe function should return a"
" DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {self.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)

@property
def required_resource_keys(self) -> AbstractSet[str]:
return {requirement.key for requirement in self.get_resource_requirements()}
Expand All @@ -252,7 +266,7 @@ def node_def(self) -> Optional[OpDefinition]:

if self._node_def is None:
self._node_def = OpDefinition(
compute_fn=self._get_op_def_compute_fn(self.observe_fn),
compute_fn=wrap_source_asset_observe_fn_in_op_compute_fn(self),
name=self.key.to_python_identifier(),
description=self.description,
required_resource_keys=self._required_resource_keys,
Expand Down

0 comments on commit 0eb0443

Please sign in to comment.