From 07923e8c7d431b7033b6ec51724e1f26687594cb Mon Sep 17 00:00:00 2001 From: jamiedemaria Date: Mon, 29 Jan 2024 13:15:04 -0500 Subject: [PATCH] Only have one kind of context for direct invocation (#17554) --- .../_core/definitions/op_invocation.py | 123 ++-- .../_core/execution/context/invocation.py | 646 +++++++++--------- .../dagster/dagster/_core/pipes/context.py | 6 +- .../test_partitioned_assets.py | 2 +- .../test_direct_invocation.py | 60 +- .../core_tests/test_op_invocation.py | 258 ++++++- 6 files changed, 676 insertions(+), 419 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index 984dd8003626b..3dc8e5efc1a67 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -32,7 +32,7 @@ from .result import MaterializeResult if TYPE_CHECKING: - from ..execution.context.invocation import BoundOpExecutionContext + from ..execution.context.invocation import DirectOpExecutionContext from .assets import AssetsDefinition from .composition import PendingNodeInvocation from .decorators.op_decorator import DecoratedOpFunction @@ -109,7 +109,7 @@ def direct_invocation_result( ) -> Any: from dagster._config.pythonic_config import Config from dagster._core.execution.context.invocation import ( - UnboundOpExecutionContext, + DirectOpExecutionContext, build_op_context, ) @@ -149,12 +149,12 @@ def direct_invocation_result( " no context was provided when invoking." ) if len(args) > 0: - if args[0] is not None and not isinstance(args[0], UnboundOpExecutionContext): + if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext): raise DagsterInvalidInvocationError( f"Decorated function '{compute_fn.name}' has context argument, " "but no context was provided when invoking." ) - context = cast(UnboundOpExecutionContext, args[0]) + context = cast(DirectOpExecutionContext, args[0]) # update args to omit context args = args[1:] else: # context argument is provided under kwargs @@ -165,14 +165,14 @@ def direct_invocation_result( f"'{context_param_name}', but no value for '{context_param_name}' was " f"found when invoking. Provided kwargs: {kwargs}" ) - context = cast(UnboundOpExecutionContext, kwargs[context_param_name]) + context = cast(DirectOpExecutionContext, kwargs[context_param_name]) # update kwargs to remove context kwargs = { kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name } # allow passing context, even if the function doesn't have an arg for it - elif len(args) > 0 and isinstance(args[0], UnboundOpExecutionContext): - context = cast(UnboundOpExecutionContext, args[0]) + elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext): + context = cast(DirectOpExecutionContext, args[0]) args = args[1:] resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()} @@ -206,24 +206,31 @@ def direct_invocation_result( ), ) - input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context) + try: + # if the compute function fails, we want to ensure we unbind the context. This + # try-except handles "vanilla" asset and op invocation (generators and async handled in + # _type_check_output_wrapper) - result = invoke_compute_fn( - fn=compute_fn.decorated_fn, - context=bound_context, - kwargs=input_dict, - context_arg_provided=compute_fn.has_context_arg(), - config_arg_cls=( - compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None - ), - resource_args=resource_arg_mapping, - ) + input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context) - return _type_check_output_wrapper(op_def, result, bound_context) + result = invoke_compute_fn( + fn=compute_fn.decorated_fn, + context=bound_context, + kwargs=input_dict, + context_arg_provided=compute_fn.has_context_arg(), + config_arg_cls=( + compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None + ), + resource_args=resource_arg_mapping, + ) + return _type_check_output_wrapper(op_def, result, bound_context) + except Exception: + bound_context.unbind() + raise def _resolve_inputs( - op_def: "OpDefinition", args, kwargs, context: "BoundOpExecutionContext" + op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext" ) -> Mapping[str, Any]: from dagster._core.execution.plan.execute_step import do_type_check @@ -263,7 +270,7 @@ def _resolve_inputs( node_label = op_def.node_type_str raise DagsterInvalidInvocationError( - f"Too many input arguments were provided for {node_label} '{context.alias}'." + f"Too many input arguments were provided for {node_label} '{context.per_invocation_properties.alias}'." f" {suggestion}" ) @@ -306,7 +313,7 @@ def _resolve_inputs( input_dict[k] = v # Type check inputs - op_label = context.describe_op() + op_label = context.per_invocation_properties.step_description for input_name, val in input_dict.items(): input_def = input_defs_by_name[input_name] @@ -326,31 +333,42 @@ def _resolve_inputs( return input_dict -def _key_for_result(result: MaterializeResult, context: "BoundOpExecutionContext") -> AssetKey: +def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey: + if not context.per_invocation_properties.assets_def: + raise DagsterInvariantViolationError( + f"Op {context.per_invocation_properties.alias} does not have an assets definition." + ) if result.asset_key: return result.asset_key - if len(context.assets_def.keys) == 1: - return next(iter(context.assets_def.keys)) + if ( + context.per_invocation_properties.assets_def + and len(context.per_invocation_properties.assets_def.keys) == 1 + ): + return next(iter(context.per_invocation_properties.assets_def.keys)) raise DagsterInvariantViolationError( "MaterializeResult did not include asset_key and it can not be inferred. Specify which" - f" asset_key, options are: {context.assets_def.keys}" + f" asset_key, options are: {context.per_invocation_properties.assets_def.keys}" ) def _output_name_for_result_obj( event: MaterializeResult, - context: "BoundOpExecutionContext", + context: "DirectOpExecutionContext", ): + if not context.per_invocation_properties.assets_def: + raise DagsterInvariantViolationError( + f"Op {context.per_invocation_properties.alias} does not have an assets definition." + ) asset_key = _key_for_result(event, context) - return context.assets_def.get_output_name_for_asset_key(asset_key) + return context.per_invocation_properties.assets_def.get_output_name_for_asset_key(asset_key) def _handle_gen_event( event: T, op_def: "OpDefinition", - context: "BoundOpExecutionContext", + context: "DirectOpExecutionContext", output_defs: Mapping[str, OutputDefinition], outputs_seen: Set[str], ) -> T: @@ -376,7 +394,7 @@ def _handle_gen_event( output_def, DynamicOutputDefinition ): raise DagsterInvariantViolationError( - f"Invocation of {op_def.node_type_str} '{context.alias}' yielded" + f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' yielded" f" an output '{output_def.name}' multiple times." ) outputs_seen.add(output_def.name) @@ -384,7 +402,7 @@ def _handle_gen_event( def _type_check_output_wrapper( - op_def: "OpDefinition", result: Any, context: "BoundOpExecutionContext" + op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext" ) -> Any: """Type checks and returns the result of a op. @@ -399,8 +417,14 @@ def _type_check_output_wrapper( async def to_gen(async_gen): outputs_seen = set() - async for event in async_gen: - yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen) + try: + # if the compute function fails, we want to ensure we unbind the context. For + # async generators, the errors will only be surfaced here + async for event in async_gen: + yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen) + except Exception: + context.unbind() + raise for output_def in op_def.output_defs: if ( @@ -413,9 +437,10 @@ async def to_gen(async_gen): yield Output(output_name=output_def.name, value=None) else: raise DagsterInvariantViolationError( - f"Invocation of {op_def.node_type_str} '{context.alias}' did not" + f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' did not" f" return an output for non-optional output '{output_def.name}'" ) + context.unbind() return to_gen(result) @@ -423,7 +448,13 @@ async def to_gen(async_gen): elif inspect.iscoroutine(result): async def type_check_coroutine(coro): - out = await coro + try: + # if the compute function fails, we want to ensure we unbind the context. For + # async, the errors will only be surfaced here + out = await coro + except Exception: + context.unbind() + raise return _type_check_function_output(op_def, out, context) return type_check_coroutine(result) @@ -433,8 +464,14 @@ async def type_check_coroutine(coro): def type_check_gen(gen): outputs_seen = set() - for event in gen: - yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen) + try: + # if the compute function fails, we want to ensure we unbind the context. For + # generators, the errors will only be surfaced here + for event in gen: + yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen) + except Exception: + context.unbind() + raise for output_def in op_def.output_defs: if ( @@ -447,9 +484,10 @@ def type_check_gen(gen): yield Output(output_name=output_def.name, value=None) else: raise DagsterInvariantViolationError( - f'Invocation of {op_def.node_type_str} "{context.alias}" did not' + f'Invocation of {op_def.node_type_str} "{context.per_invocation_properties.alias}" did not' f' return an output for non-optional output "{output_def.name}"' ) + context.unbind() return type_check_gen(result) @@ -458,7 +496,7 @@ def type_check_gen(gen): def _type_check_function_output( - op_def: "OpDefinition", result: T, context: "BoundOpExecutionContext" + op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext" ) -> T: from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator @@ -470,25 +508,26 @@ def _type_check_function_output( # ensure result objects are contextually valid _output_name_for_result_obj(event, context) + context.unbind() return result def _type_check_output( output_def: "OutputDefinition", output: Union[Output, DynamicOutput], - context: "BoundOpExecutionContext", + context: "DirectOpExecutionContext", ) -> None: """Validates and performs core type check on a provided output. Args: output_def (OutputDefinition): The output definition to validate against. output (Any): The output to validate. - context (BoundOpExecutionContext): Context containing resources to be used for type + context (DirectOpExecutionContext): Context containing resources to be used for type check. """ from ..execution.plan.execute_step import do_type_check - op_label = context.describe_op() + op_label = context.per_invocation_properties.step_description dagster_type = output_def.dagster_type type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value) if not type_check.success: diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index c53d673e3200e..02038f15127ed 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -54,6 +54,7 @@ from dagster._core.types.dagster_type import DagsterType from dagster._utils.forked_pdb import ForkedPdb from dagster._utils.merger import merge_dicts +from dagster._utils.warnings import deprecation_warning from .compute import OpExecutionContext from .system import StepExecutionContext, TypeCheckContext @@ -61,12 +62,73 @@ def _property_msg(prop_name: str, method_name: str) -> str: return ( - f"The {prop_name} {method_name} is not set on the context when a solid is directly invoked." + f"The {prop_name} {method_name} is not set on the context when an op is directly invoked." ) -class UnboundOpExecutionContext(OpExecutionContext): - """The ``context`` object available as the first argument to a solid's compute function when +class PerInvocationProperties( + NamedTuple( + "_PerInvocationProperties", + [ + ("op_def", OpDefinition), + ("tags", Mapping[Any, Any]), + ("hook_defs", Optional[AbstractSet[HookDefinition]]), + ("alias", str), + ("assets_def", Optional[AssetsDefinition]), + ("resources", Resources), + ("op_config", Any), + ("step_description", str), + ], + ) +): + """Maintains properties that are only available once the context has been bound to a particular + asset or op invocation. By splitting these out into a separate object, it is easier to ensure that + all properties bound to an invocation are cleared once the execution is complete. + """ + + def __new__( + cls, + op_def: OpDefinition, + tags: Mapping[Any, Any], + hook_defs: Optional[AbstractSet[HookDefinition]], + alias: str, + assets_def: Optional[AssetsDefinition], + resources: Resources, + op_config: Any, + step_description: str, + ): + return super(PerInvocationProperties, cls).__new__( + cls, + op_def=check.inst_param(op_def, "op_def", OpDefinition), + tags=check.dict_param(tags, "tags"), + hook_defs=check.opt_set_param(hook_defs, "hook_defs", HookDefinition), + alias=check.str_param(alias, "alias"), + assets_def=check.opt_inst_param(assets_def, "assets_def", AssetsDefinition), + resources=check.inst_param(resources, "resources", Resources), + op_config=op_config, + step_description=step_description, + ) + + +class DirectExecutionProperties: + """Maintains information about the execution that can only be updated during execution (when + the context is bound), but can be read after execution is complete. It needs to be cleared before + the context is used for another execution. + + This is not implemented as a NamedTuple because the various attributes will be mutated during + execution. + """ + + def __init__(self): + self.user_events: List[UserEvent] = [] + self.seen_outputs: Dict[str, Union[str, Set[str]]] = {} + self.output_metadata: Dict[str, Dict[str, Union[Any, Mapping[str, Any]]]] = {} + self.requires_typed_event_stream: bool = False + self.typed_event_stream_error_message: Optional[str] = None + + +class DirectOpExecutionContext(OpExecutionContext): + """The ``context`` object available as the first argument to an op's compute function when being invoked directly. Can also be used as a context manager. """ @@ -79,7 +141,6 @@ def __init__( partition_key: Optional[str], partition_key_range: Optional[PartitionKeyRange], mapping_key: Optional[str], - assets_def: Optional[AssetsDefinition], ): from dagster._core.execution.api import ephemeral_instance_if_missing from dagster._core.execution.context_creation_job import initialize_console_manager @@ -114,10 +175,28 @@ def __init__( ) self._partition_key = partition_key self._partition_key_range = partition_key_range - self._user_events: List[UserEvent] = [] - self._output_metadata: Dict[str, Any] = {} - self._assets_def = check.opt_inst_param(assets_def, "assets_def", AssetsDefinition) + # Maintains the properties on the context that are bound to a particular invocation + # of an op + # @op + # def my_op(context): + # # context._per_invocation_properties.alias is "my_op" + # ... + # ctx = build_op_context() # ctx._per_invocation_properties is None + # my_op(ctx) + # ctx._per_invocation_properties is None # ctx is unbound at the end of invocation + self._per_invocation_properties = None + + # Maintains the properties on the context that are modified during invocation + # @op + # def my_op(context): + # # context._execution_properties can be modified with output metadata etc. + # ... + # ctx = build_op_context() # ctx._execution_properties is empty + # my_op(ctx) + # ctx._execution_properties.output_metadata # information is retained after invocation + # my_op(ctx) # ctx._execution_properties is cleared at the beginning of the next invocation + self._execution_properties = DirectExecutionProperties() def __enter__(self): self._cm_scope_entered = True @@ -129,9 +208,125 @@ def __exit__(self, *exc): def __del__(self): self._exit_stack.close() + def _check_bound_to_invocation(self, fn_name: str, fn_type: str) -> PerInvocationProperties: + if self._per_invocation_properties is None: + raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type)) + # return self._per_invocation_properties so that the calling function can access properties + # of self._per_invocation_properties without causing pyright errors + return self._per_invocation_properties + + def bind( + self, + op_def: OpDefinition, + pending_invocation: Optional[PendingNodeInvocation[OpDefinition]], + assets_def: Optional[AssetsDefinition], + config_from_args: Optional[Mapping[str, Any]], + resources_from_args: Optional[Mapping[str, Any]], + ) -> "DirectOpExecutionContext": + from dagster._core.definitions.resource_invocation import resolve_bound_config + + if self._per_invocation_properties is not None: + raise DagsterInvalidInvocationError( + f"This context is currently being used to execute {self.alias}. The context cannot be used to execute another op until {self.alias} has finished executing." + ) + + # reset execution_properties + self._execution_properties = DirectExecutionProperties() + + # update the bound context with properties relevant to the invocation of the op + invocation_tags = ( + pending_invocation.tags + if isinstance(pending_invocation, PendingNodeInvocation) + else None + ) + tags = merge_dicts(op_def.tags, invocation_tags) if invocation_tags else op_def.tags + + hook_defs = ( + pending_invocation.hook_defs + if isinstance(pending_invocation, PendingNodeInvocation) + else None + ) + invocation_alias = ( + pending_invocation.given_alias + if isinstance(pending_invocation, PendingNodeInvocation) + else None + ) + alias = invocation_alias if invocation_alias else op_def.name + + if resources_from_args: + if self._resource_defs: + raise DagsterInvalidInvocationError( + "Cannot provide resources in both context and kwargs" + ) + resource_defs = wrap_resources_for_execution(resources_from_args) + # add new resources context to the stack to be cleared on exit + resources = self._exit_stack.enter_context( + build_resources(resource_defs, self.instance) + ) + elif assets_def and assets_def.resource_defs: + for key in sorted(list(assets_def.resource_defs.keys())): + if key in self._resource_defs: + raise DagsterInvalidInvocationError( + f"Error when invoking {assets_def!s} resource '{key}' " + "provided on both the definition and invocation context. Please " + "provide on only one or the other." + ) + resource_defs = wrap_resources_for_execution( + {**self._resource_defs, **assets_def.resource_defs} + ) + # add new resources context to the stack to be cleared on exit + resources = self._exit_stack.enter_context( + build_resources(resource_defs, self.instance, self._resources_config) + ) + else: + # this runs the check in resources() to ensure we are in a context manager if necessary + resources = self.resources + + resource_defs = self._resource_defs + + _validate_resource_requirements(resource_defs, op_def) + + if self._op_config and config_from_args: + raise DagsterInvalidInvocationError("Cannot provide config in both context and kwargs") + op_config = resolve_bound_config(config_from_args or self._op_config, op_def) + + step_description = f'op "{op_def.name}"' + + self._per_invocation_properties = PerInvocationProperties( + op_def=op_def, + tags=tags, + hook_defs=hook_defs, + alias=alias, + assets_def=assets_def, + resources=resources, + op_config=op_config, + step_description=step_description, + ) + + return self + + def unbind(self): + self._per_invocation_properties = None + + @property + def is_bound(self) -> bool: + return self._per_invocation_properties is not None + + @property + def execution_properties(self) -> DirectExecutionProperties: + return self._execution_properties + + @property + def per_invocation_properties(self) -> PerInvocationProperties: + return self._check_bound_to_invocation( + fn_name="_per_invocation_properties", fn_type="property" + ) + @property def op_config(self) -> Any: - return self._op_config + if self._per_invocation_properties is None: + return self._op_config + return self._per_invocation_properties.op_config @property def resource_keys(self) -> AbstractSet[str]: @@ -139,6 +334,8 @@ def resource_keys(self) -> AbstractSet[str]: @property def resources(self) -> Resources: + if self._per_invocation_properties is not None: + return self._per_invocation_properties.resources if self._resources_contain_cm and not self._cm_scope_entered: raise DagsterInvariantViolationError( "At least one provided resource is a generator, but attempting to access " @@ -149,7 +346,7 @@ def resources(self) -> Resources: @property def dagster_run(self) -> DagsterRun: - raise DagsterInvalidPropertyError(_property_msg("pipeline_run", "property")) + raise DagsterInvalidPropertyError(_property_msg("dagster_run", "property")) @property def instance(self) -> DagsterInstance: @@ -183,7 +380,19 @@ def run_id(self) -> str: @property def run_config(self) -> dict: - raise DagsterInvalidPropertyError(_property_msg("run_config", "property")) + per_invocation_properties = self._check_bound_to_invocation( + fn_name="run_config", fn_type="property" + ) + + run_config: Dict[str, object] = {} + if self._op_config and per_invocation_properties.op_def: + run_config["ops"] = { + per_invocation_properties.op_def.name: { + "config": per_invocation_properties.op_config + } + } + run_config["resources"] = self._resources_config + return run_config @property def job_def(self) -> JobDefinition: @@ -200,10 +409,10 @@ def log(self) -> DagsterLogManager: @property def node_handle(self) -> NodeHandle: - raise DagsterInvalidPropertyError(_property_msg("solid_handle", "property")) + raise DagsterInvalidPropertyError(_property_msg("node_handle", "property")) @property - def op(self) -> JobDefinition: + def op(self) -> Node: raise DagsterInvalidPropertyError(_property_msg("op", "property")) @property @@ -212,11 +421,29 @@ def solid(self) -> Node: @property def op_def(self) -> OpDefinition: - raise DagsterInvalidPropertyError(_property_msg("op_def", "property")) + per_invocation_properties = self._check_bound_to_invocation( + fn_name="op_def", fn_type="property" + ) + return cast(OpDefinition, per_invocation_properties.op_def) + + @property + def has_assets_def(self) -> bool: + per_invocation_properties = self._check_bound_to_invocation( + fn_name="has_assets_def", fn_type="property" + ) + return per_invocation_properties.assets_def is not None @property def assets_def(self) -> AssetsDefinition: - raise DagsterInvalidPropertyError(_property_msg("assets_def", "property")) + per_invocation_properties = self._check_bound_to_invocation( + fn_name="assets_def", fn_type="property" + ) + + if per_invocation_properties.assets_def is None: + raise DagsterInvalidPropertyError( + f"Op {self.op_def.name} does not have an assets definition." + ) + return per_invocation_properties.assets_def @property def has_partition_key(self) -> bool: @@ -246,89 +473,26 @@ def asset_partition_key_for_output(self, output_name: str = "result") -> str: return self.partition_key def has_tag(self, key: str) -> bool: - raise DagsterInvalidPropertyError(_property_msg("has_tag", "method")) - - def get_tag(self, key: str) -> str: - raise DagsterInvalidPropertyError(_property_msg("get_tag", "method")) - - def get_step_execution_context(self) -> StepExecutionContext: - raise DagsterInvalidPropertyError(_property_msg("get_step_execution_context", "methods")) - - def bind( - self, - op_def: OpDefinition, - pending_invocation: Optional[PendingNodeInvocation[OpDefinition]], - assets_def: Optional[AssetsDefinition], - config_from_args: Optional[Mapping[str, Any]], - resources_from_args: Optional[Mapping[str, Any]], - ) -> "BoundOpExecutionContext": - from dagster._core.definitions.resource_invocation import resolve_bound_config - - if resources_from_args: - if self._resource_defs: - raise DagsterInvalidInvocationError( - "Cannot provide resources in both context and kwargs" - ) - resource_defs = wrap_resources_for_execution(resources_from_args) - # add new resources context to the stack to be cleared on exit - resources = self._exit_stack.enter_context( - build_resources(resource_defs, self.instance) - ) - elif assets_def and assets_def.resource_defs: - for key in sorted(list(assets_def.resource_defs.keys())): - if key in self._resource_defs: - raise DagsterInvalidInvocationError( - f"Error when invoking {assets_def!s} resource '{key}' " - "provided on both the definition and invocation context. Please " - "provide on only one or the other." - ) - resource_defs = wrap_resources_for_execution( - {**self._resource_defs, **assets_def.resource_defs} - ) - # add new resources context to the stack to be cleared on exit - resources = self._exit_stack.enter_context( - build_resources(resource_defs, self.instance, self._resources_config) - ) - else: - resources = self.resources - resource_defs = self._resource_defs - - _validate_resource_requirements(resource_defs, op_def) + per_invocation_properties = self._check_bound_to_invocation( + fn_name="has_tag", fn_type="method" + ) + return key in per_invocation_properties.tags - if self.op_config and config_from_args: - raise DagsterInvalidInvocationError("Cannot provide config in both context and kwargs") - op_config = resolve_bound_config(config_from_args or self.op_config, op_def) + def get_tag(self, key: str) -> Optional[str]: + per_invocation_properties = self._check_bound_to_invocation( + fn_name="get_tag", fn_type="method" + ) + return per_invocation_properties.tags.get(key) - return BoundOpExecutionContext( - op_def=op_def, - op_config=op_config, - resources=resources, - resources_config=self._resources_config, - instance=self.instance, - log_manager=self.log, - pdb=self.pdb, - tags=( - pending_invocation.tags - if isinstance(pending_invocation, PendingNodeInvocation) - else None - ), - hook_defs=( - pending_invocation.hook_defs - if isinstance(pending_invocation, PendingNodeInvocation) - else None - ), - alias=( - pending_invocation.given_alias - if isinstance(pending_invocation, PendingNodeInvocation) - else None - ), - user_events=self._user_events, - output_metadata=self._output_metadata, - mapping_key=self._mapping_key, - partition_key=self._partition_key, - partition_key_range=self._partition_key_range, - assets_def=assets_def, + @property + def alias(self) -> str: + per_invocation_properties = self._check_bound_to_invocation( + fn_name="alias", fn_type="property" ) + return cast(str, per_invocation_properties.alias) + + def get_step_execution_context(self) -> StepExecutionContext: + raise DagsterInvalidPropertyError(_property_msg("get_step_execution_context", "method")) def get_events(self) -> Sequence[UserEvent]: """Retrieve the list of user-generated events that were logged via the context. @@ -351,7 +515,7 @@ def test_my_op(): expectation_results = [event for event in all_user_events if isinstance(event, ExpectationResult)] ... """ - return self._user_events + return self._execution_properties.user_events def get_output_metadata( self, output_name: str, mapping_key: Optional[str] = None @@ -367,7 +531,7 @@ def get_output_metadata( Returns: Optional[Mapping[str, Any]]: The metadata values present for the output_name/mapping_key combination, if present. """ - metadata = self._output_metadata.get(output_name) + metadata = self._execution_properties.output_metadata.get(output_name) if mapping_key and metadata: return metadata.get(mapping_key) return metadata @@ -375,186 +539,8 @@ def get_output_metadata( def get_mapping_key(self) -> Optional[str]: return self._mapping_key - -def _validate_resource_requirements( - resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition -) -> None: - """Validate correctness of resources against required resource keys.""" - if cast(DecoratedOpFunction, op_def.compute_fn).has_context_arg(): - for requirement in op_def.get_resource_requirements(): - if not requirement.is_io_manager_requirement: - ensure_requirements_satisfied(resource_defs, [requirement]) - - -class BoundOpExecutionContext(OpExecutionContext): - """The op execution context that is passed to the compute function during invocation. - - This context is bound to a specific op definition, for which the resources and config have - been validated. - """ - - _op_def: OpDefinition - _op_config: Any - _resources: "Resources" - _resources_config: Mapping[str, Any] - _instance: DagsterInstance - _log_manager: DagsterLogManager - _pdb: Optional[ForkedPdb] - _tags: Mapping[str, str] - _hook_defs: Optional[AbstractSet[HookDefinition]] - _alias: str - _user_events: List[UserEvent] - _seen_outputs: Dict[str, Union[str, Set[str]]] - _output_metadata: Dict[str, Any] - _mapping_key: Optional[str] - _partition_key: Optional[str] - _partition_key_range: Optional[PartitionKeyRange] - _assets_def: Optional[AssetsDefinition] - - def __init__( - self, - op_def: OpDefinition, - op_config: Any, - resources: "Resources", - resources_config: Mapping[str, Any], - instance: DagsterInstance, - log_manager: DagsterLogManager, - pdb: Optional[ForkedPdb], - tags: Optional[Mapping[str, str]], - hook_defs: Optional[AbstractSet[HookDefinition]], - alias: Optional[str], - user_events: List[UserEvent], - output_metadata: Dict[str, Any], - mapping_key: Optional[str], - partition_key: Optional[str], - partition_key_range: Optional[PartitionKeyRange], - assets_def: Optional[AssetsDefinition], - ): - self._op_def = op_def - self._op_config = op_config - self._resources = resources - self._instance = instance - self._log = log_manager - self._pdb = pdb - self._tags = merge_dicts(self._op_def.tags, tags) if tags else self._op_def.tags - self._hook_defs = hook_defs - self._alias = alias if alias else self._op_def.name - self._resources_config = resources_config - self._user_events = user_events - self._seen_outputs = {} - self._output_metadata = output_metadata - self._mapping_key = mapping_key - self._partition_key = partition_key - self._partition_key_range = partition_key_range - self._assets_def = assets_def - self._requires_typed_event_stream = False - self._typed_event_stream_error_message = None - - @property - def op_config(self) -> Any: - return self._op_config - - @property - def resources(self) -> Resources: - return self._resources - - @property - def dagster_run(self) -> DagsterRun: - raise DagsterInvalidPropertyError(_property_msg("pipeline_run", "property")) - - @property - def instance(self) -> DagsterInstance: - return self._instance - - @property - def pdb(self) -> ForkedPdb: - """dagster.utils.forked_pdb.ForkedPdb: Gives access to pdb debugging from within the solid. - - Example: - .. code-block:: python - - @solid - def debug_solid(context): - context.pdb.set_trace() - - """ - if self._pdb is None: - self._pdb = ForkedPdb() - - return self._pdb - - @property - def step_launcher(self) -> Optional[StepLauncher]: - raise DagsterInvalidPropertyError(_property_msg("step_launcher", "property")) - - @property - def run_id(self) -> str: - """str: Hard-coded value to indicate that we are directly invoking solid.""" - return "EPHEMERAL" - - @property - def run_config(self) -> Mapping[str, object]: - run_config: Dict[str, object] = {} - if self._op_config: - run_config["ops"] = {self._op_def.name: {"config": self._op_config}} - run_config["resources"] = self._resources_config - return run_config - - @property - def job_def(self) -> JobDefinition: - raise DagsterInvalidPropertyError(_property_msg("job_def", "property")) - - @property - def job_name(self) -> str: - raise DagsterInvalidPropertyError(_property_msg("job_name", "property")) - - @property - def log(self) -> DagsterLogManager: - """DagsterLogManager: A console manager constructed for this context.""" - return self._log - - @property - def node_handle(self) -> NodeHandle: - raise DagsterInvalidPropertyError(_property_msg("node_handle", "property")) - - @property - def op(self) -> Node: - raise DagsterInvalidPropertyError(_property_msg("op", "property")) - - @property - def op_def(self) -> OpDefinition: - return self._op_def - - @property - def has_assets_def(self) -> bool: - return self._assets_def is not None - - @property - def assets_def(self) -> AssetsDefinition: - if self._assets_def is None: - raise DagsterInvalidPropertyError( - f"Op {self.op_def.name} does not have an assets definition." - ) - return self._assets_def - - @property - def has_partition_key(self) -> bool: - return self._partition_key is not None - - def has_tag(self, key: str) -> bool: - return key in self._tags - - def get_tag(self, key: str) -> Optional[str]: - return self._tags.get(key) - - @property - def alias(self) -> str: - return self._alias - - def get_step_execution_context(self) -> StepExecutionContext: - raise DagsterInvalidPropertyError(_property_msg("get_step_execution_context", "methods")) - def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + self._check_bound_to_invocation(fn_name="for_type", fn_type="method") resources = cast(NamedTuple, self.resources) return TypeCheckContext( self.run_id, @@ -563,62 +549,42 @@ def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: dagster_type, ) - def get_mapping_key(self) -> Optional[str]: - return self._mapping_key - def describe_op(self) -> str: - if isinstance(self.op_def, OpDefinition): - return f'op "{self.op_def.name}"' - - return f'solid "{self.op_def.name}"' + per_invocation_properties = self._check_bound_to_invocation( + fn_name="describe_op", fn_type="method" + ) + return per_invocation_properties.step_description def log_event(self, event: UserEvent) -> None: + self._check_bound_to_invocation(fn_name="log_event", fn_type="method") check.inst_param( event, "event", (AssetMaterialization, AssetObservation, ExpectationResult), ) - self._user_events.append(event) + self._execution_properties.user_events.append(event) def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + self._check_bound_to_invocation(fn_name="observe_output", fn_type="method") if mapping_key: - if output_name not in self._seen_outputs: - self._seen_outputs[output_name] = set() - cast(Set[str], self._seen_outputs[output_name]).add(mapping_key) + if output_name not in self._execution_properties.seen_outputs: + self._execution_properties.seen_outputs[output_name] = set() + cast(Set[str], self._execution_properties.seen_outputs[output_name]).add(mapping_key) else: - self._seen_outputs[output_name] = "seen" + self._execution_properties.seen_outputs[output_name] = "seen" def has_seen_output(self, output_name: str, mapping_key: Optional[str] = None) -> bool: if mapping_key: return ( - output_name in self._seen_outputs and mapping_key in self._seen_outputs[output_name] + output_name in self._execution_properties.seen_outputs + and mapping_key in self._execution_properties.seen_outputs[output_name] ) - return output_name in self._seen_outputs - - @property - def partition_key(self) -> str: - if self._partition_key is not None: - return self._partition_key - check.failed("Tried to access partition_key for a non-partitioned asset") - - @property - def partition_key_range(self) -> PartitionKeyRange: - """The range of partition keys for the current run. - - If run is for a single partition key, return a `PartitionKeyRange` with the same start and - end. Raises an error if the current run is not a partitioned run. - """ - if self._partition_key_range: - return self._partition_key_range - elif self._partition_key: - return PartitionKeyRange(self._partition_key, self._partition_key) - else: - check.failed("Tried to access partition_key range for a non-partitioned run") - - def asset_partition_key_for_output(self, output_name: str = "result") -> str: - return self.partition_key + return output_name in self._execution_properties.seen_outputs def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: + self._check_bound_to_invocation( + fn_name="asset_partitions_time_window_for_output", fn_type="method" + ) partitions_def = self.assets_def.partitions_def if partitions_def is None: check.failed("Tried to access partition_key for a non-partitioned asset") @@ -667,6 +633,7 @@ def add_metadata_two_outputs(context) -> Tuple[str, int]: return ("dog", 5) """ + self._check_bound_to_invocation(fn_name="add_output_metadata", fn_type="method") metadata = check.mapping_param(metadata, "metadata", key_type=str) output_name = check.opt_str_param(output_name, "output_name") mapping_key = check.opt_str_param(mapping_key, "mapping_key") @@ -702,33 +669,51 @@ def add_metadata_two_outputs(context) -> Tuple[str, int]: ) output_name = output_def.name - if output_name in self._output_metadata: - if not mapping_key or mapping_key in self._output_metadata[output_name]: + if output_name in self._execution_properties.output_metadata: + if ( + not mapping_key + or mapping_key in self._execution_properties.output_metadata[output_name] + ): raise DagsterInvariantViolationError( f"In {self.op_def.node_type_str} '{self.op_def.name}', attempted to log" f" metadata for output '{output_name}' more than once." ) if mapping_key: - if output_name not in self._output_metadata: - self._output_metadata[output_name] = {} - self._output_metadata[output_name][mapping_key] = metadata + if output_name not in self._execution_properties.output_metadata: + self._execution_properties.output_metadata[output_name] = {} + self._execution_properties.output_metadata[output_name][mapping_key] = metadata else: - self._output_metadata[output_name] = metadata + self._execution_properties.output_metadata[output_name] = metadata - # In this mode no conversion is done on returned values and missing but expected outputs are not + # In bound mode no conversion is done on returned values and missing but expected outputs are not # allowed. @property def requires_typed_event_stream(self) -> bool: - return self._requires_typed_event_stream + self._check_bound_to_invocation(fn_name="requires_typed_event_stream", fn_type="property") + return self._execution_properties.requires_typed_event_stream @property def typed_event_stream_error_message(self) -> Optional[str]: - return self._typed_event_stream_error_message + self._check_bound_to_invocation( + fn_name="typed_event_stream_error_message", fn_type="property" + ) + return self._execution_properties.typed_event_stream_error_message def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> None: - self._requires_typed_event_stream = True - self._typed_event_stream_error_message = error_message + self._check_bound_to_invocation(fn_name="set_requires_typed_event_stream", fn_type="method") + self._execution_properties.requires_typed_event_stream = True + self._execution_properties.typed_event_stream_error_message = error_message + + +def _validate_resource_requirements( + resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition +) -> None: + """Validate correctness of resources against required resource keys.""" + if cast(DecoratedOpFunction, op_def.compute_fn).has_context_arg(): + for requirement in op_def.get_resource_requirements(): + if not requirement.is_io_manager_requirement: + ensure_requirements_satisfied(resource_defs, [requirement]) def build_op_context( @@ -741,7 +726,7 @@ def build_op_context( partition_key_range: Optional[PartitionKeyRange] = None, mapping_key: Optional[str] = None, _assets_def: Optional[AssetsDefinition] = None, -) -> UnboundOpExecutionContext: +) -> DirectOpExecutionContext: """Builds op execution context from provided parameters. ``build_op_context`` can be used as either a function or context manager. If there is a @@ -778,8 +763,18 @@ def build_op_context( "legacy version, ``config``. Please provide one or the other." ) + if _assets_def: + deprecation_warning( + subject="build_op_context", + additional_warn_text=( + "Parameter '_assets_def' was passed to build_op_context. This parameter was intended for internal use only, and has been deprecated " + ), + breaking_version="1.8.0", + stacklevel=1, + ) + op_config = op_config if op_config else config - return UnboundOpExecutionContext( + return DirectOpExecutionContext( resources_dict=check.opt_mapping_param(resources, "resources", key_type=str), resources_config=check.opt_mapping_param( resources_config, "resources_config", key_type=str @@ -791,7 +786,6 @@ def build_op_context( partition_key_range, "partition_key_range", PartitionKeyRange ), mapping_key=check.opt_str_param(mapping_key, "mapping_key"), - assets_def=check.opt_inst_param(_assets_def, "_assets_def", AssetsDefinition), ) diff --git a/python_modules/dagster/dagster/_core/pipes/context.py b/python_modules/dagster/dagster/_core/pipes/context.py index 96004f036e027..5b61647bbf43b 100644 --- a/python_modules/dagster/dagster/_core/pipes/context.py +++ b/python_modules/dagster/dagster/_core/pipes/context.py @@ -39,7 +39,7 @@ from dagster._core.errors import DagsterPipesExecutionError from dagster._core.events import EngineEventData from dagster._core.execution.context.compute import OpExecutionContext -from dagster._core.execution.context.invocation import BoundOpExecutionContext +from dagster._core.execution.context.invocation import DirectOpExecutionContext from dagster._utils.error import ( ExceptionInfo, SerializableErrorInfo, @@ -406,8 +406,8 @@ def build_external_execution_context_data( _convert_time_window(partition_time_window) if partition_time_window else None ), run_id=context.run_id, - job_name=None if isinstance(context, BoundOpExecutionContext) else context.job_name, - retry_number=0 if isinstance(context, BoundOpExecutionContext) else context.retry_number, + job_name=None if isinstance(context, DirectOpExecutionContext) else context.job_name, + retry_number=0 if isinstance(context, DirectOpExecutionContext) else context.retry_number, extras=extras or {}, ) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py index 4fa06e28b85e2..ac81b56b54a45 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py @@ -227,7 +227,7 @@ def partitioned_asset(context): @asset def non_partitioned_asset(context): with pytest.raises( - CheckError, match="Tried to access partition_key for a non-partitioned asset" + CheckError, match="Tried to access partition_key for a non-partitioned run" ): context.asset_partition_key_for_output() diff --git a/python_modules/dagster/dagster_tests/core_tests/resource_tests/pythonic_resources/test_direct_invocation.py b/python_modules/dagster/dagster_tests/core_tests/resource_tests/pythonic_resources/test_direct_invocation.py index 268677e9e51b9..3a259def54306 100644 --- a/python_modules/dagster/dagster_tests/core_tests/resource_tests/pythonic_resources/test_direct_invocation.py +++ b/python_modules/dagster/dagster_tests/core_tests/resource_tests/pythonic_resources/test_direct_invocation.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from dagster import ( AssetExecutionContext, @@ -8,7 +6,7 @@ asset, op, ) -from dagster._core.errors import DagsterInvalidInvocationError, DagsterInvariantViolationError +from dagster._core.errors import DagsterInvalidInvocationError from dagster._core.execution.context.invocation import build_asset_context, build_op_context @@ -431,50 +429,20 @@ def an_asset( executed.clear() -def test_direct_invocation_output_metadata(): - @asset - def my_asset(context): - context.add_output_metadata({"foo": "bar"}) - - @asset - def my_other_asset(context): - context.add_output_metadata({"baz": "qux"}) - - ctx = build_asset_context() - - my_asset(ctx) - assert ctx.get_output_metadata("result") == {"foo": "bar"} +def test_direct_invocation_resource_context_manager(): + from dagster import resource - with pytest.raises( - DagsterInvariantViolationError, - match="attempted to log metadata for output 'result' more than once", - ): - my_other_asset(ctx) + class YieldedResource: + def get_value(self): + return 1 + @resource + def yielding_resource(context): + yield YieldedResource() -def test_async_assets_with_shared_context(): - @asset - async def async_asset_one(context): - assert context.asset_key.to_user_string() == "async_asset_one" - await asyncio.sleep(0.01) - return "one" - - @asset - async def async_asset_two(context): - assert context.asset_key.to_user_string() == "async_asset_two" - await asyncio.sleep(0.01) - return "two" - - # test that we can run two ops/assets with the same context at the same time without - # overriding op/asset specific attributes - ctx = build_asset_context() - - async def main(): - return await asyncio.gather( - async_asset_one(ctx), - async_asset_two(ctx), - ) + @asset(required_resource_keys={"yielded_resource"}) + def my_asset(context): + assert context.resources.yielded_resource.get_value() == 1 - result = asyncio.run(main()) - assert result[0] == "one" - assert result[1] == "two" + with build_op_context(resources={"yielded_resource": yielding_resource}) as ctx: + my_asset(ctx) diff --git a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py index d575e54ca5358..b41669e98b265 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py @@ -44,7 +44,10 @@ DagsterTypeCheckDidNotPass, ) from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext -from dagster._core.execution.context.invocation import build_asset_context +from dagster._core.execution.context.invocation import ( + DirectOpExecutionContext, + build_asset_context, +) from dagster._utils.test import wrap_op_in_graph_and_execute @@ -1333,3 +1336,256 @@ def foo(context: AssetExecutionContext): partition_key_range=PartitionKeyRange("2023-01-01", "2023-01-02"), ) assert foo(context) == {"2023-01-01": True, "2023-01-02": True} + + +def test_direct_invocation_output_metadata(): + @asset + def my_asset(context): + context.add_output_metadata({"foo": "bar"}) + + @asset + def my_other_asset(context): + context.add_output_metadata({"baz": "qux"}) + + ctx = build_asset_context() + + my_asset(ctx) + assert ctx.get_output_metadata("result") == {"foo": "bar"} + + # context is unbound when used in another invocation. This allows the metadata to be + # added in my_other_asset + my_other_asset(ctx) + + +def test_async_assets_with_shared_context(): + @asset + async def async_asset_one(context): + assert context.asset_key.to_user_string() == "async_asset_one" + await asyncio.sleep(0.01) + return "one" + + @asset + async def async_asset_two(context): + assert context.asset_key.to_user_string() == "async_asset_two" + await asyncio.sleep(0.01) + return "two" + + # test that we can run two ops/assets with the same context at the same time without + # overriding op/asset specific attributes + ctx = build_asset_context() + + async def main(): + return await asyncio.gather( + async_asset_one(ctx), + async_asset_two(ctx), + ) + + with pytest.raises( + DagsterInvalidInvocationError, + match=r"This context is currently being used to execute .* The context" + r" cannot be used to execute another op until .* has finished executing", + ): + asyncio.run(main()) + + +def assert_context_unbound(context: DirectOpExecutionContext): + # to assert that the context is correctly unbound after op invocation + assert not context.is_bound + + +def assert_context_bound(context: DirectOpExecutionContext): + # to assert that the context is correctly bound during op invocation + assert context.is_bound + + +def assert_execution_properties_cleared(context: DirectOpExecutionContext): + # to assert that the invocation properties are reset at the beginning of op invocation + assert len(context.execution_properties.output_metadata.keys()) == 0 + + +def assert_execution_properties_exist(context: DirectOpExecutionContext): + # to assert that the invocation properties remain accessible after op invocation + assert len(context.execution_properties.output_metadata.keys()) > 0 + + +def test_context_bound_state_non_generator(): + @asset + def my_asset(context): + assert_context_bound(context) + assert_execution_properties_cleared(context) + context.add_output_metadata({"foo": "bar"}) + + ctx = build_op_context() + assert_context_unbound(ctx) + + my_asset(ctx) + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + my_asset(ctx) + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + +def test_context_bound_state_generator(): + @op(out={"first": Out(), "second": Out()}) + def generator(context): + assert_context_bound(context) + assert_execution_properties_cleared(context) + context.add_output_metadata({"foo": "bar"}, output_name="first") + yield Output("one", output_name="first") + yield Output("two", output_name="second") + + ctx = build_op_context() + + result = list(generator(ctx)) + assert result[0].value == "one" + assert result[1].value == "two" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + result = list(generator(ctx)) + assert result[0].value == "one" + assert result[1].value == "two" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + +def test_context_bound_state_async(): + @asset + async def async_asset(context): + assert_context_bound(context) + assert_execution_properties_cleared(context) + assert context.asset_key.to_user_string() == "async_asset" + context.add_output_metadata({"foo": "bar"}) + await asyncio.sleep(0.01) + return "one" + + ctx = build_asset_context() + + result = asyncio.run(async_asset(ctx)) + assert result == "one" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + result = asyncio.run(async_asset(ctx)) + assert result == "one" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + +def test_context_bound_state_async_generator(): + @op(out={"first": Out(), "second": Out()}) + async def async_generator(context): + assert_context_bound(context) + assert_execution_properties_cleared(context) + context.add_output_metadata({"foo": "bar"}, output_name="first") + yield Output("one", output_name="first") + await asyncio.sleep(0.01) + yield Output("two", output_name="second") + + ctx = build_op_context() + + async def get_results(): + res = [] + async for output in async_generator(ctx): + res.append(output) + return res + + result = asyncio.run(get_results()) + assert result[0].value == "one" + assert result[1].value == "two" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + result = asyncio.run(get_results()) + assert result[0].value == "one" + assert result[1].value == "two" + assert_context_unbound(ctx) + assert_execution_properties_exist(ctx) + + +def test_bound_state_with_error_assets(): + @asset + def throws_error(context): + assert context.asset_key.to_user_string() == "throws_error" + raise Failure("something bad happened!") + + ctx = build_asset_context() + + with pytest.raises(Failure): + throws_error(ctx) + + assert_context_unbound(ctx) + + @asset + def no_error(context): + assert context.alias == "no_error" + + no_error(ctx) + + +def test_context_bound_state_with_error_ops(): + @op(out={"first": Out(), "second": Out()}) + def throws_error(context): + assert_context_bound(ctx) + raise Failure("something bad happened!") + + ctx = build_op_context() + + with pytest.raises(Failure): + throws_error(ctx) + + assert_context_unbound(ctx) + + +def test_context_bound_state_with_error_generator(): + @op(out={"first": Out(), "second": Out()}) + def generator(context): + assert_context_bound(ctx) + yield Output("one", output_name="first") + raise Failure("something bad happened!") + + ctx = build_op_context() + + with pytest.raises(Failure): + list(generator(ctx)) + + assert_context_unbound(ctx) + + +def test_context_bound_state_with_error_async(): + @asset + async def async_asset(context): + assert_context_bound(ctx) + await asyncio.sleep(0.01) + raise Failure("something bad happened!") + + ctx = build_asset_context() + + with pytest.raises(Failure): + asyncio.run(async_asset(ctx)) + + assert_context_unbound(ctx) + + +def test_context_bound_state_with_error_async_generator(): + @op(out={"first": Out(), "second": Out()}) + async def async_generator(context): + assert_context_bound(ctx) + yield Output("one", output_name="first") + await asyncio.sleep(0.01) + raise Failure("something bad happened!") + + ctx = build_op_context() + + async def get_results(): + res = [] + async for output in async_generator(ctx): + res.append(output) + return res + + with pytest.raises(Failure): + asyncio.run(get_results()) + + assert_context_unbound(ctx)