Skip to content

Commit

Permalink
error on bad type annotation at def time
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Sep 18, 2023
1 parent 45fe573 commit 625e2b0
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 35 deletions.
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/_core/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,9 @@ def is_resource_def(obj: Any) -> TypeGuard["ResourceDefinition"]:
"""
class_names = [cls.__name__ for cls in inspect.getmro(obj.__class__)]
return "ResourceDefinition" in class_names


def is_context_provided(params: Sequence[Parameter]) -> bool:
if len(params) == 0:
return False
return params[0].name in get_valid_name_permutations("context")
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
from dagster._annotations import deprecated_param, experimental_param
from dagster._builtins import Nothing
from dagster._config import UserConfigSchema
<<<<<<< HEAD
from dagster._core.decorator_utils import get_function_params, get_valid_name_permutations
from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep
=======
from dagster._core.decorator_utils import get_function_params
>>>>>>> eed11d9bde (error on bad type annotation at def time)
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.config import ConfigMapping
from dagster._core.definitions.freshness_policy import FreshnessPolicy
Expand Down Expand Up @@ -326,6 +330,7 @@ def __call__(self, fn: Callable) -> AssetsDefinition:
from dagster._core.execution.build_resources import wrap_resources_for_execution

validate_resource_annotated_function(fn)
_validate_context_type_hint(fn)
asset_name = self.name or fn.__name__

asset_ins = build_asset_ins(fn, self.ins or {}, {dep.asset_key for dep in self.deps})
Expand Down Expand Up @@ -832,11 +837,10 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:


def get_function_params_without_context_or_config_or_resources(fn: Callable) -> List[Parameter]:
from dagster._core.decorator_utils import is_context_provided

params = get_function_params(fn)
is_context_provided = len(params) > 0 and params[0].name in get_valid_name_permutations(
"context"
)
input_params = params[1:] if is_context_provided else params
input_params = params[1:] if is_context_provided(params) else params

resource_arg_names = {arg.name for arg in get_resource_args(fn)}

Expand Down Expand Up @@ -1311,3 +1315,18 @@ def _get_partition_mappings_from_deps(
)

return partition_mappings


def _validate_context_type_hint(fn):
from inspect import _empty as EmptyAnnotation

from dagster._core.decorator_utils import get_function_params, is_context_provided
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext

params = get_function_params(fn)
if is_context_provided(params):
if not isinstance(params[0], (AssetExecutionContext, OpExecutionContext, EmptyAnnotation)):
raise DagsterInvalidDefinitionError(
f"Cannot annotate `context` parameter with type {params[0].annotation}. `context`"
" must be annotated with AssetExecutionContext, OpExecutionContext, or left blank."
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dagster._core.decorator_utils import (
format_docstring_for_description,
get_function_params,
get_valid_name_permutations,
is_context_provided,
param_is_var_keyword,
positional_arg_name_list,
)
Expand Down Expand Up @@ -285,10 +285,8 @@ def has_context_arg(self) -> bool:
return is_context_provided(get_function_params(self.decorated_fn))

def get_context_arg(self) -> Parameter:
for param in get_function_params(self.decorated_fn):
if param.name == "context":
return param

if self.has_context_arg():
return get_function_params(self.decorated_fn)[0]
check.failed("Requested context arg on function that does not have one")

@lru_cache(maxsize=1)
Expand Down Expand Up @@ -344,12 +342,6 @@ def has_context_arg(self) -> bool:
return False


def is_context_provided(params: Sequence[Parameter]) -> bool:
if len(params) == 0:
return False
return params[0].name in get_valid_name_permutations("context")


def resolve_checked_op_fn_inputs(
decorator_name: str,
fn_name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
OutputManagerRequirement,
ResourceRequirement,
)
from dagster._core.errors import DagsterInvalidInvocationError, DagsterInvariantViolationError
from dagster._core.errors import (
DagsterInvalidDefinitionError,
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
)
from dagster._core.types.dagster_type import DagsterType, DagsterTypeKind
from dagster._utils import IHasInternalInit
from dagster._utils.warnings import normalize_renamed_param
Expand Down Expand Up @@ -143,9 +147,11 @@ def __init__(
exclude_nothing=True,
)
self._compute_fn = compute_fn
_validate_context_type_hint(self._compute_fn.decorated_fn)
else:
resolved_input_defs = input_defs
self._compute_fn = check.callable_param(compute_fn, "compute_fn")
_validate_context_type_hint(self._compute_fn)

code_version = normalize_renamed_param(
code_version,
Expand Down Expand Up @@ -504,3 +510,18 @@ def _resolve_output_defs_from_outs(
)

return output_defs


def _validate_context_type_hint(fn):
from inspect import _empty as EmptyAnnotation

from dagster._core.decorator_utils import get_function_params, is_context_provided
from dagster._core.execution.context.compute import OpExecutionContext

params = get_function_params(fn)
if is_context_provided(params):
if not isinstance(params[0], (OpExecutionContext, EmptyAnnotation)):
raise DagsterInvalidDefinitionError(
f"Cannot annotate `context` parameter with type {params[0].annotation}. `context`"
" must be annotated with OpExecutionContext or left blank."
)
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,8 @@ def is_observable(self) -> bool:
return self.node_def is not None

def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.decorator_utils import is_context_provided
from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction
from dagster._core.execution.context.compute import (
OpExecutionContext,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from dagster._core.definitions.step_launcher import StepLauncher
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import (
DagsterInvalidDefinitionError,
DagsterInvalidPropertyError,
DagsterInvariantViolationError,
)
Expand Down Expand Up @@ -1524,10 +1523,4 @@ def build_execution_context(
return AssetExecutionContext(op_context) if is_sda_step else op_context
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(op_context)
if context_annotation is OpExecutionContext:
return op_context

raise DagsterInvalidDefinitionError(
f"Cannot annotate `context` parameter with type {context_annotation}. `context` must be"
" annotated with AssetExecutionContext, OpExecutionContext, or left blank."
)
return op_context
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,20 @@ def op_annotation(context: OpExecutionContext, *args):


def test_error_on_invalid_context_annotation():
@op
def the_op(context: int):
pass
with pytest.raises(
DagsterInvalidDefinitionError,
match="must be annotated with OpExecutionContext or left blank",
):

@job
def the_job():
the_op()
@op
def the_op(context: int):
pass

with pytest.raises(
DagsterInvalidDefinitionError,
match="must be annotated with AssetExecutionContext, OpExecutionContext, or left blank",
):
assert the_job.execute_in_process()

@asset
def the_asset(context: int):
pass

0 comments on commit 625e2b0

Please sign in to comment.