Skip to content

Commit

Permalink
[check] call builder (#22260)
Browse files Browse the repository at this point in the history
Adds a utility funtion to check for programatiically building calls
based on a type signatures.

## How I Tested These Changes

added tests
  • Loading branch information
alangenfeld authored Jun 6, 2024
1 parent 08ae999 commit 4af5c63
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 0 deletions.
219 changes: 219 additions & 0 deletions python_modules/dagster/dagster/_check/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import collections.abc
import inspect
import sys
from os import PathLike, fspath
from typing import (
AbstractSet,
Any,
Callable,
Dict,
ForwardRef,
Generator,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
NoReturn,
Optional,
Sequence,
Expand All @@ -19,9 +22,20 @@
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)

from typing_extensions import Annotated

try:
# this type only exists in python 3.10+
from types import UnionType # type: ignore
except ImportError:
UnionType = Union


TypeOrTupleOfTypes = Union[type, Tuple[type, ...]]
Numeric = Union[int, float]
T = TypeVar("T")
Expand Down Expand Up @@ -1204,6 +1218,18 @@ def opt_iterable_param(
return iterable_param(obj, param_name, of_type, additional_message)


def opt_nullable_iterable_param(
obj: Optional[Iterable[T]],
param_name: str,
of_type: Optional[TypeOrTupleOfTypes] = None,
additional_message: Optional[str] = None,
) -> Optional[Iterable[T]]:
if obj is None:
return None

return iterable_param(obj, param_name, of_type, additional_message)


# ########################
# ##### SET
# ########################
Expand Down Expand Up @@ -1805,3 +1831,196 @@ def _check_two_dim_mapping_entries(
) # check level two

return obj


# ###################################################################################################
# ##### CALL BUILDER
# ###################################################################################################


class EvalContext(NamedTuple):
"""Utility class for managing references to global and local namespaces.
These namespaces are passed to ForwardRef._evaluate to resolve the actual
type from a string.
"""

global_ns: dict
local_ns: dict

@staticmethod
def capture_from_frame(depth: int) -> "EvalContext":
ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly

return EvalContext(
# copy to not mess up frame data
ctx_frame.f_globals.copy(),
ctx_frame.f_locals.copy(),
)

def update_from_frame(self, depth: int):
# Update the global and local namespaces with symbols from the target frame
ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly
self.global_ns.update(ctx_frame.f_globals)
self.local_ns.update(ctx_frame.f_locals)

def eval_forward_ref(self, ref: ForwardRef) -> Optional[Type]:
try:
if sys.version_info <= (3, 9):
return ref._evaluate(self.global_ns, self.local_ns) # noqa
else:
return ref._evaluate(self.global_ns, self.local_ns, frozenset()) # noqa
except NameError as e:
raise CheckError(f"Unable to resolve {ref}") from e


def _no_op(_): ...


def _coerce(
ttype: Optional[TypeOrTupleOfTypes],
eval_ctx: Optional[EvalContext],
) -> Optional[TypeOrTupleOfTypes]:
if ttype is Any:
return None
if isinstance(ttype, str):
if eval_ctx is None:
failed(
f"Can not generate check calls from string {ttype} (assumed ForwardRef) without EvalContext"
)
return eval_ctx.eval_forward_ref(ForwardRef(ttype))
if isinstance(ttype, ForwardRef):
if eval_ctx is None:
failed(f"Can not evaluate ForwardRef {ttype} without passing in EvalContext")
return eval_ctx.eval_forward_ref(ttype)
if get_origin(ttype) in (UnionType, Union):
optional_args = get_args(ttype)
tuple_types = _container_pair_args(optional_args, eval_ctx)
if None in tuple_types:
failed(f"Unable to turn Optional in to tuple of types for {optional_args} from {ttype}")
return tuple_types # type: ignore # static analysis cant follow above check

return ttype


def _container_pair_args(
args: Tuple[Type, ...], eval_ctx
) -> Tuple[Optional[TypeOrTupleOfTypes], Optional[TypeOrTupleOfTypes]]:
if len(args) == 2:
return _coerce(args[0], eval_ctx), _coerce(args[1], eval_ctx)

return None, None


def _container_single_arg(
args: Tuple[Type, ...], eval_ctx: Optional[EvalContext]
) -> Optional[TypeOrTupleOfTypes]:
if len(args) == 1:
return _coerce(args[0], eval_ctx)

return None


def build_check_call(
ttype: Type,
name: str,
# have this be passed in to avoid guessing which stack frame to use
eval_ctx: Optional[EvalContext] = None,
) -> Callable[[Any], Any]:
# performance notes:
# * lambdas measured to beat functools.partial
# * positional arg use measured to beat keyword arg use

origin = get_origin(ttype)
args = get_args(ttype)

# scalars
if origin is None:
if ttype is str:
return lambda o: str_param(o, name)
elif ttype is float:
return lambda o: float_param(o, name)
elif ttype is int:
return lambda o: int_param(o, name)
elif ttype is bool:
return lambda o: int_param(o, name)
elif ttype is Any:
return _no_op

# fallback to inst
inst_type = _coerce(ttype, eval_ctx)
if inst_type:
return lambda o: inst_param(o, name, inst_type)
else:
return _no_op
else:
if origin is Annotated and args:
return build_check_call(args[0], name, eval_ctx)

pair_left, pair_right = _container_pair_args(args, eval_ctx)
single = _container_single_arg(args, eval_ctx)

# containers
if origin is list:
return lambda o: list_param(o, name, single)
elif origin is dict:
return lambda o: dict_param(o, name, pair_left, pair_right)
elif origin is set:
return lambda o: set_param(o, param_name=name, of_type=single)
elif origin is collections.abc.Sequence:
return lambda o: sequence_param(o, param_name=name, of_type=single)
elif origin is collections.abc.Iterable:
return lambda o: iterable_param(o, param_name=name, of_type=single)
elif origin is collections.abc.Mapping:
return lambda o: mapping_param(o, name, pair_left, pair_right)

elif origin in (UnionType, Union):
# optional
if pair_right is type(None):
inner_origin = get_origin(pair_left)
# optional scalar
if inner_origin is None:
if pair_left is str:
return lambda o: opt_str_param(o, param_name=name)
elif pair_left is float:
return lambda o: opt_float_param(o, param_name=name)
elif pair_left is int:
return lambda o: opt_int_param(o, param_name=name)
elif pair_left is bool:
return lambda o: opt_bool_param(o, param_name=name)

# fallback to opt_inst
inst_type = _coerce(pair_left, eval_ctx)
if inst_type:
return lambda o: opt_inst_param(o, ttype=inst_type, param_name=name)
else:
return _no_op

# optional container
else:
inner_args = get_args(pair_left)
inner_pair_left, inner_pair_right = _container_pair_args(inner_args, eval_ctx)
inner_single = _container_single_arg(inner_args, eval_ctx)
if inner_origin is list:
return lambda o: opt_nullable_list_param(o, name, inner_single)
elif inner_origin is dict:
return lambda o: opt_nullable_dict_param(
o, name, inner_pair_left, inner_pair_right
)
elif inner_origin is set:
return lambda o: opt_nullable_set_param(o, name, inner_single)
elif inner_origin is collections.abc.Sequence:
return lambda o: opt_nullable_sequence_param(o, name, inner_single)
elif inner_origin is collections.abc.Iterable:
return lambda o: opt_nullable_iterable_param(o, name, inner_single)
elif inner_origin is collections.abc.Mapping:
return lambda o: opt_nullable_mapping_param(
o, name, inner_pair_left, inner_pair_right
)
# union
else:
tuple_types = _coerce(ttype, eval_ctx)
if tuple_types is not None:
return lambda o: inst_param(o, name, tuple_types)

failed(f"Unhandled {ttype}")
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import sys
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union

import dagster._check as check
import pytest
from dagster._annotations import PublicAttr
from dagster._check import (
CheckError,
ElementCheckError,
EvalContext,
NotImplementedCheckError,
ParameterCheckError,
build_check_call,
)


Expand Down Expand Up @@ -1508,3 +1512,84 @@ def test_opt_iterable():

with pytest.raises(CheckError, match="Member of iterable mismatches type"):
check.opt_iterable_param(["atr", None], "nonedoesntcount", of_type=str)


# ###################################################################################################
# ##### CHECK BUILDER
# ###################################################################################################


class Foo: ...


class SubFoo(Foo): ...


class Bar: ...


BUILD_CASES = [
(int, 4, "4"),
(float, 4.2, "4.1"),
(str, "hi", Foo()),
(Bar, Bar(), Foo()),
(Optional[Bar], Bar(), Foo()),
(List[str], ["a", "b"], [1, 2]),
(Sequence[str], ["a", "b"], [1, 2]),
(Iterable[str], ["a", "b"], [1, 2]),
(Set[str], {"a", "b"}, {1, 2}),
(Dict[str, int], {"a": 1}, {1: "a"}),
(Mapping[str, int], {"a": 1}, {1: "a"}),
(Optional[int], None, "4"),
(Optional[Bar], None, Foo()),
(Optional[List[str]], ["a", "b"], [1, 2]),
(Optional[Sequence[str]], ["a", "b"], [1, 2]),
(Optional[Iterable[str]], ["a", "b"], [1, 2]),
(Optional[Set[str]], {"a", "b"}, {1, 2}),
(Optional[Dict[str, int]], {"a": 1}, {1: "a"}),
(Optional[Mapping[str, int]], {"a": 1}, {1: "a"}),
(PublicAttr[Optional[Mapping[str, int]]], {"a": 1}, {1: "a"}),
(PublicAttr[Bar], Bar(), Foo()),
(Union[bool, Foo], True, None),
# fwd refs
("Foo", Foo(), Bar()),
(Optional["Foo"], Foo(), Bar()),
(PublicAttr[Optional["Foo"]], None, Bar()),
(Mapping[str, Optional["Foo"]], {"foo": Foo()}, {"bar": Bar()}),
]


@pytest.mark.parametrize("ttype, should_succeed, should_fail", BUILD_CASES)
def test_build_check_call(ttype, should_succeed, should_fail) -> None:
eval_ctx = EvalContext(globals(), locals())
check_call = build_check_call(ttype, "test_param", eval_ctx)

check_call(should_succeed)
with pytest.raises(CheckError):
check_call(should_fail)


def test_build_check_errors() -> None:
with pytest.raises(CheckError, match=r"Unable to resolve ForwardRef\('NoExist'\)"):
build_check_call(
List["NoExist"], # type: ignore # noqa
"bad",
EvalContext(globals(), locals()),
)


def test_forward_ref_flow() -> None:
# original context captured at decl
eval_ctx = EvalContext(globals(), locals())
ttype = List["Late"] # class not yet defined

class Late: ...

with pytest.raises(CheckError):
# can not build call since ctx was captured before definition
build_check_call(ttype, "ok", eval_ctx)

eval_ctx.update_from_frame(0) # update from callsite frame
# now it works
call = build_check_call(ttype, "ok", eval_ctx)
call([Late()])

0 comments on commit 4af5c63

Please sign in to comment.