From 4af5c6303687469dd9221a23b3b2de58c8fd2271 Mon Sep 17 00:00:00 2001 From: Alex Langenfeld Date: Thu, 6 Jun 2024 14:43:11 -0500 Subject: [PATCH] [check] call builder (#22260) Adds a utility funtion to check for programatiically building calls based on a type signatures. ## How I Tested These Changes added tests --- .../dagster/dagster/_check/__init__.py | 219 ++++++++++++++++++ .../general_tests/check_tests/test_check.py | 85 +++++++ 2 files changed, 304 insertions(+) diff --git a/python_modules/dagster/dagster/_check/__init__.py b/python_modules/dagster/dagster/_check/__init__.py index 152944c0b1e97..d1e40b8876099 100644 --- a/python_modules/dagster/dagster/_check/__init__.py +++ b/python_modules/dagster/dagster/_check/__init__.py @@ -1,16 +1,19 @@ import collections.abc import inspect +import sys from os import PathLike, fspath from typing import ( AbstractSet, Any, Callable, Dict, + ForwardRef, Generator, Iterable, Iterator, List, Mapping, + NamedTuple, NoReturn, Optional, Sequence, @@ -19,9 +22,20 @@ TypeVar, Union, cast, + get_args, + get_origin, overload, ) +from typing_extensions import Annotated + +try: + # this type only exists in python 3.10+ + from types import UnionType # type: ignore +except ImportError: + UnionType = Union + + TypeOrTupleOfTypes = Union[type, Tuple[type, ...]] Numeric = Union[int, float] T = TypeVar("T") @@ -1204,6 +1218,18 @@ def opt_iterable_param( return iterable_param(obj, param_name, of_type, additional_message) +def opt_nullable_iterable_param( + obj: Optional[Iterable[T]], + param_name: str, + of_type: Optional[TypeOrTupleOfTypes] = None, + additional_message: Optional[str] = None, +) -> Optional[Iterable[T]]: + if obj is None: + return None + + return iterable_param(obj, param_name, of_type, additional_message) + + # ######################## # ##### SET # ######################## @@ -1805,3 +1831,196 @@ def _check_two_dim_mapping_entries( ) # check level two return obj + + +# ################################################################################################### +# ##### CALL BUILDER +# ################################################################################################### + + +class EvalContext(NamedTuple): + """Utility class for managing references to global and local namespaces. + + These namespaces are passed to ForwardRef._evaluate to resolve the actual + type from a string. + """ + + global_ns: dict + local_ns: dict + + @staticmethod + def capture_from_frame(depth: int) -> "EvalContext": + ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly + + return EvalContext( + # copy to not mess up frame data + ctx_frame.f_globals.copy(), + ctx_frame.f_locals.copy(), + ) + + def update_from_frame(self, depth: int): + # Update the global and local namespaces with symbols from the target frame + ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly + self.global_ns.update(ctx_frame.f_globals) + self.local_ns.update(ctx_frame.f_locals) + + def eval_forward_ref(self, ref: ForwardRef) -> Optional[Type]: + try: + if sys.version_info <= (3, 9): + return ref._evaluate(self.global_ns, self.local_ns) # noqa + else: + return ref._evaluate(self.global_ns, self.local_ns, frozenset()) # noqa + except NameError as e: + raise CheckError(f"Unable to resolve {ref}") from e + + +def _no_op(_): ... + + +def _coerce( + ttype: Optional[TypeOrTupleOfTypes], + eval_ctx: Optional[EvalContext], +) -> Optional[TypeOrTupleOfTypes]: + if ttype is Any: + return None + if isinstance(ttype, str): + if eval_ctx is None: + failed( + f"Can not generate check calls from string {ttype} (assumed ForwardRef) without EvalContext" + ) + return eval_ctx.eval_forward_ref(ForwardRef(ttype)) + if isinstance(ttype, ForwardRef): + if eval_ctx is None: + failed(f"Can not evaluate ForwardRef {ttype} without passing in EvalContext") + return eval_ctx.eval_forward_ref(ttype) + if get_origin(ttype) in (UnionType, Union): + optional_args = get_args(ttype) + tuple_types = _container_pair_args(optional_args, eval_ctx) + if None in tuple_types: + failed(f"Unable to turn Optional in to tuple of types for {optional_args} from {ttype}") + return tuple_types # type: ignore # static analysis cant follow above check + + return ttype + + +def _container_pair_args( + args: Tuple[Type, ...], eval_ctx +) -> Tuple[Optional[TypeOrTupleOfTypes], Optional[TypeOrTupleOfTypes]]: + if len(args) == 2: + return _coerce(args[0], eval_ctx), _coerce(args[1], eval_ctx) + + return None, None + + +def _container_single_arg( + args: Tuple[Type, ...], eval_ctx: Optional[EvalContext] +) -> Optional[TypeOrTupleOfTypes]: + if len(args) == 1: + return _coerce(args[0], eval_ctx) + + return None + + +def build_check_call( + ttype: Type, + name: str, + # have this be passed in to avoid guessing which stack frame to use + eval_ctx: Optional[EvalContext] = None, +) -> Callable[[Any], Any]: + # performance notes: + # * lambdas measured to beat functools.partial + # * positional arg use measured to beat keyword arg use + + origin = get_origin(ttype) + args = get_args(ttype) + + # scalars + if origin is None: + if ttype is str: + return lambda o: str_param(o, name) + elif ttype is float: + return lambda o: float_param(o, name) + elif ttype is int: + return lambda o: int_param(o, name) + elif ttype is bool: + return lambda o: int_param(o, name) + elif ttype is Any: + return _no_op + + # fallback to inst + inst_type = _coerce(ttype, eval_ctx) + if inst_type: + return lambda o: inst_param(o, name, inst_type) + else: + return _no_op + else: + if origin is Annotated and args: + return build_check_call(args[0], name, eval_ctx) + + pair_left, pair_right = _container_pair_args(args, eval_ctx) + single = _container_single_arg(args, eval_ctx) + + # containers + if origin is list: + return lambda o: list_param(o, name, single) + elif origin is dict: + return lambda o: dict_param(o, name, pair_left, pair_right) + elif origin is set: + return lambda o: set_param(o, param_name=name, of_type=single) + elif origin is collections.abc.Sequence: + return lambda o: sequence_param(o, param_name=name, of_type=single) + elif origin is collections.abc.Iterable: + return lambda o: iterable_param(o, param_name=name, of_type=single) + elif origin is collections.abc.Mapping: + return lambda o: mapping_param(o, name, pair_left, pair_right) + + elif origin in (UnionType, Union): + # optional + if pair_right is type(None): + inner_origin = get_origin(pair_left) + # optional scalar + if inner_origin is None: + if pair_left is str: + return lambda o: opt_str_param(o, param_name=name) + elif pair_left is float: + return lambda o: opt_float_param(o, param_name=name) + elif pair_left is int: + return lambda o: opt_int_param(o, param_name=name) + elif pair_left is bool: + return lambda o: opt_bool_param(o, param_name=name) + + # fallback to opt_inst + inst_type = _coerce(pair_left, eval_ctx) + if inst_type: + return lambda o: opt_inst_param(o, ttype=inst_type, param_name=name) + else: + return _no_op + + # optional container + else: + inner_args = get_args(pair_left) + inner_pair_left, inner_pair_right = _container_pair_args(inner_args, eval_ctx) + inner_single = _container_single_arg(inner_args, eval_ctx) + if inner_origin is list: + return lambda o: opt_nullable_list_param(o, name, inner_single) + elif inner_origin is dict: + return lambda o: opt_nullable_dict_param( + o, name, inner_pair_left, inner_pair_right + ) + elif inner_origin is set: + return lambda o: opt_nullable_set_param(o, name, inner_single) + elif inner_origin is collections.abc.Sequence: + return lambda o: opt_nullable_sequence_param(o, name, inner_single) + elif inner_origin is collections.abc.Iterable: + return lambda o: opt_nullable_iterable_param(o, name, inner_single) + elif inner_origin is collections.abc.Mapping: + return lambda o: opt_nullable_mapping_param( + o, name, inner_pair_left, inner_pair_right + ) + # union + else: + tuple_types = _coerce(ttype, eval_ctx) + if tuple_types is not None: + return lambda o: inst_param(o, name, tuple_types) + + failed(f"Unhandled {ttype}") diff --git a/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py b/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py index e913180dba5cd..92954699245bc 100644 --- a/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py +++ b/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py @@ -3,14 +3,18 @@ import sys from collections import defaultdict from contextlib import contextmanager +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union import dagster._check as check import pytest +from dagster._annotations import PublicAttr from dagster._check import ( CheckError, ElementCheckError, + EvalContext, NotImplementedCheckError, ParameterCheckError, + build_check_call, ) @@ -1508,3 +1512,84 @@ def test_opt_iterable(): with pytest.raises(CheckError, match="Member of iterable mismatches type"): check.opt_iterable_param(["atr", None], "nonedoesntcount", of_type=str) + + +# ################################################################################################### +# ##### CHECK BUILDER +# ################################################################################################### + + +class Foo: ... + + +class SubFoo(Foo): ... + + +class Bar: ... + + +BUILD_CASES = [ + (int, 4, "4"), + (float, 4.2, "4.1"), + (str, "hi", Foo()), + (Bar, Bar(), Foo()), + (Optional[Bar], Bar(), Foo()), + (List[str], ["a", "b"], [1, 2]), + (Sequence[str], ["a", "b"], [1, 2]), + (Iterable[str], ["a", "b"], [1, 2]), + (Set[str], {"a", "b"}, {1, 2}), + (Dict[str, int], {"a": 1}, {1: "a"}), + (Mapping[str, int], {"a": 1}, {1: "a"}), + (Optional[int], None, "4"), + (Optional[Bar], None, Foo()), + (Optional[List[str]], ["a", "b"], [1, 2]), + (Optional[Sequence[str]], ["a", "b"], [1, 2]), + (Optional[Iterable[str]], ["a", "b"], [1, 2]), + (Optional[Set[str]], {"a", "b"}, {1, 2}), + (Optional[Dict[str, int]], {"a": 1}, {1: "a"}), + (Optional[Mapping[str, int]], {"a": 1}, {1: "a"}), + (PublicAttr[Optional[Mapping[str, int]]], {"a": 1}, {1: "a"}), + (PublicAttr[Bar], Bar(), Foo()), + (Union[bool, Foo], True, None), + # fwd refs + ("Foo", Foo(), Bar()), + (Optional["Foo"], Foo(), Bar()), + (PublicAttr[Optional["Foo"]], None, Bar()), + (Mapping[str, Optional["Foo"]], {"foo": Foo()}, {"bar": Bar()}), +] + + +@pytest.mark.parametrize("ttype, should_succeed, should_fail", BUILD_CASES) +def test_build_check_call(ttype, should_succeed, should_fail) -> None: + eval_ctx = EvalContext(globals(), locals()) + check_call = build_check_call(ttype, "test_param", eval_ctx) + + check_call(should_succeed) + with pytest.raises(CheckError): + check_call(should_fail) + + +def test_build_check_errors() -> None: + with pytest.raises(CheckError, match=r"Unable to resolve ForwardRef\('NoExist'\)"): + build_check_call( + List["NoExist"], # type: ignore # noqa + "bad", + EvalContext(globals(), locals()), + ) + + +def test_forward_ref_flow() -> None: + # original context captured at decl + eval_ctx = EvalContext(globals(), locals()) + ttype = List["Late"] # class not yet defined + + class Late: ... + + with pytest.raises(CheckError): + # can not build call since ctx was captured before definition + build_check_call(ttype, "ok", eval_ctx) + + eval_ctx.update_from_frame(0) # update from callsite frame + # now it works + call = build_check_call(ttype, "ok", eval_ctx) + call([Late()])