diff --git a/Dockerfile.test.py310 b/Dockerfile.test.py310 deleted file mode 100644 index fc33837..0000000 --- a/Dockerfile.test.py310 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.10-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/Dockerfile.test.py38 b/Dockerfile.test.py38 deleted file mode 100644 index 209a32e..0000000 --- a/Dockerfile.test.py38 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.8-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/Dockerfile.test.py39 b/Dockerfile.test.py39 deleted file mode 100644 index 5e6c2fd..0000000 --- a/Dockerfile.test.py39 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.9-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/examples/basic/schema.py b/examples/basic/schema.py index ecb8e66..9e4057b 100644 --- a/examples/basic/schema.py +++ b/examples/basic/schema.py @@ -1,9 +1,30 @@ +from dataclasses import dataclass +from datetime import date from typing import List import msgspec +class Address(msgspec.Struct): + street: str + city: str + zip: int + + class User(msgspec.Struct): name: str age: int emails: List[str] + addresses: List[Address] + registered_at: date + + +@dataclass +class Teacher: + name: str + + +class Student(User): + roll_no: int + marks: List[int] + teachers: List[Teacher] diff --git a/examples/basic/server.py b/examples/basic/server.py index ffc3c36..2bac606 100644 --- a/examples/basic/server.py +++ b/examples/basic/server.py @@ -5,7 +5,7 @@ from zero import ZeroServer -from .schema import User +from .schema import Student, Teacher, User app = ZeroServer(port=5559) @@ -42,6 +42,17 @@ def hello_users(users: typing.List[User]) -> str: return f"Hello {', '.join([user.name for user in users])}! Your emails are {', '.join([email for user in users for email in user.emails])}!" +teachers = [ + Teacher(name="Teacher1"), + Teacher(name="Teacher2"), +] + + +@app.register_rpc +def hello_student(student: Student) -> str: + return f"Hello {student.name}! You are {student.age} years old. Your email is {student.emails[0]}! Your roll no. is {student.roll_no} and your marks are {student.marks}!" + + if __name__ == "__main__": app.register_rpc(echo) app.register_rpc(hello_world) diff --git a/tests/functional/codegen/__init__.py b/tests/functional/codegen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/functional/codegen/test_codegen.py b/tests/functional/codegen/test_codegen.py new file mode 100644 index 0000000..8d1584a --- /dev/null +++ b/tests/functional/codegen/test_codegen.py @@ -0,0 +1,710 @@ +import dataclasses +import datetime +import decimal +import enum +import typing +import unittest +import uuid +from dataclasses import dataclass +from datetime import date +from typing import Dict, List, Optional, Tuple, Union + +import msgspec +from msgspec import Struct + +from zero.codegen.codegen import CodeGen + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +class ChildDataclass(SimpleDataclass): + e: int + f: str + + +class SimpleStruct(Struct): + h: int + i: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + +class SimpleEnum(enum.Enum): + ONE = 1 + TWO = 2 + + +class SimpleIntEnum(enum.IntEnum): + ONE = 1 + TWO = 2 + + +def func_none(arg: None) -> str: + return "Received None" + + +def func_bool(arg: bool) -> str: + return f"Received bool: {arg}" + + +def func_int(arg: int) -> str: + return f"Received int: {arg}" + + +def func_float(arg: float) -> str: + return f"Received float: {arg}" + + +def func_str(arg: str) -> str: + return f"Received str: {arg}" + + +def func_bytes(arg: bytes) -> str: + return f"Received bytes: {arg}" + + +def func_bytearray(arg: bytearray) -> str: + return f"Received bytearray: {arg}" + + +def func_tuple(arg: tuple) -> str: + return f"Received tuple: {arg}" + + +def func_list(arg: list) -> str: + return f"Received list: {arg}" + + +def func_dict(arg: dict) -> str: + return f"Received dict: {arg}" + + +def func_optional_dict(arg: Optional[dict]) -> str: + return f"Received dict: {arg}" + + +def func_set(arg: set) -> str: + return f"Received set: {arg}" + + +def func_frozenset(arg: frozenset) -> str: + return f"Received frozenset: {arg}" + + +def func_datetime(arg: datetime.datetime) -> str: + return f"Received datetime: {arg}" + + +def func_date(arg: date) -> str: + return f"Received date: {arg}" + + +def func_time(arg: datetime.time) -> str: + return f"Received time: {arg}" + + +def func_uuid(arg: uuid.UUID) -> str: + return f"Received UUID: {arg}" + + +def func_decimal(arg: decimal.Decimal) -> str: + return f"Received Decimal: {arg}" + + +def func_enum(arg: SimpleEnum) -> str: + return f"Received Enum: {arg}" + + +def func_intenum(arg: SimpleIntEnum) -> str: + return f"Received IntEnum: {arg}" + + +def func_dataclass(arg: SimpleDataclass) -> str: + return f"Received dataclass: {arg}" + + +def func_tuple_typing(arg: typing.Tuple[int, str]) -> str: + return f"Received typing.Tuple: {arg}" + + +def func_list_typing(arg: typing.List[int]) -> str: + return f"Received typing.List: {arg}" + + +def func_dict_typing(arg: typing.Dict[str, int]) -> str: + return f"Received typing.Dict: {arg}" + + +def func_set_typing(arg: typing.Set[int]) -> str: + return f"Received typing.Set: {arg}" + + +def func_frozenset_typing(arg: typing.FrozenSet[int]) -> str: + return f"Received typing.FrozenSet: {arg}" + + +def func_any_typing(arg: typing.Any) -> str: + return f"Received typing.Any: {arg}" + + +def func_union_typing(arg: typing.Union[int, str]) -> str: + return f"Received typing.Union: {arg}" + + +def func_optional_typing(arg: typing.Optional[int]) -> str: + return f"Received typing.Optional: {arg}" + + +def func_msgspec_struct(arg: SimpleStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_msgspec_struct_complex(arg: ComplexStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_child_complex_struct(arg: ChildComplexStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_return_optional_child_complex_struct() -> Optional[ChildComplexStruct]: + return None + + +def func_return_complex_struct() -> ComplexStruct: + return ComplexStruct( + a=1, + b="hello", + c=SimpleStruct(h=1, i="hello"), + d=[SimpleStruct(h=1, i="hello")], + e={"1": SimpleStruct(h=1, i="hello")}, + f=(SimpleDataclass(a=1, b="hello"), SimpleStruct(h=1, i="hello")), + g=SimpleDataclass(a=1, b="hello"), + ) + + +def func_take_optional_child_dataclass_return_optional_child_complex_struct( + arg: Optional[ChildDataclass], +) -> Optional[ChildComplexStruct]: + return None + + +class TestCodegen(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + self._rpc_router = { + "func_none": (func_none, False), + "func_bool": (func_bool, False), + "func_int": (func_int, False), + "func_float": (func_float, False), + "func_str": (func_str, False), + "func_bytes": (func_bytes, False), + "func_bytearray": (func_bytearray, False), + "func_tuple": (func_tuple, False), + "func_list": (func_list, False), + "func_dict": (func_dict, False), + "func_optional_dict": (func_optional_dict, False), + "func_set": (func_set, False), + "func_frozenset": (func_frozenset, False), + "func_datetime": (func_datetime, False), + "func_date": (func_date, False), + "func_time": (func_time, False), + "func_uuid": (func_uuid, False), + "func_decimal": (func_decimal, False), + "func_enum": (func_enum, False), + "func_intenum": (func_intenum, False), + "func_dataclass": (func_dataclass, False), + "func_tuple_typing": (func_tuple_typing, False), + "func_list_typing": (func_list_typing, False), + "func_dict_typing": (func_dict_typing, False), + "func_set_typing": (func_set_typing, False), + "func_frozenset_typing": (func_frozenset_typing, False), + "func_any_typing": (func_any_typing, False), + "func_union_typing": (func_union_typing, False), + "func_optional_typing": (func_optional_typing, False), + "func_msgspec_struct": (func_msgspec_struct, False), + "func_msgspec_struct_complex": (func_msgspec_struct_complex, False), + "func_child_complex_struct": (func_child_complex_struct, False), + "func_return_complex_struct": (func_return_complex_struct, False), + } + self._rpc_input_type_map = { + "func_none": None, + "func_bool": bool, + "func_int": int, + "func_float": float, + "func_str": str, + "func_bytes": bytes, + "func_bytearray": bytearray, + "func_tuple": tuple, + "func_list": list, + "func_dict": dict, + "func_optional_dict": Optional[dict], + "func_set": set, + "func_frozenset": frozenset, + "func_datetime": datetime.datetime, + "func_date": datetime.date, + "func_time": datetime.time, + "func_uuid": uuid.UUID, + "func_decimal": decimal.Decimal, + "func_enum": SimpleEnum, + "func_intenum": SimpleIntEnum, + "func_dataclass": SimpleDataclass, + "func_tuple_typing": typing.Tuple[int, str], + "func_list_typing": typing.List[int], + "func_dict_typing": typing.Dict[str, int], + "func_set_typing": typing.Set[int], + "func_frozenset_typing": typing.FrozenSet[int], + "func_any_typing": typing.Any, + "func_union_typing": typing.Union[int, str], + "func_optional_typing": typing.Optional[int], + "func_msgspec_struct": SimpleStruct, + "func_msgspec_struct_complex": ComplexStruct, + "func_child_complex_struct": ChildComplexStruct, + "func_return_complex_struct": None, + } + self._rpc_return_type_map = { + "func_none": str, + "func_bool": str, + "func_int": str, + "func_float": str, + "func_str": str, + "func_bytes": str, + "func_bytearray": str, + "func_tuple": str, + "func_list": str, + "func_dict": str, + "func_optional_dict": Optional[str], + "func_set": str, + "func_frozenset": str, + "func_datetime": str, + "func_date": str, + "func_time": str, + "func_uuid": str, + "func_decimal": str, + "func_enum": str, + "func_intenum": str, + "func_dataclass": str, + "func_tuple_typing": str, + "func_list_typing": str, + "func_dict_typing": str, + "func_set_typing": str, + "func_frozenset_typing": str, + "func_any_typing": str, + "func_union_typing": str, + "func_optional_typing": str, + "func_msgspec_struct": str, + "func_msgspec_struct_complex": str, + "func_child_complex_struct": str, + "func_return_complex_struct": ComplexStruct, + } + + def test_codegen(self): + codegen = CodeGen( + self._rpc_router, self._rpc_input_type_map, self._rpc_return_type_map + ) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +import dataclasses +from dataclasses import dataclass +from datetime import date, datetime, time +import decimal +import enum +import msgspec +from msgspec import Struct +from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union +import uuid + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleEnum(enum.Enum): + ONE = 1 + TWO = 2 + + +class SimpleIntEnum(enum.IntEnum): + ONE = 1 + TWO = 2 + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_none(selfarg: None) -> str: + return self._zero_client.call("func_none", None) + + def func_bool(self, arg: bool) -> str: + return self._zero_client.call("func_bool", arg) + + def func_int(self, arg: int) -> str: + return self._zero_client.call("func_int", arg) + + def func_float(self, arg: float) -> str: + return self._zero_client.call("func_float", arg) + + def func_str(self, arg: str) -> str: + return self._zero_client.call("func_str", arg) + + def func_bytes(self, arg: bytes) -> str: + return self._zero_client.call("func_bytes", arg) + + def func_bytearray(self, arg: bytearray) -> str: + return self._zero_client.call("func_bytearray", arg) + + def func_tuple(self, arg: tuple) -> str: + return self._zero_client.call("func_tuple", arg) + + def func_list(self, arg: list) -> str: + return self._zero_client.call("func_list", arg) + + def func_dict(self, arg: dict) -> str: + return self._zero_client.call("func_dict", arg) + + def func_optional_dict(self, arg: Optional[dict]) -> str: + return self._zero_client.call("func_optional_dict", arg) + + def func_set(self, arg: set) -> str: + return self._zero_client.call("func_set", arg) + + def func_frozenset(self, arg: frozenset) -> str: + return self._zero_client.call("func_frozenset", arg) + + def func_datetime(self, arg: datetime.datetime) -> str: + return self._zero_client.call("func_datetime", arg) + + def func_date(self, arg: date) -> str: + return self._zero_client.call("func_date", arg) + + def func_time(self, arg: datetime.time) -> str: + return self._zero_client.call("func_time", arg) + + def func_uuid(self, arg: uuid.UUID) -> str: + return self._zero_client.call("func_uuid", arg) + + def func_decimal(self, arg: decimal.Decimal) -> str: + return self._zero_client.call("func_decimal", arg) + + def func_enum(self, arg: SimpleEnum) -> str: + return self._zero_client.call("func_enum", arg) + + def func_intenum(self, arg: SimpleIntEnum) -> str: + return self._zero_client.call("func_intenum", arg) + + def func_dataclass(self, arg: SimpleDataclass) -> str: + return self._zero_client.call("func_dataclass", arg) + + def func_tuple_typing(self, arg: Tuple[int, str]) -> str: + return self._zero_client.call("func_tuple_typing", arg) + + def func_list_typing(self, arg: List[int]) -> str: + return self._zero_client.call("func_list_typing", arg) + + def func_dict_typing(self, arg: Dict[str, int]) -> str: + return self._zero_client.call("func_dict_typing", arg) + + def func_set_typing(self, arg: Set[int]) -> str: + return self._zero_client.call("func_set_typing", arg) + + def func_frozenset_typing(self, arg: FrozenSet[int]) -> str: + return self._zero_client.call("func_frozenset_typing", arg) + + def func_any_typing(self, arg: Any) -> str: + return self._zero_client.call("func_any_typing", arg) + + def func_union_typing(self, arg: Union[int, str]) -> str: + return self._zero_client.call("func_union_typing", arg) + + def func_optional_typing(self, arg: Optional[int]) -> str: + return self._zero_client.call("func_optional_typing", arg) + + def func_msgspec_struct(self, arg: SimpleStruct) -> str: + return self._zero_client.call("func_msgspec_struct", arg) + + def func_msgspec_struct_complex(self, arg: ComplexStruct) -> str: + return self._zero_client.call("func_msgspec_struct_complex", arg) + + def func_child_complex_struct(self, arg: ChildComplexStruct) -> str: + return self._zero_client.call("func_child_complex_struct", arg) + + def func_return_complex_struct(self) -> ComplexStruct: + return self._zero_client.call("func_return_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_return_single_complex_struct(self): + rpc_router = { + "func_return_complex_struct": (func_return_complex_struct, False), + } + rpc_input_type_map = { + "func_return_complex_struct": None, + } + rpc_return_type_map = { + "func_return_complex_struct": ComplexStruct, + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +import dataclasses +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_return_complex_struct(self) -> ComplexStruct: + return self._zero_client.call("func_return_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_return_optional_complex_struct(self): + rpc_router = { + "func_return_optional_child_complex_struct": ( + func_return_optional_child_complex_struct, + False, + ), + } + rpc_input_type_map = { + "func_return_optional_child_complex_struct": None, + } + rpc_return_type_map = { + "func_return_optional_child_complex_struct": Optional[ChildComplexStruct], + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +import dataclasses +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Optional, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_return_optional_child_complex_struct(self) -> Optional[ChildComplexStruct]: + return self._zero_client.call("func_return_optional_child_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_optional_child_dataclass_return_optional_child_complex_struct( + self, + ): + rpc_router = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": ( + func_take_optional_child_dataclass_return_optional_child_complex_struct, + False, + ), + } + rpc_input_type_map = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[ + ChildDataclass + ], + } + rpc_return_type_map = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[ + ChildComplexStruct + ], + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +import dataclasses +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Optional, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +@dataclass +class SimpleDataclass: + a: int + b: str + + +class ChildDataclass(SimpleDataclass): + e: int + f: str + + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_take_optional_child_dataclass_return_optional_child_complex_struct(self, + arg: Optional[ChildDataclass], +) -> Optional[ChildComplexStruct]: + return self._zero_client.call("func_take_optional_child_dataclass_return_optional_child_complex_struct", arg) +""" + self.assertEqual(code, expected_code) diff --git a/tests/functional/single_server/client_generation_test.py b/tests/functional/single_server/client_generation_test.py index 22d3fc7..1a9c59c 100644 --- a/tests/functional/single_server/client_generation_test.py +++ b/tests/functional/single_server/client_generation_test.py @@ -17,13 +17,22 @@ def test_codegeneration(): assert ( code == """# Generated by Zero -# import types as per needed +# import types as per needed, not all imports are shown here +from datetime import datetime +import enum +import msgspec +from typing import Dict, List, Tuple, Union from zero import ZeroClient zero_client = ZeroClient("localhost", 5559) +class Message(msgspec.Struct): + msg: str + start_time: datetime.datetime + + class RpcClient: def __init__(self, zero_client: ZeroClient): @@ -53,19 +62,19 @@ def hello_world(self) -> str: def decode_jwt(self, msg: str) -> str: return self._zero_client.call("decode_jwt", msg) - def sum_list(self, msg: typing.List[int]) -> int: + def sum_list(self, msg: List[int]) -> int: return self._zero_client.call("sum_list", msg) - def echo_dict(self, msg: typing.Dict[int, str]) -> typing.Dict[int, str]: + def echo_dict(self, msg: Dict[int, str]) -> Dict[int, str]: return self._zero_client.call("echo_dict", msg) - def echo_tuple(self, msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]: + def echo_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]: return self._zero_client.call("echo_tuple", msg) - def echo_union(self, msg: typing.Union[int, str]) -> typing.Union[int, str]: + def echo_union(self, msg: Union[int, str]) -> Union[int, str]: return self._zero_client.call("echo_union", msg) - def divide(self, msg: typing.Tuple[int, int]) -> int: + def divide(self, msg: Tuple[int, int]) -> int: return self._zero_client.call("divide", msg) """ ) diff --git a/zero/codegen/codegen.py b/zero/codegen/codegen.py index 9ffd6f4..c28ba80 100644 --- a/zero/codegen/codegen.py +++ b/zero/codegen/codegen.py @@ -1,86 +1,231 @@ +import datetime +import decimal +import enum import inspect +import uuid +from dataclasses import is_dataclass +from typing import ( + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +import msgspec + +from zero.utils.type_util import typing_types # from pydantic import BaseModel class CodeGen: - def __init__(self, rpc_router, rpc_input_type_map, rpc_return_type_map): + def __init__( + self, + rpc_router: Dict[str, Tuple[Callable, bool]], + rpc_input_type_map: Dict[str, Optional[type]], + rpc_return_type_map: Dict[str, Optional[type]], + ): self._rpc_router = rpc_router self._rpc_input_type_map = rpc_input_type_map self._rpc_return_type_map = rpc_return_type_map - self._typing_imports = set() + + # for imports + self._typing_imports: list[str] = [ + str(typ).replace("typing.", "") for typ in typing_types + ] + self._typing_imports.sort() + self._datetime_imports: set[str] = set() + self._has_uuid = False + self._has_decimal = False + self._has_enum = True def generate_code(self, host="localhost", port=5559): code = f"""# Generated by Zero -# import types as per needed +# import types as per needed, not all imports are shown here from zero import ZeroClient zero_client = ZeroClient("{host}", {port}) - +""" + code += self.generate_models() + code += """ class RpcClient: def __init__(self, zero_client: ZeroClient): self._zero_client = zero_client """ for func_name in self._rpc_router: + input_param_name = ( + None + if self._rpc_input_type_map[func_name] is None + else self.get_function_input_param_name(func_name) + ) code += f""" {self.get_function_str(func_name)} - return self._zero_client.call("{func_name}", { - None if self._rpc_input_type_map[func_name] is None - else self.get_function_input_param_name(func_name) - }) + return self._zero_client.call("{func_name}", {input_param_name}) """ - # self.generate_data_classes() TODO: next feature + + # add imports after first 2 lines + code_lines = code.split("\n") + code_lines.insert(2, self.get_imports(code)) + code = "\n".join(code_lines) + + if "typing." in code: + code = code.replace("typing.", "") + return code - def get_imports(self): - return f"from typing import {', '.join(i for i in self._typing_imports)}" + def get_imports(self, code): + for func_name in self._rpc_input_type_map: + input_type = self._rpc_input_type_map[func_name] + self._track_imports(input_type) - def get_input_type_str(self, func_name: str): # pragma: no cover - if self._rpc_input_type_map[func_name] is None: - return "" - if self._rpc_input_type_map[func_name].__module__ == "typing": - type_name = self._rpc_input_type_map[func_name]._name - self._typing_imports.add(type_name) - return ": " + type_name - return ": " + self._rpc_input_type_map[func_name].__name__ - - def get_return_type_str(self, func_name: str): # pragma: no cover - if self._rpc_return_type_map[func_name].__module__ == "typing": - type_name = self._rpc_return_type_map[func_name]._name - self._typing_imports.add(type_name) - return type_name - return self._rpc_return_type_map[func_name].__name__ + for typ in list(self._typing_imports): + if typ + "[" not in code: + self._typing_imports.remove(typ) + + import_lines = [] + if "@dataclasses.dataclass" in code: + import_lines.append("import dataclasses") + if "@dataclass" in code: + import_lines.append("from dataclasses import dataclass") + + if self._datetime_imports: + import_lines.append( + "from datetime import " + ", ".join(sorted(self._datetime_imports)) + ) + + if self._has_decimal: + import_lines.append("import decimal") + if self._has_enum: + import_lines.append("import enum") + + if "(msgspec.Struct)" in code: + import_lines.append("import msgspec") + + if "(Struct)" in code: + import_lines.append("from msgspec import Struct") + + if self._typing_imports: + import_lines.append("from typing import " + ", ".join(self._typing_imports)) + + if self._has_uuid: + import_lines.append("import uuid") + + return "\n".join(import_lines) + + def _track_imports(self, input_type): + if not input_type: + return + if input_type in (datetime.datetime, datetime.date, datetime.time): + self._datetime_imports.add(input_type.__name__) + elif input_type == uuid.UUID: + self._has_uuid = True + elif input_type == decimal.Decimal: + self._has_decimal = True def get_function_str(self, func_name: str): func = self._rpc_router[func_name][0] func_lines = inspect.getsourcelines(func)[0] - def_line = [line for line in func_lines if "def" in line][0] + func_str = "".join(func_lines) + # from def to -> + def_str = func_str.split("def")[1].split("->")[0].strip() + def_str = "def " + def_str - # put self after the first ( - def_line = def_line.replace(f"{func_name}(", f"{func_name}(self").replace( - "async ", "" - ) + # Insert 'self' as the first parameter + insert_index = def_str.index("(") + 1 + if self._rpc_input_type_map[func_name]: # If there is input, add 'self, ' + def_str = def_str[:insert_index] + "self, " + def_str[insert_index:] + else: # If there is no input, just add 'self' + def_str = def_str[:insert_index] + "self" + def_str[insert_index:] - # if there is input, add comma after self - if self._rpc_input_type_map[func_name]: - def_line = def_line.replace(f"{func_name}(self", f"{func_name}(self, ") + # from -> to : + return_type_str = func_str.split("->")[1].split(":")[0].strip() + # add return type + def_str = def_str + f" -> {return_type_str}:" - return def_line.replace("\n", "") + return def_str.strip() def get_function_input_param_name(self, func_name: str): func = self._rpc_router[func_name][0] func_lines = inspect.getsourcelines(func)[0] - def_line = [line for line in func_lines if "def" in line][0] - params = def_line.split("(")[1].split(")")[0] - return params.split(":")[0].strip() - - # def generate_data_classes(self): - # code = "" - # for func_name in self._rpc_input_type_map: - # input_class = self._rpc_input_type_map[func_name] - # if input_class and is_pydantic(input_class): - # code += inspect.getsource(input_class) + func_str = "".join(func_lines) + # from bracket to bracket + input_param_name = func_str.split("(")[1].split(")")[0] + # everything until : + input_param_name = input_param_name.split(":")[0] + return input_param_name.strip() + + def _generate_class_code(self, cls: Type, already_generated: Set[Type]) -> str: + if cls in already_generated: + return "" + + code = self._generate_code_for_bases(cls, already_generated) + code += self._generate_code_for_fields(cls, already_generated) + code += inspect.getsource(cls) + "\n\n" + already_generated.add(cls) + return code + + def _generate_code_for_bases(self, cls: Type, already_generated: Set[Type]) -> str: + code = "" + for base_cls in cls.__bases__: + if issubclass(base_cls, msgspec.Struct) and base_cls is not msgspec.Struct: + code += self._generate_class_code(base_cls, already_generated) + elif is_dataclass(base_cls): + code += self._generate_class_code(base_cls, already_generated) + return code + + def _generate_code_for_fields(self, cls: Type, already_generated: Set[Type]) -> str: + code = "" + for field_type in get_type_hints(cls).values(): + code += self._generate_code_for_type(field_type, already_generated) + return code + + def _generate_code_for_type(self, typ: Type, already_generated: Set[Type]) -> str: + code = "" + typs = self._resolve_field_type(typ) + for it in typs: + self._track_imports(it) + if isinstance(it, type) and ( + issubclass(it, (msgspec.Struct, enum.Enum, enum.IntEnum)) + or is_dataclass(it) + ): + code += self._generate_class_code(it, already_generated) + return code + + def _resolve_field_type(self, field_type) -> List[Type]: + origin = get_origin(field_type) + if origin in (list, tuple, set, frozenset, Optional): + return [get_args(field_type)[0]] + elif origin == dict: + return [get_args(field_type)[1]] + elif origin == Union: + return list(get_args(field_type)) + + return [field_type] + + def generate_models(self) -> str: + already_generated: Set[Type] = set() + code = "" + + merged_types = list(self._rpc_input_type_map.values()) + list( + self._rpc_return_type_map.values() + ) + # retain order and remove duplicates + merged_types = list(dict.fromkeys(merged_types)) + + for input_type in merged_types: + if input_type is None: + continue + code += self._generate_code_for_type(input_type, already_generated) + + return code diff --git a/zero/config.py b/zero/config.py index bd89c09..0ce5feb 100644 --- a/zero/config.py +++ b/zero/config.py @@ -11,7 +11,6 @@ RESERVED_FUNCTIONS = ["get_rpc_contract", "connect", "__server_info__"] ZEROMQ_PATTERN = "proxy" -ENCODER = "msgspec" SUPPORTED_PROTOCOLS = { "zeromq": { "server": ZMQServer, diff --git a/zero/encoder/__init__.py b/zero/encoder/__init__.py index 6b48645..67da175 100644 --- a/zero/encoder/__init__.py +++ b/zero/encoder/__init__.py @@ -1,2 +1,3 @@ -from .factory import get_encoder from .protocols import Encoder + +__all__ = ["Encoder"] diff --git a/zero/encoder/factory.py b/zero/encoder/factory.py deleted file mode 100644 index 1ed02b2..0000000 --- a/zero/encoder/factory.py +++ /dev/null @@ -1,9 +0,0 @@ -from .msgspc import MsgspecEncoder -from .protocols import Encoder - - -def get_encoder(name: str) -> Encoder: - if name == "msgspec": - return MsgspecEncoder() - - raise ValueError(f"unknown encoder: {name}") diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py index 1b80791..da185c0 100644 --- a/zero/protocols/zeromq/client.py +++ b/zero/protocols/zeromq/client.py @@ -1,9 +1,11 @@ import logging import threading -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, Optional, Type, TypeVar from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder +from zero.utils.type_util import AllowedType from zero.zeromq_patterns import ( AsyncZeroMQClient, ZeroMQClient, @@ -23,7 +25,7 @@ def __init__( ): self._address = address self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self.client_pool = ZMQClientPool( self._address, @@ -34,7 +36,7 @@ def __init__( def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -65,7 +67,7 @@ def __init__( ): self._address = address self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self.client_pool = AsyncZMQClientPool( self._address, @@ -76,7 +78,7 @@ def __init__( async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -114,7 +116,7 @@ def __init__( self._pool: Dict[int, ZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() def get(self) -> ZeroMQClient: thread_id = threading.get_ident() @@ -146,7 +148,7 @@ def __init__( self._pool: Dict[int, AsyncZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() async def get(self) -> AsyncZeroMQClient: thread_id = threading.get_ident() diff --git a/zero/rpc/client.py b/zero/rpc/client.py index 77d54e8..3fd2507 100644 --- a/zero/rpc/client.py +++ b/zero/rpc/client.py @@ -1,8 +1,10 @@ -from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Optional, Type, TypeVar from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder from zero.error import MethodNotFoundException, RemoteException, ValidationException +from zero.utils.type_util import AllowedType if TYPE_CHECKING: from zero.rpc.protocols import AsyncZeroClientProtocol, ZeroClientProtocol @@ -55,7 +57,7 @@ def __init__( """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self._client_inst: "ZeroClientProtocol" = self._determine_client_cls(protocol)( self._address, self._default_timeout, @@ -79,7 +81,7 @@ def _determine_client_cls(self, protocol: str) -> Type["ZeroClientProtocol"]: def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -173,7 +175,7 @@ def __init__( """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self._client_inst: "AsyncZeroClientProtocol" = self._determine_client_cls( "zeromq" )( @@ -199,7 +201,7 @@ def _determine_client_cls(self, protocol: str) -> Type["AsyncZeroClientProtocol" async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py index 3752229..4ac9a02 100644 --- a/zero/rpc/protocols.py +++ b/zero/rpc/protocols.py @@ -6,11 +6,11 @@ Tuple, Type, TypeVar, - Union, runtime_checkable, ) from zero.encoder import Encoder +from zero.utils.type_util import AllowedType T = TypeVar("T") @@ -47,7 +47,7 @@ def __init__( def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: @@ -70,7 +70,7 @@ def __init__( async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: diff --git a/zero/rpc/server.py b/zero/rpc/server.py index f9cbc14..5347d3f 100644 --- a/zero/rpc/server.py +++ b/zero/rpc/server.py @@ -1,10 +1,21 @@ import logging import os from asyncio import iscoroutinefunction -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Optional, + Tuple, + Type, + Union, +) from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder from zero.utils import type_util if TYPE_CHECKING: @@ -47,7 +58,7 @@ def __init__( self._address = f"tcp://{self._host}:{self._port}" # to encode/decode messages from/to client - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() # Stores rpc functions against their names # and if they are coroutines @@ -79,7 +90,7 @@ def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]: ) return server_cls - def register_rpc(self, func: Callable): + def register_rpc(self, func: Callable[..., Union[Any, Coroutine]]): """ Register a function available for clients. Function should have a single argument. diff --git a/zero/utils/type_util.py b/zero/utils/type_util.py index fd69d4c..1d1d812 100644 --- a/zero/utils/type_util.py +++ b/zero/utils/type_util.py @@ -4,7 +4,18 @@ import enum import typing import uuid -from typing import Callable, Optional, get_origin, get_type_hints +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Optional, + Protocol, + Type, + Union, + get_origin, + get_type_hints, +) import msgspec @@ -17,15 +28,10 @@ bytes, bytearray, tuple, - typing.Tuple, list, - typing.List, dict, - typing.Dict, set, - typing.Set, frozenset, - typing.FrozenSet, ] std_lib_types: typing.List = [ @@ -40,6 +46,11 @@ ] typing_types: typing.List = [ + typing.Tuple, + typing.List, + typing.Dict, + typing.Set, + typing.FrozenSet, typing.Any, typing.Union, typing.Optional, @@ -54,6 +65,46 @@ allowed_types = builtin_types + std_lib_types + typing_types +class IsDataclass(Protocol): + # as already noted in comments, checking for this attribute is currently + # the most reliable way to ascertain that something is a dataclass + __dataclass_fields__: ClassVar[Dict[str, Any]] + + +AllowedType = Union[ + None, + bool, + int, + float, + str, + bytes, + bytearray, + tuple, + list, + dict, + set, + frozenset, + datetime.datetime, + datetime.date, + datetime.time, + uuid.UUID, + decimal.Decimal, + enum.Enum, + enum.IntEnum, + IsDataclass, + typing.Tuple, + typing.List, + typing.Dict, + typing.Set, + typing.FrozenSet, + typing.Any, + msgspec.Struct, + msgspec.Raw, + Type[enum.Enum], # For enum classes + Type[enum.IntEnum], # For int enum classes +] + + def verify_function_args(func: Callable) -> None: arg_count = func.__code__.co_argcount if arg_count < 1: @@ -125,6 +176,11 @@ def verify_function_input_type(func: Callable): def verify_function_return_type(func: Callable): return_type = get_function_return_class(func) + if return_type is None: + raise TypeError( + f"{func.__name__} returns None; RPC functions must return a value" + ) + if return_type in allowed_types: return @@ -149,24 +205,3 @@ def verify_allowed_type(msg, rpc_method: Optional[str] = None): f"{msg} is not allowed {method_name}; allowed types are: \n" + "\n".join([str(t) for t in allowed_types]) ) - - -def verify_incoming_rpc_call_input_type( - msg, rpc_method: str, rpc_input_type_map: dict -): # pragma: no cover - input_type = rpc_input_type_map[rpc_method] - if input_type is None: - return - - if input_type in builtin_types: - if input_type != type(msg): - raise TypeError( - f"{msg} is not allowed for method `{rpc_method}`; allowed type: {input_type}" - ) - - origin_type = get_origin(input_type) - if origin_type in builtin_types: - if origin_type != type(msg): - raise TypeError( - f"{msg} is not allowed for method `{rpc_method}`; allowed type: {input_type}" - ) diff --git a/zero/zeromq_patterns/__init__.py b/zero/zeromq_patterns/__init__.py index 9b05bc0..2af176f 100644 --- a/zero/zeromq_patterns/__init__.py +++ b/zero/zeromq_patterns/__init__.py @@ -1,2 +1,13 @@ from .factory import get_async_client, get_broker, get_client, get_worker from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker + +__all__ = [ + "get_async_client", + "get_broker", + "get_client", + "get_worker", + "AsyncZeroMQClient", + "ZeroMQBroker", + "ZeroMQClient", + "ZeroMQWorker", +] diff --git a/zero/zeromq_patterns/queue_device/__init__.py b/zero/zeromq_patterns/queue_device/__init__.py index 5fae288..4011159 100644 --- a/zero/zeromq_patterns/queue_device/__init__.py +++ b/zero/zeromq_patterns/queue_device/__init__.py @@ -1,3 +1,5 @@ from .broker import ZeroMQBroker from .client import AsyncZeroMQClient, ZeroMQClient from .worker import ZeroMQWorker + +__all__ = ["ZeroMQBroker", "ZeroMQClient", "AsyncZeroMQClient", "ZeroMQWorker"]