Skip to content

Commit

Permalink
Fix #425 and add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat K <[email protected]>
  • Loading branch information
NeejWeej committed Jan 9, 2025
1 parent 1054f5c commit 4092860
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
6 changes: 5 additions & 1 deletion csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True)
if CspTypingUtils.is_generic_container(expected_type):
expected_type_base = CspTypingUtils.get_orig_base(expected_type)
if expected_type_base is new_type:
return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic
return expected_type # If new_type is Generic and expected type is Generic[T], return Generic
if CspTypingUtils.is_generic_container(new_type):
expected_origin = CspTypingUtils.get_origin(expected_type)
new_type_origin = CspTypingUtils.get_origin(new_type)
Expand Down Expand Up @@ -386,6 +386,10 @@ def _add_scalar_value(self, arg, in_out_def):
def _is_scalar_value_matching_spec(self, inp_def_type, arg):
if inp_def_type is typing.Any:
return True
if inp_def_type is typing.Callable or (
hasattr(inp_def_type, "__origin__") and CspTypingUtils.get_origin(inp_def_type) is collections.abc.Callable
):
return callable(arg) # TODO: Actually check the input types
if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type:
return True
if CspTypingUtils.is_union_type(inp_def_type):
Expand Down
171 changes: 171 additions & 0 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
import unittest
from datetime import datetime, time, timedelta
from typing import Callable, Dict, List, Optional

import csp
import csp.impl.types.instantiation_type_resolver as type_resolver
Expand Down Expand Up @@ -621,6 +622,46 @@ def main():

csp.run(main, starttime=datetime.utcnow(), endtime=timedelta())

def test_typed_to_untyped_container_wrong(self):
@csp.graph
def g1(d: csp.ts[dict]):
pass

@csp.graph
def g2(d: csp.ts[set]):
pass

@csp.graph
def g3(d: csp.ts[list]):
pass

def main():
# This should fail - wrong key type in Dict
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(dict\\)\\(T=typing\\.Dict\\[int, int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g1(d=csp.const.using(T=typing.Dict[int, int])({"a": 10}))

# This should fail - wrong element type in Set
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(set\\)\\(T=typing\\.Set\\[int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g2(d=csp.const.using(T=typing.Set[int])(set(["z"])))

# This should fail - wrong element type in List
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(list\\)\\(T=typing\\.List\\[int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g3(d=csp.const.using(T=typing.List[int])(["d"]))

csp.run(main, starttime=datetime.utcnow(), endtime=timedelta())

def test_time_tzinfo(self):
import pytz

Expand Down Expand Up @@ -670,6 +711,136 @@ def g():
self.assertEqual(res["y"][0][1], set())
self.assertEqual(res["z"][0][1], {})

def test_callable_type_checking(self):
@csp.node
def node_callable_typed(x: ts[int], my_data: Callable[[int], int]) -> ts[int]:
if csp.ticked(x):
if my_data:
return my_data(x) if callable(my_data) else 12

@csp.node
def node_callable_untyped(x: ts[int], my_data: Callable) -> ts[int]:
if csp.ticked(x):
if my_data:
return my_data(x) if callable(my_data) else 12

def graph():
# These should work
node_callable_untyped(csp.const(10), lambda x: 2 * x)
node_callable_typed(csp.const(10), lambda x: x + 1)

# We intentionally allow setting None to be allowed
node_callable_typed(csp.const(10), None)
node_callable_untyped(csp.const(10), None)

# This should fail - passing non-callable
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_callable_untyped.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_callable_untyped: Expected typing\\.Callable for argument 'my_data', got 11 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_callable_untyped(csp.const(10), 11)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_optional_type_checking(self):
for use_dict in [True, False]:
if use_dict:

@csp.node
def node_optional_list_typed(x: ts[int], my_data: Optional[Dict[int, int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

@csp.node
def node_optional_list_untyped(x: ts[int], my_data: Optional[dict] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x
else:

@csp.node
def node_optional_list_typed(x: ts[int], my_data: Optional[List[int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

@csp.node
def node_optional_list_untyped(x: ts[int], my_data: Optional[list] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

def graph():
# Optional[list] tests - these should work
node_optional_list_untyped(csp.const(10), {} if use_dict else [])
node_optional_list_untyped(csp.const(10), None)
node_optional_list_untyped(csp.const(10), {9: 10} if use_dict else [9])

# Optional[List[int]] tests
node_optional_list_typed(csp.const(10), None)
node_optional_list_typed(csp.const(10), {} if use_dict else [])
node_optional_list_typed(csp.const(10), {9: 10} if use_dict else [9])

# Here the List/Dict type hints don't match the signature
# But, for backwards compatibility (as this was the behavior with Optional in version 0.0.5)
# The pydantic version of the checks, however, catches this.
if USE_PYDANTIC:
msg = "(?s).*validation error.* for node_optional_list_typed.*my_data.*Input should be a valid integer.*type=int_parsing"
with self.assertRaisesRegex(TypeError, msg):
node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"])
else:
node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"])

# This should fail - type mismatch
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_list_typed.*my_data"
else:
msg = "In function node_optional_list_typed: Expected typing.Optional\\[typing(.)*"
with self.assertRaisesRegex(TypeError, msg):
node_optional_list_typed(csp.const(10), [] if use_dict else {})

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_optional_callable_type_checking(self):
@csp.node
def node_optional_callable_typed(x: ts[int], my_data: Optional[Callable[[int], int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data(x) if my_data else x

@csp.node
def node_optional_callable_untyped(x: ts[int], my_data: Optional[Callable] = None) -> ts[int]:
if csp.ticked(x):
return my_data(x) if my_data else x

def graph():
# These should work for both typed and untyped
node_optional_callable_typed(csp.const(10), None)
node_optional_callable_untyped(csp.const(10), None)

# These should also work - valid callables
node_optional_callable_typed(csp.const(10), lambda x: x + 1)
node_optional_callable_untyped(csp.const(10), lambda x: 2 * x)

# Here the Callable's type hints don't match the signature
# but we allow anyways, both with the pydantic version and without
node_optional_callable_typed(csp.const(10), lambda x, y: "a")

# This should fail - passing non-callable to typed version
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_callable_typed.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_optional_callable_typed: Expected typing\\.Optional\\[typing\\.Callable\\[\\[int\\], int\\]\\] for argument 'my_data', got 12 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_optional_callable_typed(csp.const(10), 12)

# This should fail - passing non-callable to untyped version
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_callable_untyped.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_optional_callable_untyped: Expected typing\\.Optional\\[typing\\.Callable\\] for argument 'my_data', got 12 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_optional_callable_untyped(csp.const(10), 12)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4092860

Please sign in to comment.