Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] add AssetExecutionContext to direct invocation #18044

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import RunlessOpExecutionContext
from ..execution.context.invocation import (
BaseDirectInvocationContext,
)
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -109,7 +111,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
RunlessOpExecutionContext,
BaseDirectInvocationContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +151,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], RunlessOpExecutionContext):
if args[0] is not None and not isinstance(args[0], BaseDirectInvocationContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(RunlessOpExecutionContext, args[0])
context = args[0]
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +167,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(RunlessOpExecutionContext, kwargs[context_param_name])
context = kwargs[context_param_name]
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], RunlessOpExecutionContext):
context = cast(RunlessOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], BaseDirectInvocationContext):
context = args[0]
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -230,7 +232,7 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "BaseDirectInvocationContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -268,9 +270,8 @@ def _resolve_inputs(
"but no context parameter was defined for the op."
)

node_label = op_def.node_type_str
raise DagsterInvalidInvocationError(
f"Too many input arguments were provided for {node_label} '{context.bound_properties.alias}'."
f"Too many input arguments were provided for {context.bound_properties.step_description}'."
f" {suggestion}"
)

Expand Down Expand Up @@ -313,7 +314,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.bound_properties.step_description
step_label = context.bound_properties.step_description

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -322,7 +323,7 @@ def _resolve_inputs(
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} input "{input_def.name}" - '
f'Type check failed for {step_label} input "{input_def.name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand Down Expand Up @@ -352,7 +353,7 @@ def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionConte

def _output_name_for_result_obj(
event: MaterializeResult,
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
):
if not context.bound_properties.assets_def:
raise DagsterInvariantViolationError(
Expand All @@ -365,7 +366,7 @@ def _output_name_for_result_obj(
def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand All @@ -391,15 +392,15 @@ def _handle_gen_event(
output_def, DynamicOutputDefinition
):
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.bound_properties.alias}' yielded"
f"Invocation of {context.bound_properties.step_description} yielded"
f" an output '{output_def.name}' multiple times."
)
outputs_seen.add(output_def.name)
return event


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "BaseDirectInvocationContext"
) -> Any:
"""Type checks and returns the result of a op.

Expand Down Expand Up @@ -493,7 +494,7 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", result: T, context: "BaseDirectInvocationContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -512,25 +513,25 @@ def _type_check_function_output(
def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
) -> None:
"""Validates and performs core type check on a provided output.

Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (RunlessOpExecutionContext): Context containing resources to be used for type
context (BaseDirectInvocationContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.bound_properties.step_description
step_label = context.bound_properties.step_description
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} output "{output.output_name}" - '
f'Type check failed for {step_label} output "{output.output_name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand Down
Loading