diff --git a/csp/impl/types/instantiation_type_resolver.py b/csp/impl/types/instantiation_type_resolver.py index 131a055a..fb432d98 100644 --- a/csp/impl/types/instantiation_type_resolver.py +++ b/csp/impl/types/instantiation_type_resolver.py @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True) if CspTypingUtils.is_generic_container(expected_type): expected_type_base = CspTypingUtils.get_orig_base(expected_type) if expected_type_base is new_type: - return expected_type # If new_type is Generic and expected type is Generic[T], return Generic + return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic if CspTypingUtils.is_generic_container(new_type): expected_origin = CspTypingUtils.get_origin(expected_type) new_type_origin = CspTypingUtils.get_origin(new_type) @@ -389,8 +389,14 @@ def _is_scalar_value_matching_spec(self, inp_def_type, arg): if inp_def_type is typing.Callable or ( hasattr(inp_def_type, "__origin__") and CspTypingUtils.get_origin(inp_def_type) is collections.abc.Callable ): - return callable(arg) # TODO: Actually check the input types - if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type: + return callable(arg) + resolved_type = UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) + if resolved_type is inp_def_type: + return True + elif ( + CspTypingUtils.is_generic_container(inp_def_type) + and CspTypingUtils.get_orig_base(inp_def_type) is resolved_type + ): return True if CspTypingUtils.is_union_type(inp_def_type): types = inp_def_type.__args__ diff --git a/csp/tests/test_type_checking.py b/csp/tests/test_type_checking.py index 330de186..17b65a64 100644 --- a/csp/tests/test_type_checking.py +++ b/csp/tests/test_type_checking.py @@ -5,7 +5,7 @@ import typing import unittest from datetime import datetime, time, timedelta -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union import csp import csp.impl.types.instantiation_type_resolver as type_resolver @@ -733,6 +733,11 @@ def graph(): node_callable_typed(csp.const(10), None) node_callable_untyped(csp.const(10), None) + # Here the Callable's type hints don't match the signature + # but we allow anyways, both with the pydantic version and without + node_callable_typed(csp.const(10), lambda x, y: "a") + node_callable_untyped(csp.const(10), lambda x, y: "a") + # This should fail - passing non-callable if USE_PYDANTIC: msg = "(?s)1 validation error for node_callable_untyped.*my_data.*Input should be callable \\[type=callable_type" @@ -822,6 +827,7 @@ def graph(): # Here the Callable's type hints don't match the signature # but we allow anyways, both with the pydantic version and without node_optional_callable_typed(csp.const(10), lambda x, y: "a") + node_optional_callable_untyped(csp.const(10), lambda x, y: "a") # This should fail - passing non-callable to typed version if USE_PYDANTIC: @@ -841,6 +847,75 @@ def graph(): csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + def test_union_type_checking(self): + @csp.node + def node_union_typed(x: ts[int], my_data: Union[int, str]) -> ts[int]: + if csp.ticked(x): + return x + int(my_data) if isinstance(my_data, str) else x + my_data + + def graph(): + # These should work - valid int inputs + node_union_typed(csp.const(10), 5) + + # These should also work - valid str inputs + node_union_typed(csp.const(10), "123") + + # These should fail - passing float when expecting Union[int, str] + if USE_PYDANTIC: + msg = "(?s)2 validation errors for node_union_typed.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part.*my_data\\.str.*Input should be a valid string" + else: + msg = "In function node_union_typed: Expected typing\\.Union\\[int, str\\] for argument 'my_data', got 12\\.5 \\(float\\)" + with self.assertRaisesRegex(TypeError, msg): + node_union_typed(csp.const(10), 12.5) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + + def test_union_list_type_checking(self): + @csp.node + def node_union_typed(x: ts[int], my_data: Union[List[str], int] = None) -> ts[int]: + if csp.ticked(x): + if isinstance(my_data, list): + return x + len(my_data) + return x + my_data + + @csp.node + def node_union_untyped(x: ts[int], my_data: Union[list, int] = None) -> ts[int]: + if csp.ticked(x): + if isinstance(my_data, list): + return x + len(my_data) + return x + my_data + + def graph(): + # These should work - valid int inputs + node_union_typed(csp.const(10), 5) + node_union_untyped(csp.const(10), 42) + + # These should work - valid list inputs + node_union_typed(csp.const(10), ["hello", "world"]) + node_union_untyped(csp.const(10), ["hello", "world"]) + + # This should fail - passing float when expecting Union[List[str], int] + if USE_PYDANTIC: + msg = "(?s)2 validation errors for node_union_typed.*my_data\\.list.*Input should be a valid list.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part" + else: + msg = "In function node_union_typed: Expected typing\\.Union\\[typing\\.List\\[str\\], int\\] for argument 'my_data', got 12\\.5 \\(float\\)" + with self.assertRaisesRegex(TypeError, msg): + node_union_typed(csp.const(10), 12.5) + + # This should fail - passing list with wrong element type + if USE_PYDANTIC: + msg = "(?s)3 validation errors for node_union_typed.*my_data\\.list\\[str\\]\\.0.*Input should be a valid string.*my_data\\.list\\[str\\]\\.1.*Input should be a valid string.*my_data\\.int.*Input should be a valid integer" + with self.assertRaisesRegex(TypeError, msg): + node_union_typed(csp.const(10), [1, 2]) # List of ints instead of strings + else: + # We choose to intentionally not enforce the types provided + # to maintain previous flexibility when not using pydantic type validation + node_union_typed(csp.const(10), [1, 2]) + + node_union_untyped(csp.const(10), [1, 2]) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + if __name__ == "__main__": unittest.main()