diff --git a/csp/impl/types/instantiation_type_resolver.py b/csp/impl/types/instantiation_type_resolver.py index 33b964b7..131a055a 100644 --- a/csp/impl/types/instantiation_type_resolver.py +++ b/csp/impl/types/instantiation_type_resolver.py @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True) if CspTypingUtils.is_generic_container(expected_type): expected_type_base = CspTypingUtils.get_orig_base(expected_type) if expected_type_base is new_type: - return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic + return expected_type # If new_type is Generic and expected type is Generic[T], return Generic if CspTypingUtils.is_generic_container(new_type): expected_origin = CspTypingUtils.get_origin(expected_type) new_type_origin = CspTypingUtils.get_origin(new_type) @@ -386,6 +386,10 @@ def _add_scalar_value(self, arg, in_out_def): def _is_scalar_value_matching_spec(self, inp_def_type, arg): if inp_def_type is typing.Any: return True + if inp_def_type is typing.Callable or ( + hasattr(inp_def_type, "__origin__") and CspTypingUtils.get_origin(inp_def_type) is collections.abc.Callable + ): + return callable(arg) # TODO: Actually check the input types if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type: return True if CspTypingUtils.is_union_type(inp_def_type): diff --git a/csp/tests/test_type_checking.py b/csp/tests/test_type_checking.py index 7c0493c3..eb0f7bfb 100644 --- a/csp/tests/test_type_checking.py +++ b/csp/tests/test_type_checking.py @@ -5,6 +5,7 @@ import typing import unittest from datetime import datetime, time, timedelta +from typing import Callable, Dict, List, Optional import csp import csp.impl.types.instantiation_type_resolver as type_resolver @@ -670,6 +671,136 @@ def g(): self.assertEqual(res["y"][0][1], set()) self.assertEqual(res["z"][0][1], {}) + def test_callable_type_checking(self): + @csp.node + def node_callable_typed(x: ts[int], my_data: Callable[[int], int]) -> ts[int]: + if csp.ticked(x): + if my_data: + return my_data(x) if callable(my_data) else 12 + + @csp.node + def node_callable_untyped(x: ts[int], my_data: Callable) -> ts[int]: + if csp.ticked(x): + if my_data: + return my_data(x) if callable(my_data) else 12 + + def graph(): + # These should work + node_callable_untyped(csp.const(10), lambda x: 2 * x) + node_callable_typed(csp.const(10), lambda x: x + 1) + + # We intentionally allow setting None to be allowed + node_callable_typed(csp.const(10), None) + node_callable_untyped(csp.const(10), None) + + # This should fail - passing non-callable + if USE_PYDANTIC: + msg = "(?s)1 validation error for node_callable_untyped.*my_data.*Input should be callable \\[type=callable_type" + else: + msg = "In function node_callable_untyped: Expected typing\\.Callable for argument 'my_data', got 11 \\(int\\)" + with self.assertRaisesRegex(TypeError, msg): + node_callable_untyped(csp.const(10), 11) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + + def test_optional_type_checking(self): + for use_dict in [True, False]: + if use_dict: + + @csp.node + def node_optional_list_typed(x: ts[int], my_data: Optional[Dict[int, int]] = None) -> ts[int]: + if csp.ticked(x): + return my_data[0] if my_data else x + + @csp.node + def node_optional_list_untyped(x: ts[int], my_data: Optional[dict] = None) -> ts[int]: + if csp.ticked(x): + return my_data[0] if my_data else x + else: + + @csp.node + def node_optional_list_typed(x: ts[int], my_data: Optional[List[int]] = None) -> ts[int]: + if csp.ticked(x): + return my_data[0] if my_data else x + + @csp.node + def node_optional_list_untyped(x: ts[int], my_data: Optional[list] = None) -> ts[int]: + if csp.ticked(x): + return my_data[0] if my_data else x + + def graph(): + # Optional[list] tests - these should work + node_optional_list_untyped(csp.const(10), {} if use_dict else []) + node_optional_list_untyped(csp.const(10), None) + node_optional_list_untyped(csp.const(10), {9: 10} if use_dict else [9]) + + # Optional[List[int]] tests + node_optional_list_typed(csp.const(10), None) + node_optional_list_typed(csp.const(10), {} if use_dict else []) + node_optional_list_typed(csp.const(10), {9: 10} if use_dict else [9]) + + # Here the List/Dict type hints don't match the signature + # but we allow anyways. We only care that the top object matches + # The pydantic version, however, catches this. + if USE_PYDANTIC: + msg = "(?s).*validation error.* for node_optional_list_typed.*my_data.*Input should be a valid integer.*type=int_parsing" + with self.assertRaisesRegex(TypeError, msg): + node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"]) + else: + node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"]) + + # This should fail - type mismatch + if USE_PYDANTIC: + msg = "(?s)1 validation error for node_optional_list_typed.*my_data" + else: + msg = "In function node_optional_list_typed: Expected typing.Optional\\[typing(.)*" + with self.assertRaisesRegex(TypeError, msg): + node_optional_list_typed(csp.const(10), [] if use_dict else {}) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + + def test_optional_callable_type_checking(self): + @csp.node + def node_optional_callable_typed(x: ts[int], my_data: Optional[Callable[[int], int]] = None) -> ts[int]: + if csp.ticked(x): + return my_data(x) if my_data else x + + @csp.node + def node_optional_callable_untyped(x: ts[int], my_data: Optional[Callable] = None) -> ts[int]: + if csp.ticked(x): + return my_data(x) if my_data else x + + def graph(): + # These should work for both typed and untyped + node_optional_callable_typed(csp.const(10), None) + node_optional_callable_untyped(csp.const(10), None) + + # These should also work - valid callables + node_optional_callable_typed(csp.const(10), lambda x: x + 1) + node_optional_callable_untyped(csp.const(10), lambda x: 2 * x) + + # Here the Callable's type hints don't match the signature + # but we allow anyways, both with the pydantic version and without + node_optional_callable_typed(csp.const(10), lambda x, y: "a") + + # This should fail - passing non-callable to typed version + if USE_PYDANTIC: + msg = "(?s)1 validation error for node_optional_callable_typed.*my_data.*Input should be callable \\[type=callable_type" + else: + msg = "In function node_optional_callable_typed: Expected typing\\.Optional\\[typing\\.Callable\\[\\[int\\], int\\]\\] for argument 'my_data', got 12 \\(int\\)" + with self.assertRaisesRegex(TypeError, msg): + node_optional_callable_typed(csp.const(10), 12) + + # This should fail - passing non-callable to untyped version + if USE_PYDANTIC: + msg = "(?s)1 validation error for node_optional_callable_untyped.*my_data.*Input should be callable \\[type=callable_type" + else: + msg = "In function node_optional_callable_untyped: Expected typing\\.Optional\\[typing\\.Callable\\] for argument 'my_data', got 12 \\(int\\)" + with self.assertRaisesRegex(TypeError, msg): + node_optional_callable_untyped(csp.const(10), 12) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + if __name__ == "__main__": unittest.main()