diff --git a/python_modules/dagster/dagster/_core/decorator_utils.py b/python_modules/dagster/dagster/_core/decorator_utils.py index 87a7d3b309cb1..a3492632691a3 100644 --- a/python_modules/dagster/dagster/_core/decorator_utils.py +++ b/python_modules/dagster/dagster/_core/decorator_utils.py @@ -270,3 +270,9 @@ def is_resource_def(obj: Any) -> TypeGuard["ResourceDefinition"]: """ class_names = [cls.__name__ for cls in inspect.getmro(obj.__class__)] return "ResourceDefinition" in class_names + + +def is_context_provided(params: Sequence[Parameter]) -> bool: + if len(params) == 0: + return False + return params[0].name in get_valid_name_permutations("context") diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py index 917f30d695667..16b24dada8ced 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/asset_decorator.py @@ -21,8 +21,12 @@ from dagster._annotations import deprecated_param, experimental_param from dagster._builtins import Nothing from dagster._config import UserConfigSchema +<<<<<<< HEAD from dagster._core.decorator_utils import get_function_params, get_valid_name_permutations from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep +======= +from dagster._core.decorator_utils import get_function_params +>>>>>>> eed11d9bde (error on bad type annotation at def time) from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy from dagster._core.definitions.config import ConfigMapping from dagster._core.definitions.freshness_policy import FreshnessPolicy @@ -326,6 +330,7 @@ def __call__(self, fn: Callable) -> AssetsDefinition: from dagster._core.execution.build_resources import wrap_resources_for_execution validate_resource_annotated_function(fn) + _validate_context_type_hint(fn) asset_name = self.name or fn.__name__ asset_ins = build_asset_ins(fn, self.ins or {}, {dep.asset_key for dep in self.deps}) @@ -832,11 +837,10 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition: def get_function_params_without_context_or_config_or_resources(fn: Callable) -> List[Parameter]: + from dagster._core.decorator_utils import is_context_provided + params = get_function_params(fn) - is_context_provided = len(params) > 0 and params[0].name in get_valid_name_permutations( - "context" - ) - input_params = params[1:] if is_context_provided else params + input_params = params[1:] if is_context_provided(params) else params resource_arg_names = {arg.name for arg in get_resource_args(fn)} @@ -1311,3 +1315,18 @@ def _get_partition_mappings_from_deps( ) return partition_mappings + + +def _validate_context_type_hint(fn): + from inspect import _empty as EmptyAnnotation + + from dagster._core.decorator_utils import get_function_params, is_context_provided + from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext + + params = get_function_params(fn) + if is_context_provided(params): + if not isinstance(params[0], (AssetExecutionContext, OpExecutionContext, EmptyAnnotation)): + raise DagsterInvalidDefinitionError( + f"Cannot annotate `context` parameter with type {params[0].annotation}. `context`" + " must be annotated with AssetExecutionContext, OpExecutionContext, or left blank." + ) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py index 49c216330d805..e745aad0f4c78 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py @@ -21,7 +21,7 @@ from dagster._core.decorator_utils import ( format_docstring_for_description, get_function_params, - get_valid_name_permutations, + is_context_provided, param_is_var_keyword, positional_arg_name_list, ) @@ -285,10 +285,8 @@ def has_context_arg(self) -> bool: return is_context_provided(get_function_params(self.decorated_fn)) def get_context_arg(self) -> Parameter: - for param in get_function_params(self.decorated_fn): - if param.name == "context": - return param - + if self.has_context_arg(): + return get_function_params(self.decorated_fn)[0] check.failed("Requested context arg on function that does not have one") @lru_cache(maxsize=1) @@ -344,12 +342,6 @@ def has_context_arg(self) -> bool: return False -def is_context_provided(params: Sequence[Parameter]) -> bool: - if len(params) == 0: - return False - return params[0].name in get_valid_name_permutations("context") - - def resolve_checked_op_fn_inputs( decorator_name: str, fn_name: str, diff --git a/python_modules/dagster/dagster/_core/definitions/op_definition.py b/python_modules/dagster/dagster/_core/definitions/op_definition.py index fb4c1c3dca5da..30de8ca6b1574 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/op_definition.py @@ -30,7 +30,11 @@ OutputManagerRequirement, ResourceRequirement, ) -from dagster._core.errors import DagsterInvalidInvocationError, DagsterInvariantViolationError +from dagster._core.errors import ( + DagsterInvalidDefinitionError, + DagsterInvalidInvocationError, + DagsterInvariantViolationError, +) from dagster._core.types.dagster_type import DagsterType, DagsterTypeKind from dagster._utils import IHasInternalInit from dagster._utils.warnings import normalize_renamed_param @@ -143,9 +147,11 @@ def __init__( exclude_nothing=True, ) self._compute_fn = compute_fn + _validate_context_type_hint(self._compute_fn.decorated_fn) else: resolved_input_defs = input_defs self._compute_fn = check.callable_param(compute_fn, "compute_fn") + _validate_context_type_hint(self._compute_fn) code_version = normalize_renamed_param( code_version, @@ -504,3 +510,18 @@ def _resolve_output_defs_from_outs( ) return output_defs + + +def _validate_context_type_hint(fn): + from inspect import _empty as EmptyAnnotation + + from dagster._core.decorator_utils import get_function_params, is_context_provided + from dagster._core.execution.context.compute import OpExecutionContext + + params = get_function_params(fn) + if is_context_provided(params): + if not isinstance(params[0], (OpExecutionContext, EmptyAnnotation)): + raise DagsterInvalidDefinitionError( + f"Cannot annotate `context` parameter with type {params[0].annotation}. `context`" + " must be annotated with OpExecutionContext or left blank." + ) diff --git a/python_modules/dagster/dagster/_core/definitions/source_asset.py b/python_modules/dagster/dagster/_core/definitions/source_asset.py index 33a06525883fc..53e950a8332b5 100644 --- a/python_modules/dagster/dagster/_core/definitions/source_asset.py +++ b/python_modules/dagster/dagster/_core/definitions/source_asset.py @@ -181,10 +181,8 @@ def is_observable(self) -> bool: return self.node_def is not None def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction): - from dagster._core.definitions.decorators.op_decorator import ( - DecoratedOpFunction, - is_context_provided, - ) + from dagster._core.decorator_utils import is_context_provided + from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction from dagster._core.execution.context.compute import ( OpExecutionContext, ) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 2918ff0d47e8a..22f34e0a5db99 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -39,7 +39,6 @@ from dagster._core.definitions.step_launcher import StepLauncher from dagster._core.definitions.time_window_partitions import TimeWindow from dagster._core.errors import ( - DagsterInvalidDefinitionError, DagsterInvalidPropertyError, DagsterInvariantViolationError, ) @@ -1524,10 +1523,4 @@ def build_execution_context( return AssetExecutionContext(op_context) if is_sda_step else op_context if context_annotation is AssetExecutionContext: return AssetExecutionContext(op_context) - if context_annotation is OpExecutionContext: - return op_context - - raise DagsterInvalidDefinitionError( - f"Cannot annotate `context` parameter with type {context_annotation}. `context` must be" - " annotated with AssetExecutionContext, OpExecutionContext, or left blank." - ) + return op_context diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index 9222e04e8fd71..7b9ad96d8649f 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -236,16 +236,20 @@ def op_annotation(context: OpExecutionContext, *args): def test_error_on_invalid_context_annotation(): - @op - def the_op(context: int): - pass + with pytest.raises( + DagsterInvalidDefinitionError, + match="must be annotated with OpExecutionContext or left blank", + ): - @job - def the_job(): - the_op() + @op + def the_op(context: int): + pass with pytest.raises( DagsterInvalidDefinitionError, match="must be annotated with AssetExecutionContext, OpExecutionContext, or left blank", ): - assert the_job.execute_in_process() + + @asset + def the_asset(context: int): + pass