Skip to content

Commit

Permalink
update interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Nov 20, 2023
1 parent aeddec6 commit 348a8c7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 27 deletions.
45 changes: 23 additions & 22 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

if TYPE_CHECKING:
from ..execution.context.invocation import (
DirectInvocationOpExecutionContext,
BaseDirectInvocationContext,
)
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
Expand Down Expand Up @@ -225,7 +225,7 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "DirectInvocationOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "BaseDirectInvocationContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -263,9 +263,8 @@ def _resolve_inputs(
"but no context parameter was defined for the op."
)

node_label = op_def.node_type_str
raise DagsterInvalidInvocationError(
f"Too many input arguments were provided for {node_label} '{context.alias}'."
f"Too many input arguments were provided for {context.execution_properties.step_description}'."
f" {suggestion}"
)

Expand Down Expand Up @@ -308,7 +307,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.describe_op()
step_label = context.execution_properties.step_description

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -317,7 +316,7 @@ def _resolve_inputs(
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} input "{input_def.name}" - '
f'Type check failed for {step_label} input "{input_def.name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand All @@ -328,33 +327,35 @@ def _resolve_inputs(
return input_dict


def _key_for_result(
result: MaterializeResult, context: "DirectInvocationOpExecutionContext"
) -> AssetKey:
def _key_for_result(result: MaterializeResult, context: "BaseDirectInvocationContext") -> AssetKey:
if result.asset_key:
return result.asset_key

if len(context.assets_def.keys) == 1:
return next(iter(context.assets_def.keys))
if len(context.execution_properties.op_execution_context.assets_def.keys) == 1:
return next(iter(context.execution_properties.op_execution_context.assets_def.keys))

raise DagsterInvariantViolationError(
"MaterializeResult did not include asset_key and it can not be inferred. Specify which"
f" asset_key, options are: {context.assets_def.keys}"
f" asset_key, options are: {context.execution_properties.op_execution_context.assets_def.keys}"
)


def _output_name_for_result_obj(
event: MaterializeResult,
context: "DirectInvocationOpExecutionContext",
context: "BaseDirectInvocationContext",
):
asset_key = _key_for_result(event, context)
return context.assets_def.get_output_name_for_asset_key(asset_key)
return (
context.execution_properties.op_execution_context.assets_def.get_output_name_for_asset_key(
asset_key
)
)


def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "DirectInvocationOpExecutionContext",
context: "BaseDirectInvocationContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand All @@ -380,15 +381,15 @@ def _handle_gen_event(
output_def, DynamicOutputDefinition
):
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' yielded"
f"Invocation of {context.execution_properties.step_description} yielded"
f" an output '{output_def.name}' multiple times."
)
outputs_seen.add(output_def.name)
return event


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "DirectInvocationOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "BaseDirectInvocationContext"
) -> Any:
"""Type checks and returns the result of a op.
Expand Down Expand Up @@ -462,7 +463,7 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "DirectInvocationOpExecutionContext"
op_def: "OpDefinition", result: T, context: "BaseDirectInvocationContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -480,25 +481,25 @@ def _type_check_function_output(
def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "DirectInvocationOpExecutionContext",
context: "BaseDirectInvocationContext",
) -> None:
"""Validates and performs core type check on a provided output.
Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (DirectInvocationOpExecutionContext): Context containing resources to be used for type
context (BaseDirectInvocationContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.describe_op()
step_label = context.execution_properties.step_description
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} output "{output.output_name}" - '
f'Type check failed for {step_label} output "{output.output_name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from abc import abstractmethod
from contextlib import ExitStack
from typing import (
AbstractSet,
Expand Down Expand Up @@ -55,15 +56,22 @@
from dagster._utils.forked_pdb import ForkedPdb
from dagster._utils.merger import merge_dicts

from .compute import AssetExecutionContext, ExecutionProperties, OpExecutionContext, RunProperties
from .compute import (
AssetExecutionContext,
ContextHasExecutionProperties,
ExecutionProperties,
OpExecutionContext,
RunProperties,
)
from .system import StepExecutionContext, TypeCheckContext


def _property_msg(prop_name: str, method_name: str, step_type: str) -> str:
return f"The {prop_name} {method_name} is not set on the context when an {step_type} is directly invoked."


class BaseDirectInvocationContext:
class BaseDirectInvocationContext(ContextHasExecutionProperties):
@abstractmethod
def bind(
self,
op_def: OpDefinition,
Expand All @@ -74,6 +82,14 @@ def bind(
):
pass

@abstractmethod
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
pass

@abstractmethod
def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
pass


class DirectInvocationOpExecutionContext(OpExecutionContext, BaseDirectInvocationContext):
"""The ``context`` object available as the first argument to an op's compute function when
Expand Down Expand Up @@ -716,9 +732,18 @@ def unbind(self):

self._bound = False

@property
def op_execution_context(self) -> OpExecutionContext:
return self._op_execution_context
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
self._check_bound(fn_name="for_type", fn_type="method")
resources = cast(NamedTuple, self.resources)
return TypeCheckContext(
self.run_id,
self.log,
ScopedResourcesBuilder(resources._asdict()),
dagster_type,
)

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self._op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)

@property
def run_properties(self) -> RunProperties:
Expand Down

0 comments on commit 348a8c7

Please sign in to comment.