Skip to content

Commit

Permalink
Fix flakey test, add more tests for validation, union types
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat K <[email protected]>
  • Loading branch information
NeejWeej committed Jan 10, 2025
1 parent 4092860 commit 8ed0013
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
12 changes: 9 additions & 3 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True)
if CspTypingUtils.is_generic_container(expected_type):
expected_type_base = CspTypingUtils.get_orig_base(expected_type)
if expected_type_base is new_type:
return expected_type # If new_type is Generic and expected type is Generic[T], return Generic
return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic
if CspTypingUtils.is_generic_container(new_type):
expected_origin = CspTypingUtils.get_origin(expected_type)
new_type_origin = CspTypingUtils.get_origin(new_type)
Expand Down Expand Up @@ -389,8 +389,14 @@ def _is_scalar_value_matching_spec(self, inp_def_type, arg):
if inp_def_type is typing.Callable or (
hasattr(inp_def_type, "__origin__") and CspTypingUtils.get_origin(inp_def_type) is collections.abc.Callable
):
return callable(arg) # TODO: Actually check the input types
if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type:
return callable(arg)
resolved_type = UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False)
if resolved_type is inp_def_type:
return True
elif (
CspTypingUtils.is_generic_container(inp_def_type)
and CspTypingUtils.get_orig_base(inp_def_type) is resolved_type
):
return True
if CspTypingUtils.is_union_type(inp_def_type):
types = inp_def_type.__args__
Expand Down
77 changes: 76 additions & 1 deletion csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
import unittest
from datetime import datetime, time, timedelta
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import csp
import csp.impl.types.instantiation_type_resolver as type_resolver
Expand Down Expand Up @@ -733,6 +733,11 @@ def graph():
node_callable_typed(csp.const(10), None)
node_callable_untyped(csp.const(10), None)

# Here the Callable's type hints don't match the signature
# but we allow anyways, both with the pydantic version and without
node_callable_typed(csp.const(10), lambda x, y: "a")
node_callable_untyped(csp.const(10), lambda x, y: "a")

# This should fail - passing non-callable
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_callable_untyped.*my_data.*Input should be callable \\[type=callable_type"
Expand Down Expand Up @@ -822,6 +827,7 @@ def graph():
# Here the Callable's type hints don't match the signature
# but we allow anyways, both with the pydantic version and without
node_optional_callable_typed(csp.const(10), lambda x, y: "a")
node_optional_callable_untyped(csp.const(10), lambda x, y: "a")

# This should fail - passing non-callable to typed version
if USE_PYDANTIC:
Expand All @@ -841,6 +847,75 @@ def graph():

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_union_type_checking(self):
@csp.node
def node_union_typed(x: ts[int], my_data: Union[int, str]) -> ts[int]:
if csp.ticked(x):
return x + int(my_data) if isinstance(my_data, str) else x + my_data

def graph():
# These should work - valid int inputs
node_union_typed(csp.const(10), 5)

# These should also work - valid str inputs
node_union_typed(csp.const(10), "123")

# These should fail - passing float when expecting Union[int, str]
if USE_PYDANTIC:
msg = "(?s)2 validation errors for node_union_typed.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part.*my_data\\.str.*Input should be a valid string"
else:
msg = "In function node_union_typed: Expected typing\\.Union\\[int, str\\] for argument 'my_data', got 12\\.5 \\(float\\)"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), 12.5)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_union_list_type_checking(self):
@csp.node
def node_union_typed(x: ts[int], my_data: Union[List[str], int] = None) -> ts[int]:
if csp.ticked(x):
if isinstance(my_data, list):
return x + len(my_data)
return x + my_data

@csp.node
def node_union_untyped(x: ts[int], my_data: Union[list, int] = None) -> ts[int]:
if csp.ticked(x):
if isinstance(my_data, list):
return x + len(my_data)
return x + my_data

def graph():
# These should work - valid int inputs
node_union_typed(csp.const(10), 5)
node_union_untyped(csp.const(10), 42)

# These should work - valid list inputs
node_union_typed(csp.const(10), ["hello", "world"])
node_union_untyped(csp.const(10), ["hello", "world"])

# This should fail - passing float when expecting Union[List[str], int]
if USE_PYDANTIC:
msg = "(?s)2 validation errors for node_union_typed.*my_data\\.list.*Input should be a valid list.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part"
else:
msg = "In function node_union_typed: Expected typing\\.Union\\[typing\\.List\\[str\\], int\\] for argument 'my_data', got 12\\.5 \\(float\\)"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), 12.5)

# This should fail - passing list with wrong element type
if USE_PYDANTIC:
msg = "(?s)3 validation errors for node_union_typed.*my_data\\.list\\[str\\]\\.0.*Input should be a valid string.*my_data\\.list\\[str\\]\\.1.*Input should be a valid string.*my_data\\.int.*Input should be a valid integer"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), [1, 2]) # List of ints instead of strings
else:
# We choose to intentionally not enforce the types provided
# to maintain previous flexibility when not using pydantic type validation
node_union_typed(csp.const(10), [1, 2])

node_union_untyped(csp.const(10), [1, 2])

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8ed0013

Please sign in to comment.