Skip to content

Commit

Permalink
update test to use new function
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Sep 19, 2023
1 parent 05aabfd commit 682fb19
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 46 deletions.
46 changes: 33 additions & 13 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
DagsterInvariantViolationError,
DagsterTypeCheckDidNotPass,
)
from dagster._core.execution.context.invocation import UnboundAssetExecutionContext

from .events import (
AssetMaterialization,
Expand All @@ -30,7 +29,7 @@
from .output import DynamicOutputDefinition

if TYPE_CHECKING:
from ..execution.context.invocation import BoundOpExecutionContext
from ..execution.context.invocation import BoundAssetExecutionContext, BoundOpExecutionContext
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -108,6 +107,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
UnboundAssetExecutionContext,
UnboundOpExecutionContext,
build_op_context,
)
Expand Down Expand Up @@ -155,7 +155,10 @@ def direct_invocation_result(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(UnboundOpExecutionContext, args[0])
if isinstance(args[0], UnboundAssetExecutionContext):
context = cast(UnboundAssetExecutionContext, args[0])
else:
context = cast(UnboundOpExecutionContext, args[0])
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -166,14 +169,22 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(UnboundOpExecutionContext, kwargs[context_param_name])
if isinstance(kwargs[context_param_name], UnboundAssetExecutionContext):
context = cast(UnboundAssetExecutionContext, kwargs[context_param_name])
else:
context = cast(UnboundOpExecutionContext, kwargs[context_param_name])
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], UnboundOpExecutionContext):
context = cast(UnboundOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(
args[0], (UnboundOpExecutionContext, UnboundAssetExecutionContext)
):
if isinstance(args[0], UnboundAssetExecutionContext):
context = cast(UnboundAssetExecutionContext, args[0])
else:
context = cast(UnboundOpExecutionContext, args[0])
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -224,7 +235,10 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "BoundOpExecutionContext"
op_def: "OpDefinition",
args,
kwargs,
context: Union["BoundOpExecutionContext", "BoundAssetExecutionContext"],
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -307,7 +321,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.describe_op()
step_label = context.describe_step()

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -316,7 +330,7 @@ def _resolve_inputs(
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} input "{input_def.name}" - '
f'Type check failed for {step_label} input "{input_def.name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand All @@ -328,7 +342,9 @@ def _resolve_inputs(


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "BoundOpExecutionContext"
op_def: "OpDefinition",
result: Any,
context: Union["BoundOpExecutionContext", "BoundAssetExecutionContext"],
) -> Any:
"""Type checks and returns the result of a op.
Expand Down Expand Up @@ -436,7 +452,9 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "BoundOpExecutionContext"
op_def: "OpDefinition",
result: T,
context: Union["BoundOpExecutionContext", "BoundAssetExecutionContext"],
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -447,7 +465,9 @@ def _type_check_function_output(


def _type_check_output(
output_def: "OutputDefinition", output: T, context: "BoundOpExecutionContext"
output_def: "OutputDefinition",
output: T,
context: Union["BoundOpExecutionContext", "BoundAssetExecutionContext"],
) -> T:
"""Validates and performs core type check on a provided output.
Expand All @@ -459,7 +479,7 @@ def _type_check_output(
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.describe_op()
op_label = context.describe_step()

if isinstance(output, (Output, DynamicOutput)):
dagster_type = output_def.dagster_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def get_mapping_key(self) -> Optional[str]:
return self._mapping_key

def describe_op(self) -> str:
return self.describe_step()

def describe_step(self) -> str:
if isinstance(self.op_def, OpDefinition):
return f'op "{self.op_def.name}"'

Expand Down Expand Up @@ -797,7 +800,6 @@ def __init__(
instance: Optional[DagsterInstance],
partition_key: Optional[str],
partition_key_range: Optional[PartitionKeyRange],
mapping_key: Optional[str],
assets_def: Optional[AssetsDefinition],
):
self._op_execution_context = build_op_context(
Expand All @@ -807,7 +809,7 @@ def __init__(
instance=instance,
partition_key_range=partition_key_range,
partition_key=partition_key,
mapping_key=mapping_key,
mapping_key=None,
_assets_def=assets_def,
)

Expand Down Expand Up @@ -865,7 +867,7 @@ def test_my_op():
expectation_results = [event for event in all_user_events if isinstance(event, ExpectationResult)]
...
"""
return self._op_execution_context._user_events
return self._op_execution_context._user_events # noqa: SLF001

def get_output_metadata(
self, output_name: str, mapping_key: Optional[str] = None
Expand All @@ -881,13 +883,13 @@ def get_output_metadata(
Returns:
Optional[Mapping[str, Any]]: The metadata values present for the output_name/mapping_key combination, if present.
"""
metadata = self._op_execution_context._output_metadata.get(output_name)
metadata = self._op_execution_context._output_metadata.get(output_name) # noqa: SLF001
if mapping_key and metadata:
return metadata.get(mapping_key)
return metadata

def get_mapping_key(self) -> Optional[str]:
return self._op_execution_context._mapping_key
return self._op_execution_context._mapping_key # noqa: SLF001


class BoundAssetExecutionContext(AssetExecutionContext):
Expand All @@ -902,13 +904,16 @@ def __init__(self, bound_op_execution_context: BoundOpExecutionContext):

@property
def alias(self) -> str:
return self._op_execution_context._alias
return self._op_execution_context._alias # noqa: SLF001

def describe_step(self) -> str:
return f" asset '{self._op_execution_context.op_def.name}'"

def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
return self._op_execution_context.for_type(dagster_type=dagster_type)

def get_mapping_key(self) -> Optional[str]:
return self._op_execution_context._mapping_key
return self._op_execution_context._mapping_key # noqa: SLF001

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self._op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)
Expand All @@ -928,7 +933,6 @@ def build_asset_context(
partition_key_range: Optional[PartitionKeyRange] = None,
# TODO - the below params were not originally params for this function, but are used by `build_op_context`
# figure out what they are for anf if we need them
mapping_key: Optional[str] = None,
_assets_def: Optional[AssetsDefinition] = None,
):
"""Builds asset execution context from provided parameters.
Expand Down Expand Up @@ -968,6 +972,5 @@ def build_asset_context(
partition_key_range, "partition_key_range", PartitionKeyRange
),
instance=check.opt_inst_param(instance, "instance", DagsterInstance),
mapping_key=check.opt_str_param(mapping_key, "mapping_key"),
assets_def=check.opt_inst_param(_assets_def, "_assets_def", AssetsDefinition),
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from dagster._utils import is_named_tuple_instance
from dagster._utils.warnings import disable_dagster_warnings

from ..context.compute import OpExecutionContext
from ..context.compute import AssetExecutionContext, OpExecutionContext


class NoAnnotationSentinel:
Expand Down Expand Up @@ -242,15 +242,21 @@ def _check_output_object_name(


def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
result: Any,
context: Union[OpExecutionContext, AssetExecutionContext],
output_defs: Sequence[OutputDefinition],
) -> Iterator[Any]:
if isinstance(context, AssetExecutionContext):
step_description = f" asset '{context._op_execution_context.op_def.name}'" # noqa: SLF001
else:
step_description = context.describe_op()
if inspect.isgenerator(result):
# this happens when a user explicitly returns a generator in the op
for event in result:
yield event
elif isinstance(result, (AssetMaterialization, ExpectationResult)):
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: If you are "
f"Error in {step_description}: If you are "
"returning an AssetMaterialization "
"or an ExpectationResult from "
f"{context.op_def.node_type_str} you must yield them "
Expand All @@ -263,7 +269,7 @@ def validate_and_coerce_op_result_to_iterator(
yield result
elif result is not None and not output_defs:
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: Unexpectedly returned output of type"
f"Error in {step_description}: Unexpectedly returned output of type"
f" {type(result)}. {context.op_def.node_type_str.capitalize()} is explicitly defined to"
" return no results."
)
Expand All @@ -275,15 +281,15 @@ def validate_and_coerce_op_result_to_iterator(
if output_def.is_dynamic:
if not isinstance(element, list):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: "
f"Error with output for {step_description}: "
f"dynamic output '{output_def.name}' expected a list of "
"DynamicOutput objects, but instead received instead an "
f"object of type {type(element)}."
)
for item in element:
if not isinstance(item, DynamicOutput):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: "
f"Error with output for {step_description}: "
f"dynamic output '{output_def.name}' at position {position} expected a "
"list of DynamicOutput objects, but received an "
f"item with type {type(item)}."
Expand All @@ -305,7 +311,7 @@ def validate_and_coerce_op_result_to_iterator(
annotation
):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: received Output object for"
f"Error with output for {step_description}: received Output object for"
f" output '{output_def.name}' which does not have an Output annotation."
f" Annotation has type {annotation}."
)
Expand All @@ -323,15 +329,15 @@ def validate_and_coerce_op_result_to_iterator(
# output object was not received, throw an error.
if is_generic_output_annotation(annotation):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: output "
f"Error with output for {step_description}: output "
f"'{output_def.name}' has generic output annotation, "
"but did not receive an Output object for this output. "
f"Received instead an object of type {type(element)}."
)
if result is None and output_def.is_required is False:
context.log.warning(
'Value "None" returned for non-required output '
f'"{output_def.name}" of {context.describe_op()}. '
f'"{output_def.name}" of {step_description}. '
"This value will be passed to downstream "
f"{context.op_def.node_type_str}s. For conditional "
"execution, results must be yielded: "
Expand Down
Loading

0 comments on commit 682fb19

Please sign in to comment.