From c85072d80f53d8bbff7dfcba3b75e73adc911dfd Mon Sep 17 00:00:00 2001 From: Azizul Haque Ananto Date: Thu, 27 Jun 2024 18:05:55 +0200 Subject: [PATCH] Codegen now produce models --- .coveragerc | 6 + Dockerfile.test.py310 | 9 - Dockerfile.test.py38 | 9 - Dockerfile.test.py39 | 9 - Makefile | 2 +- README.md | 153 ++-- benchmarks/dockerize/README.md | 9 +- examples/basic/schema.py | 21 + examples/basic/server.py | 13 +- tests/functional/codegen/__init__.py | 0 tests/functional/codegen/test_codegen.py | 708 ++++++++++++++++++ .../single_server/client_generation_test.py | 129 +++- .../single_server/client_server_test.py | 230 +++++- tests/functional/single_server/server.py | 213 +++++- tests/functional/test_async_to_sync.py | 51 ++ tests/unit/test_type_util.py | 134 ++++ tests/unit/test_util.py | 21 + tests/unit/test_worker.py | 130 +++- zero/codegen/codegen.py | 256 +++++-- zero/config.py | 1 - zero/encoder/__init__.py | 3 +- zero/encoder/factory.py | 9 - zero/protocols/zeromq/client.py | 18 +- zero/protocols/zeromq/worker.py | 65 +- zero/rpc/client.py | 14 +- zero/rpc/protocols.py | 12 +- zero/rpc/server.py | 19 +- zero/utils/type_util.py | 136 ++-- zero/zeromq_patterns/__init__.py | 11 + zero/zeromq_patterns/queue_device/__init__.py | 2 + 30 files changed, 2049 insertions(+), 344 deletions(-) create mode 100644 .coveragerc delete mode 100644 Dockerfile.test.py310 delete mode 100644 Dockerfile.test.py38 delete mode 100644 Dockerfile.test.py39 create mode 100644 tests/functional/codegen/__init__.py create mode 100644 tests/functional/codegen/test_codegen.py create mode 100644 tests/functional/test_async_to_sync.py create mode 100644 tests/unit/test_type_util.py create mode 100644 tests/unit/test_util.py delete mode 100644 zero/encoder/factory.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..7ef80e0 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + zero/zeromq_patterns/factory.py + zero/zeromq_patterns/helpers.py + zero/logger.py + zero/rpc/protocols.py \ No newline at end of file diff --git a/Dockerfile.test.py310 b/Dockerfile.test.py310 deleted file mode 100644 index fc33837..0000000 --- a/Dockerfile.test.py310 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.10-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/Dockerfile.test.py38 b/Dockerfile.test.py38 deleted file mode 100644 index 209a32e..0000000 --- a/Dockerfile.test.py38 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.8-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/Dockerfile.test.py39 b/Dockerfile.test.py39 deleted file mode 100644 index 5e6c2fd..0000000 --- a/Dockerfile.test.py39 +++ /dev/null @@ -1,9 +0,0 @@ -FROM python:3.9-slim - -COPY tests/requirements.txt . -RUN pip install -r requirements.txt - -COPY zero ./zero -COPY tests ./tests - -CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"] diff --git a/Makefile b/Makefile index 413330e..ca8d107 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ setup: ) test: - python3 -m pytest tests --cov=zero --cov-report=term-missing -vv --durations=10 --timeout=280 + python3 -m pytest tests --cov=zero --cov-report=term-missing --cov-config=.coveragerc -vv --durations=10 --timeout=280 docker-test: docker build -t zero-test -f Dockerfile.test.py38 . diff --git a/README.md b/README.md index 1754f17..fb4a253 100644 --- a/README.md +++ b/README.md @@ -27,100 +27,101 @@ **Features**: -* Zero provides **faster communication** (see [benchmarks](https://github.com/Ananto30/zero#benchmarks-)) between the microservices using [zeromq](https://zeromq.org/) under the hood. -* Zero uses messages for communication and traditional **client-server** or **request-reply** pattern is supported. -* Support for both **async** and **sync**. -* The base server (ZeroServer) **utilizes all cpu cores**. -* **Code generation**! See [example](https://github.com/Ananto30/zero#code-generation-) ๐Ÿ‘‡ +- Zero provides **faster communication** (see [benchmarks](https://github.com/Ananto30/zero#benchmarks-)) between the microservices using [zeromq](https://zeromq.org/) under the hood. +- Zero uses messages for communication and traditional **client-server** or **request-reply** pattern is supported. +- Support for both **async** and **sync**. +- The base server (ZeroServer) **utilizes all cpu cores**. +- **Code generation**! See [example](https://github.com/Ananto30/zero#code-generation-) ๐Ÿ‘‡ **Philosophy** behind Zero: -* **Zero learning curve**: The learning curve is tends to zero. Just add functions and spin up a server, literally that's it! The framework hides the complexity of messaging pattern that enables faster communication. -* **ZeroMQ**: An awesome messaging library enables the power of Zero. +- **Zero learning curve**: The learning curve is tends to zero. Just add functions and spin up a server, literally that's it! The framework hides the complexity of messaging pattern that enables faster communication. +- **ZeroMQ**: An awesome messaging library enables the power of Zero. Let's get started! # Getting started ๐Ÿš€ -*Ensure Python 3.8+* +_Ensure Python 3.8+_ pip install zeroapi **For Windows**, [tornado](https://pypi.org/project/tornado/) needs to be installed separately (for async operations). It's not included with `zeroapi` because for linux and mac-os, tornado is not needed as they have their own event loops. -* Create a `server.py` +- Create a `server.py` - ```python - from zero import ZeroServer + ```python + from zero import ZeroServer - app = ZeroServer(port=5559) + app = ZeroServer(port=5559) - @app.register_rpc - def echo(msg: str) -> str: - return msg + @app.register_rpc + def echo(msg: str) -> str: + return msg - @app.register_rpc - async def hello_world() -> str: - return "hello world" + @app.register_rpc + async def hello_world() -> str: + return "hello world" - if __name__ == "__main__": - app.run() - ``` + if __name__ == "__main__": + app.run() + ``` -* The **RPC functions only support one argument** (`msg`) for now. +- The **RPC functions only support one argument** (`msg`) for now. -* Also note that server **RPC functions are type hinted**. Type hint is **must** in Zero server. Supported types can be found [here](/zero/utils/type_util.py#L11). +- Also note that server **RPC functions are type hinted**. Type hint is **must** in Zero server. Supported types can be found [here](/zero/utils/type_util.py#L11). -* Run the server - ```shell - python -m server - ``` +- Run the server -* Call the rpc methods + ```shell + python -m server + ``` - ```python - from zero import ZeroClient +- Call the rpc methods - zero_client = ZeroClient("localhost", 5559) + ```python + from zero import ZeroClient - def echo(): - resp = zero_client.call("echo", "Hi there!") - print(resp) + zero_client = ZeroClient("localhost", 5559) - def hello(): - resp = zero_client.call("hello_world", None) - print(resp) + def echo(): + resp = zero_client.call("echo", "Hi there!") + print(resp) + def hello(): + resp = zero_client.call("hello_world", None) + print(resp) - if __name__ == "__main__": - echo() - hello() - ``` -* Or using async client - + if __name__ == "__main__": + echo() + hello() + ``` - ```python - import asyncio +- Or using async client - - from zero import AsyncZeroClient + ```python + import asyncio - zero_client = AsyncZeroClient("localhost", 5559) + from zero import AsyncZeroClient - async def echo(): - resp = await zero_client.call("echo", "Hi there!") - print(resp) + zero_client = AsyncZeroClient("localhost", 5559) - async def hello(): - resp = await zero_client.call("hello_world", None) - print(resp) + async def echo(): + resp = await zero_client.call("echo", "Hi there!") + print(resp) + async def hello(): + resp = await zero_client.call("hello_world", None) + print(resp) - if __name__ == "__main__": - loop = asyncio.get_event_loop() - loop.run_until_complete(echo()) - loop.run_until_complete(hello()) - ``` + + if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.run_until_complete(echo()) + loop.run_until_complete(hello()) + ``` # Serialization ๐Ÿ“ฆ @@ -220,13 +221,13 @@ if __name__ == "__main__": Currently, the code generation tool supports only `ZeroClient` and not `AsyncZeroClient`. -*WIP - Generate models from server code.* +_WIP - Generate models from server code._ # Important notes! ๐Ÿ“ -* `ZeroServer` should always be run under `if __name__ == "__main__":`, as it uses multiprocessing. -* `ZeroServer` creates the workers in different processes, so anything global in your code will be instantiated N times where N is the number of workers. So if you want to initiate them once, put them under `if __name__ == "__main__":`. But recommended to not use global vars. And Databases, Redis, other clients, creating them N times in different processes is fine and preferred. -* The methods which are under `register_rpc()` in `ZeroServer` should have **type hinting**, like `def echo(msg: str) -> str:` +- `ZeroServer` should always be run under `if __name__ == "__main__":`, as it uses multiprocessing. +- `ZeroServer` creates the workers in different processes, so anything global in your code will be instantiated N times where N is the number of workers. So if you want to initiate them once, put them under `if __name__ == "__main__":`. But recommended to not use global vars. And Databases, Redis, other clients, creating them N times in different processes is fine and preferred. +- The methods which are under `register_rpc()` in `ZeroServer` should have **type hinting**, like `def echo(msg: str) -> str:` # Let's do some benchmarking! ๐ŸŽ @@ -236,8 +237,8 @@ So we will be testing a gateway calling another server for some data. Check the There are two endpoints in every tests, -* `/hello`: Just call for a hello world response ๐Ÿ˜… -* `/order`: Save a Order object in redis +- `/hello`: Just call for a hello world response ๐Ÿ˜… +- `/order`: Save a Order object in redis Compare the results! ๐Ÿ‘‡ @@ -245,25 +246,25 @@ Compare the results! ๐Ÿ‘‡ 11th Gen Intelยฎ Coreโ„ข i7-11800H @ 2.30GHz, 8 cores, 16 threads, 16GB RAM (Docker in Ubuntu 22.04.2 LTS) -*(Sorted alphabetically)* +_(Sorted alphabetically)_ -Framework | "hello world" (req/s) | 99% latency (ms) | redis save (req/s) | 99% latency (ms) ------------ | --------------------- | ---------------- | ------------------ | ---------------- -aiohttp | 14949.57 | 8.91 | 9753.87 | 13.75 -aiozmq | 13844.67 | 9.55 | 5239.14 | 30.92 -blacksheep | 32967.27 | 3.03 | 18010.67 | 6.79 -fastApi | 13154.96 | 9.07 | 8369.87 | 15.91 -sanic | 18793.08 | 5.88 | 12739.37 | 8.78 -zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69 -zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80 +| Framework | "hello world" (req/s) | 99% latency (ms) | redis save (req/s) | 99% latency (ms) | +| ----------- | --------------------- | ---------------- | ------------------ | ---------------- | +| aiohttp | 14949.57 | 8.91 | 9753.87 | 13.75 | +| aiozmq | 13844.67 | 9.55 | 5239.14 | 30.92 | +| blacksheep | 32967.27 | 3.03 | 18010.67 | 6.79 | +| fastApi | 13154.96 | 9.07 | 8369.87 | 15.91 | +| sanic | 18793.08 | 5.88 | 12739.37 | 8.78 | +| zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69 | +| zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80 | Seems like blacksheep is faster on hello world, but in more complex operations like saving to redis, zero is the winner! ๐Ÿ† # Roadmap ๐Ÿ—บ -* [x] Make msgspec as default serializer -* [ ] Add support for async server (currently the sync server runs async functions in the eventloop, which is blocking) -* [ ] Add pub/sub support +- [x] Make msgspec as default serializer +- [ ] Add support for async server (currently the sync server runs async functions in the eventloop, which is blocking) +- [ ] Add pub/sub support # Contribution diff --git a/benchmarks/dockerize/README.md b/benchmarks/dockerize/README.md index 412128a..3250e76 100644 --- a/benchmarks/dockerize/README.md +++ b/benchmarks/dockerize/README.md @@ -42,7 +42,7 @@ I have used 2x cpu threads so `-t 16` and 16x25 = 400 connections. 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz, 4 cores, 8 threads, 12GB RAM -*(Sorted alphabetically)* +_(Sorted alphabetically)_ | Framework | "hello world" (req/s) | 99% latency (ms) | redis save (req/s) | 99% latency (ms) | | --------- | --------------------- | ---------------- | ------------------ | ---------------- | @@ -52,12 +52,11 @@ I have used 2x cpu threads so `-t 16` and 16x25 = 400 connections. | sanic | 13195.99 | 20.04 | 7226.72 | 25.24 | | zero | 18867.00 | 11.48 | 12293.81 | 11.68 | - ## Old benchmark results Intel Core i3 10100, 4 cores, 8 threads, 16GB RAM, with docker limits **cpu 40% and memory 256m** -*(Sorted alphabetically)* +_(Sorted alphabetically)_ | Framework | "hello world" example | redis save example | | --------- | --------------------- | ------------------ | @@ -67,10 +66,9 @@ Intel Core i3 10100, 4 cores, 8 threads, 16GB RAM, with docker limits **cpu 40% | sanic | 3,085.80 req/s | 547.02 req/s | | zero | 5,000.77 req/s | 784.51 req/s | - MacBook Pro (13-inch, M1, 2020), Apple M1, 8 cores (4 performance and 4 efficiency), 8 GB RAM -*(Sorted alphabetically)* +_(Sorted alphabetically)_ | Framework | "hello world" example | redis save example | | --------- | --------------------- | ------------------ | @@ -81,7 +79,6 @@ MacBook Pro (13-inch, M1, 2020), Apple M1, 8 cores (4 performance and 4 efficien More about MacBook benchmarks [here](https://github.com/Ananto30/zero/blob/main/benchmarks/others/mac-results.md) - ### Note! Please note that sometimes just `docker-compose up` will not run the `wrk`. Because you know about the docker `depends_on` only ensures the service is up, not running or healthy. So you may need to run wrk service after other services are up and running. diff --git a/examples/basic/schema.py b/examples/basic/schema.py index ecb8e66..9e4057b 100644 --- a/examples/basic/schema.py +++ b/examples/basic/schema.py @@ -1,9 +1,30 @@ +from dataclasses import dataclass +from datetime import date from typing import List import msgspec +class Address(msgspec.Struct): + street: str + city: str + zip: int + + class User(msgspec.Struct): name: str age: int emails: List[str] + addresses: List[Address] + registered_at: date + + +@dataclass +class Teacher: + name: str + + +class Student(User): + roll_no: int + marks: List[int] + teachers: List[Teacher] diff --git a/examples/basic/server.py b/examples/basic/server.py index ffc3c36..2bac606 100644 --- a/examples/basic/server.py +++ b/examples/basic/server.py @@ -5,7 +5,7 @@ from zero import ZeroServer -from .schema import User +from .schema import Student, Teacher, User app = ZeroServer(port=5559) @@ -42,6 +42,17 @@ def hello_users(users: typing.List[User]) -> str: return f"Hello {', '.join([user.name for user in users])}! Your emails are {', '.join([email for user in users for email in user.emails])}!" +teachers = [ + Teacher(name="Teacher1"), + Teacher(name="Teacher2"), +] + + +@app.register_rpc +def hello_student(student: Student) -> str: + return f"Hello {student.name}! You are {student.age} years old. Your email is {student.emails[0]}! Your roll no. is {student.roll_no} and your marks are {student.marks}!" + + if __name__ == "__main__": app.register_rpc(echo) app.register_rpc(hello_world) diff --git a/tests/functional/codegen/__init__.py b/tests/functional/codegen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/functional/codegen/test_codegen.py b/tests/functional/codegen/test_codegen.py new file mode 100644 index 0000000..df43014 --- /dev/null +++ b/tests/functional/codegen/test_codegen.py @@ -0,0 +1,708 @@ +import dataclasses +import datetime +import decimal +import enum +import typing +import unittest +import uuid +from dataclasses import dataclass +from datetime import date +from typing import Dict, List, Optional, Tuple, Union + +import msgspec +from msgspec import Struct + +from zero.codegen.codegen import CodeGen + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclasses.dataclass +class SimpleDataclass2: + c: int + d: str + + +@dataclass +class ChildDataclass(SimpleDataclass): + e: int + f: str + + +class SimpleStruct(Struct): + h: int + i: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + +class SimpleEnum(enum.Enum): + ONE = 1 + TWO = 2 + + +class SimpleIntEnum(enum.IntEnum): + ONE = 1 + TWO = 2 + + +def func_none(arg: None) -> str: + return "Received None" + + +def func_bool(arg: bool) -> str: + return f"Received bool: {arg}" + + +def func_int(arg: int) -> str: + return f"Received int: {arg}" + + +def func_float(arg: float) -> str: + return f"Received float: {arg}" + + +def func_str(arg: str) -> str: + return f"Received str: {arg}" + + +def func_bytes(arg: bytes) -> str: + return f"Received bytes: {arg}" + + +def func_bytearray(arg: bytearray) -> str: + return f"Received bytearray: {arg}" + + +def func_tuple(arg: tuple) -> str: + return f"Received tuple: {arg}" + + +def func_list(arg: list) -> str: + return f"Received list: {arg}" + + +def func_dict(arg: dict) -> str: + return f"Received dict: {arg}" + + +def func_optional_dict(arg: Optional[dict]) -> str: + return f"Received dict: {arg}" + + +def func_set(arg: set) -> str: + return f"Received set: {arg}" + + +def func_frozenset(arg: frozenset) -> str: + return f"Received frozenset: {arg}" + + +def func_datetime(arg: datetime.datetime) -> str: + return f"Received datetime: {arg}" + + +def func_date(arg: date) -> str: + return f"Received date: {arg}" + + +def func_time(arg: datetime.time) -> str: + return f"Received time: {arg}" + + +def func_uuid(arg: uuid.UUID) -> str: + return f"Received UUID: {arg}" + + +def func_decimal(arg: decimal.Decimal) -> str: + return f"Received Decimal: {arg}" + + +def func_enum(arg: SimpleEnum) -> str: + return f"Received Enum: {arg}" + + +def func_intenum(arg: SimpleIntEnum) -> str: + return f"Received IntEnum: {arg}" + + +def func_dataclass(arg: SimpleDataclass) -> str: + return f"Received dataclass: {arg}" + + +def func_tuple_typing(arg: typing.Tuple[int, str]) -> str: + return f"Received typing.Tuple: {arg}" + + +def func_list_typing(arg: typing.List[int]) -> str: + return f"Received typing.List: {arg}" + + +def func_dict_typing(arg: typing.Dict[str, int]) -> str: + return f"Received typing.Dict: {arg}" + + +def func_set_typing(arg: typing.Set[int]) -> str: + return f"Received typing.Set: {arg}" + + +def func_frozenset_typing(arg: typing.FrozenSet[int]) -> str: + return f"Received typing.FrozenSet: {arg}" + + +def func_any_typing(arg: typing.Any) -> str: + return f"Received typing.Any: {arg}" + + +def func_union_typing(arg: typing.Union[int, str]) -> str: + return f"Received typing.Union: {arg}" + + +def func_optional_typing(arg: typing.Optional[int]) -> str: + return f"Received typing.Optional: {arg}" + + +def func_msgspec_struct(arg: SimpleStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_msgspec_struct_complex(arg: ComplexStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_child_complex_struct(arg: ChildComplexStruct) -> str: + return f"Received msgspec.Struct: {arg}" + + +def func_return_optional_child_complex_struct() -> Optional[ChildComplexStruct]: + return None + + +def func_return_complex_struct() -> ComplexStruct: + return ComplexStruct( + a=1, + b="hello", + c=SimpleStruct(h=1, i="hello"), + d=[SimpleStruct(h=1, i="hello")], + e={"1": SimpleStruct(h=1, i="hello")}, + f=(SimpleDataclass(a=1, b="hello"), SimpleStruct(h=1, i="hello")), + g=SimpleDataclass(a=1, b="hello"), + ) + + +def func_take_optional_child_dataclass_return_optional_child_complex_struct( + arg: Optional[ChildDataclass], +) -> Optional[ChildComplexStruct]: + return None + + +class TestCodegen(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + self._rpc_router = { + "func_none": (func_none, False), + "func_bool": (func_bool, False), + "func_int": (func_int, False), + "func_float": (func_float, False), + "func_str": (func_str, False), + "func_bytes": (func_bytes, False), + "func_bytearray": (func_bytearray, False), + "func_tuple": (func_tuple, False), + "func_list": (func_list, False), + "func_dict": (func_dict, False), + "func_optional_dict": (func_optional_dict, False), + "func_set": (func_set, False), + "func_frozenset": (func_frozenset, False), + "func_datetime": (func_datetime, False), + "func_date": (func_date, False), + "func_time": (func_time, False), + "func_uuid": (func_uuid, False), + "func_decimal": (func_decimal, False), + "func_enum": (func_enum, False), + "func_intenum": (func_intenum, False), + "func_dataclass": (func_dataclass, False), + "func_tuple_typing": (func_tuple_typing, False), + "func_list_typing": (func_list_typing, False), + "func_dict_typing": (func_dict_typing, False), + "func_set_typing": (func_set_typing, False), + "func_frozenset_typing": (func_frozenset_typing, False), + "func_any_typing": (func_any_typing, False), + "func_union_typing": (func_union_typing, False), + "func_optional_typing": (func_optional_typing, False), + "func_msgspec_struct": (func_msgspec_struct, False), + "func_msgspec_struct_complex": (func_msgspec_struct_complex, False), + "func_child_complex_struct": (func_child_complex_struct, False), + "func_return_complex_struct": (func_return_complex_struct, False), + } + self._rpc_input_type_map = { + "func_none": None, + "func_bool": bool, + "func_int": int, + "func_float": float, + "func_str": str, + "func_bytes": bytes, + "func_bytearray": bytearray, + "func_tuple": tuple, + "func_list": list, + "func_dict": dict, + "func_optional_dict": Optional[dict], + "func_set": set, + "func_frozenset": frozenset, + "func_datetime": datetime.datetime, + "func_date": datetime.date, + "func_time": datetime.time, + "func_uuid": uuid.UUID, + "func_decimal": decimal.Decimal, + "func_enum": SimpleEnum, + "func_intenum": SimpleIntEnum, + "func_dataclass": SimpleDataclass, + "func_tuple_typing": typing.Tuple[int, str], + "func_list_typing": typing.List[int], + "func_dict_typing": typing.Dict[str, int], + "func_set_typing": typing.Set[int], + "func_frozenset_typing": typing.FrozenSet[int], + "func_any_typing": typing.Any, + "func_union_typing": typing.Union[int, str], + "func_optional_typing": typing.Optional[int], + "func_msgspec_struct": SimpleStruct, + "func_msgspec_struct_complex": ComplexStruct, + "func_child_complex_struct": ChildComplexStruct, + "func_return_complex_struct": None, + } + self._rpc_return_type_map = { + "func_none": str, + "func_bool": str, + "func_int": str, + "func_float": str, + "func_str": str, + "func_bytes": str, + "func_bytearray": str, + "func_tuple": str, + "func_list": str, + "func_dict": str, + "func_optional_dict": Optional[str], + "func_set": str, + "func_frozenset": str, + "func_datetime": str, + "func_date": str, + "func_time": str, + "func_uuid": str, + "func_decimal": str, + "func_enum": str, + "func_intenum": str, + "func_dataclass": str, + "func_tuple_typing": str, + "func_list_typing": str, + "func_dict_typing": str, + "func_set_typing": str, + "func_frozenset_typing": str, + "func_any_typing": str, + "func_union_typing": str, + "func_optional_typing": str, + "func_msgspec_struct": str, + "func_msgspec_struct_complex": str, + "func_child_complex_struct": str, + "func_return_complex_struct": ComplexStruct, + } + + def test_codegen(self): + codegen = CodeGen( + self._rpc_router, self._rpc_input_type_map, self._rpc_return_type_map + ) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +from dataclasses import dataclass +from datetime import date, datetime, time +import decimal +import enum +import msgspec +from msgspec import Struct +from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union +import uuid + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleEnum(enum.Enum): + ONE = 1 + TWO = 2 + + +class SimpleIntEnum(enum.IntEnum): + ONE = 1 + TWO = 2 + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_none(selfarg: None) -> str: + return self._zero_client.call("func_none", None) + + def func_bool(self, arg: bool) -> str: + return self._zero_client.call("func_bool", arg) + + def func_int(self, arg: int) -> str: + return self._zero_client.call("func_int", arg) + + def func_float(self, arg: float) -> str: + return self._zero_client.call("func_float", arg) + + def func_str(self, arg: str) -> str: + return self._zero_client.call("func_str", arg) + + def func_bytes(self, arg: bytes) -> str: + return self._zero_client.call("func_bytes", arg) + + def func_bytearray(self, arg: bytearray) -> str: + return self._zero_client.call("func_bytearray", arg) + + def func_tuple(self, arg: tuple) -> str: + return self._zero_client.call("func_tuple", arg) + + def func_list(self, arg: list) -> str: + return self._zero_client.call("func_list", arg) + + def func_dict(self, arg: dict) -> str: + return self._zero_client.call("func_dict", arg) + + def func_optional_dict(self, arg: Optional[dict]) -> str: + return self._zero_client.call("func_optional_dict", arg) + + def func_set(self, arg: set) -> str: + return self._zero_client.call("func_set", arg) + + def func_frozenset(self, arg: frozenset) -> str: + return self._zero_client.call("func_frozenset", arg) + + def func_datetime(self, arg: datetime) -> str: + return self._zero_client.call("func_datetime", arg) + + def func_date(self, arg: date) -> str: + return self._zero_client.call("func_date", arg) + + def func_time(self, arg: time) -> str: + return self._zero_client.call("func_time", arg) + + def func_uuid(self, arg: uuid.UUID) -> str: + return self._zero_client.call("func_uuid", arg) + + def func_decimal(self, arg: decimal.Decimal) -> str: + return self._zero_client.call("func_decimal", arg) + + def func_enum(self, arg: SimpleEnum) -> str: + return self._zero_client.call("func_enum", arg) + + def func_intenum(self, arg: SimpleIntEnum) -> str: + return self._zero_client.call("func_intenum", arg) + + def func_dataclass(self, arg: SimpleDataclass) -> str: + return self._zero_client.call("func_dataclass", arg) + + def func_tuple_typing(self, arg: Tuple[int, str]) -> str: + return self._zero_client.call("func_tuple_typing", arg) + + def func_list_typing(self, arg: List[int]) -> str: + return self._zero_client.call("func_list_typing", arg) + + def func_dict_typing(self, arg: Dict[str, int]) -> str: + return self._zero_client.call("func_dict_typing", arg) + + def func_set_typing(self, arg: Set[int]) -> str: + return self._zero_client.call("func_set_typing", arg) + + def func_frozenset_typing(self, arg: FrozenSet[int]) -> str: + return self._zero_client.call("func_frozenset_typing", arg) + + def func_any_typing(self, arg: Any) -> str: + return self._zero_client.call("func_any_typing", arg) + + def func_union_typing(self, arg: Union[int, str]) -> str: + return self._zero_client.call("func_union_typing", arg) + + def func_optional_typing(self, arg: Optional[int]) -> str: + return self._zero_client.call("func_optional_typing", arg) + + def func_msgspec_struct(self, arg: SimpleStruct) -> str: + return self._zero_client.call("func_msgspec_struct", arg) + + def func_msgspec_struct_complex(self, arg: ComplexStruct) -> str: + return self._zero_client.call("func_msgspec_struct_complex", arg) + + def func_child_complex_struct(self, arg: ChildComplexStruct) -> str: + return self._zero_client.call("func_child_complex_struct", arg) + + def func_return_complex_struct(self) -> ComplexStruct: + return self._zero_client.call("func_return_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_return_single_complex_struct(self): + rpc_router = { + "func_return_complex_struct": (func_return_complex_struct, False), + } + rpc_input_type_map = { + "func_return_complex_struct": None, + } + rpc_return_type_map = { + "func_return_complex_struct": ComplexStruct, + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_return_complex_struct(self) -> ComplexStruct: + return self._zero_client.call("func_return_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_return_optional_complex_struct(self): + rpc_router = { + "func_return_optional_child_complex_struct": ( + func_return_optional_child_complex_struct, + False, + ), + } + rpc_input_type_map = { + "func_return_optional_child_complex_struct": None, + } + rpc_return_type_map = { + "func_return_optional_child_complex_struct": Optional[ChildComplexStruct], + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Optional, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_return_optional_child_complex_struct(self) -> Optional[ChildComplexStruct]: + return self._zero_client.call("func_return_optional_child_complex_struct", None) +""" + self.assertEqual(code, expected_code) + + def test_codegen_optional_child_dataclass_return_optional_child_complex_struct( + self, + ): + rpc_router = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": ( + func_take_optional_child_dataclass_return_optional_child_complex_struct, + False, + ), + } + rpc_input_type_map = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[ + ChildDataclass + ], + } + rpc_return_type_map = { + "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[ + ChildComplexStruct + ], + } + codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map) + code = codegen.generate_code() + expected_code = """# Generated by Zero +# import types as per needed, not all imports are shown here +from dataclasses import dataclass +import enum +import msgspec +from msgspec import Struct +from typing import Dict, List, Optional, Tuple, Union + +from zero import ZeroClient + + +zero_client = ZeroClient("localhost", 5559) + +@dataclass +class SimpleDataclass: + a: int + b: str + + +@dataclass +class ChildDataclass(SimpleDataclass): + e: int + f: str + + +class SimpleStruct(Struct): + h: int + i: str + + +@dataclass +class SimpleDataclass2: + c: int + d: str + + +class ComplexStruct(msgspec.Struct): + a: int + b: str + c: SimpleStruct + d: List[SimpleStruct] + e: Dict[str, SimpleStruct] + f: Tuple[SimpleDataclass, SimpleStruct] + g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2] + + +class ChildComplexStruct(ComplexStruct): + h: int + i: str + + + +class RpcClient: + def __init__(self, zero_client: ZeroClient): + self._zero_client = zero_client + + def func_take_optional_child_dataclass_return_optional_child_complex_struct(self, + arg: Optional[ChildDataclass], +) -> Optional[ChildComplexStruct]: + return self._zero_client.call("func_take_optional_child_dataclass_return_optional_child_complex_struct", arg) +""" + self.assertEqual(code, expected_code) diff --git a/tests/functional/single_server/client_generation_test.py b/tests/functional/single_server/client_generation_test.py index 22d3fc7..da1b33c 100644 --- a/tests/functional/single_server/client_generation_test.py +++ b/tests/functional/single_server/client_generation_test.py @@ -17,18 +17,129 @@ def test_codegeneration(): assert ( code == """# Generated by Zero -# import types as per needed +# import types as per needed, not all imports are shown here +from dataclasses import dataclass +from datetime import date, datetime, time +import decimal +import enum +import msgspec +from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union +import uuid from zero import ZeroClient zero_client = ZeroClient("localhost", 5559) +class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +class ColorInt(enum.IntEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +@dataclass +class Dataclass: + name: str + age: int + + +class Message(msgspec.Struct): + msg: str + start_time: datetime + + class RpcClient: def __init__(self, zero_client: ZeroClient): self._zero_client = zero_client + def echo_bool(self, msg: bool) -> bool: + return self._zero_client.call("echo_bool", msg) + + def echo_int(self, msg: int) -> int: + return self._zero_client.call("echo_int", msg) + + def echo_float(self, msg: float) -> float: + return self._zero_client.call("echo_float", msg) + + def echo_str(self, msg: str) -> str: + return self._zero_client.call("echo_str", msg) + + def echo_bytes(self, msg: bytes) -> bytes: + return self._zero_client.call("echo_bytes", msg) + + def echo_bytearray(self, msg: bytearray) -> bytearray: + return self._zero_client.call("echo_bytearray", msg) + + def echo_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]: + return self._zero_client.call("echo_tuple", msg) + + def echo_list(self, msg: List[int]) -> List[int]: + return self._zero_client.call("echo_list", msg) + + def echo_dict(self, msg: Dict[int, str]) -> Dict[int, str]: + return self._zero_client.call("echo_dict", msg) + + def echo_set(self, msg: Set[int]) -> Set[int]: + return self._zero_client.call("echo_set", msg) + + def echo_frozenset(self, msg: FrozenSet[int]) -> FrozenSet[int]: + return self._zero_client.call("echo_frozenset", msg) + + def echo_datetime(self, msg: datetime) -> datetime: + return self._zero_client.call("echo_datetime", msg) + + def echo_date(self, msg: date) -> date: + return self._zero_client.call("echo_date", msg) + + def echo_time(self, msg: time) -> time: + return self._zero_client.call("echo_time", msg) + + def echo_uuid(self, msg: uuid.UUID) -> uuid.UUID: + return self._zero_client.call("echo_uuid", msg) + + def echo_decimal(self, msg: decimal.Decimal) -> decimal.Decimal: + return self._zero_client.call("echo_decimal", msg) + + def echo_enum(self, msg: Color) -> Color: + return self._zero_client.call("echo_enum", msg) + + def echo_enum_int(self, msg: ColorInt) -> ColorInt: + return self._zero_client.call("echo_enum_int", msg) + + def echo_dataclass(self, msg: Dataclass) -> Dataclass: + return self._zero_client.call("echo_dataclass", msg) + + def echo_typing_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]: + return self._zero_client.call("echo_typing_tuple", msg) + + def echo_typing_list(self, msg: List[int]) -> List[int]: + return self._zero_client.call("echo_typing_list", msg) + + def echo_typing_dict(self, msg: Dict[int, str]) -> Dict[int, str]: + return self._zero_client.call("echo_typing_dict", msg) + + def echo_typing_set(self, msg: Set[int]) -> Set[int]: + return self._zero_client.call("echo_typing_set", msg) + + def echo_typing_frozenset(self, msg: FrozenSet[int]) -> FrozenSet[int]: + return self._zero_client.call("echo_typing_frozenset", msg) + + def echo_typing_union(self, msg: Union[int, str]) -> Union[int, str]: + return self._zero_client.call("echo_typing_union", msg) + + def echo_typing_optional(self, msg: Optional[int]) -> int: + return self._zero_client.call("echo_typing_optional", msg) + + def echo_msgspec_struct(self, msg: Message) -> Message: + return self._zero_client.call("echo_msgspec_struct", msg) + def sleep(self, msec: int) -> str: return self._zero_client.call("sleep", msec) @@ -38,9 +149,6 @@ def sleep_async(self, msec: int) -> str: def error(self, msg: str) -> str: return self._zero_client.call("error", msg) - def msgspec_struct(self, start: datetime.datetime) -> Message: - return self._zero_client.call("msgspec_struct", start) - def send_bytes(self, msg: bytes) -> bytes: return self._zero_client.call("send_bytes", msg) @@ -53,19 +161,10 @@ def hello_world(self) -> str: def decode_jwt(self, msg: str) -> str: return self._zero_client.call("decode_jwt", msg) - def sum_list(self, msg: typing.List[int]) -> int: + def sum_list(self, msg: List[int]) -> int: return self._zero_client.call("sum_list", msg) - def echo_dict(self, msg: typing.Dict[int, str]) -> typing.Dict[int, str]: - return self._zero_client.call("echo_dict", msg) - - def echo_tuple(self, msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]: - return self._zero_client.call("echo_tuple", msg) - - def echo_union(self, msg: typing.Union[int, str]) -> typing.Union[int, str]: - return self._zero_client.call("echo_union", msg) - - def divide(self, msg: typing.Tuple[int, int]) -> int: + def divide(self, msg: Tuple[int, int]) -> int: return self._zero_client.call("divide", msg) """ ) diff --git a/tests/functional/single_server/client_server_test.py b/tests/functional/single_server/client_server_test.py index 6d79246..b29a809 100644 --- a/tests/functional/single_server/client_server_test.py +++ b/tests/functional/single_server/client_server_test.py @@ -1,4 +1,7 @@ import datetime +import decimal +import typing +import uuid import pytest import requests @@ -11,14 +14,193 @@ from .server import Message -def test_hello_world(): - zero_client = ZeroClient(server.HOST, server.PORT) +@pytest.fixture +def zero_client(): + return ZeroClient(server.HOST, server.PORT) + + +# bool input +def test_echo_bool(zero_client): + assert zero_client.call("echo_bool", True) is True + + +# int input +def test_echo_int(zero_client): + assert zero_client.call("echo_int", 42) == 42 + + +# float input +def test_echo_float(zero_client): + assert zero_client.call("echo_float", 3.14) == 3.14 + + +# str input +def test_echo_str(zero_client): + assert zero_client.call("echo_str", "hello") == "hello" + + +# bytes input +def test_echo_bytes(zero_client): + assert zero_client.call("echo_bytes", b"hello") == b"hello" + + +# bytearray input +def test_echo_bytearray(zero_client): + assert zero_client.call("echo_bytearray", bytearray(b"hello")) == bytearray( + b"hello" + ) + + +# tuple input +def test_echo_tuple(zero_client): + assert zero_client.call("echo_tuple", (1, "a"), return_type=tuple) == (1, "a") + + +# list input +def test_echo_list(zero_client): + assert zero_client.call("echo_list", [1, 2, 3]) == [1, 2, 3] + + +# dict input +def test_echo_dict(zero_client): + assert zero_client.call("echo_dict", {1: "a"}) == {1: "a"} + + +# set input +def test_echo_set(zero_client): + assert zero_client.call("echo_set", {1, 2, 3}, return_type=set) == {1, 2, 3} + + +# frozenset input +def test_echo_frozenset(zero_client): + assert zero_client.call( + "echo_frozenset", frozenset({1, 2, 3}), return_type=frozenset + ) == frozenset({1, 2, 3}) + + +# datetime input +def test_echo_datetime(zero_client): + now = datetime.datetime.now() + assert zero_client.call("echo_datetime", now, return_type=datetime.datetime) == now + + +# date input +def test_echo_date(zero_client): + today = datetime.date.today() + assert zero_client.call("echo_date", today, return_type=datetime.date) == today + + +# time input +def test_echo_time(zero_client): + now = datetime.datetime.now().time() + assert zero_client.call("echo_time", now, return_type=datetime.time) == now + + +# uuid input +def test_echo_uuid(zero_client): + uid = uuid.uuid4() + assert zero_client.call("echo_uuid", uid, return_type=uuid.UUID) == uid + + +# decimal input +def test_echo_decimal(zero_client): + value = decimal.Decimal("10.1") + assert zero_client.call("echo_decimal", value, return_type=decimal.Decimal) == value + + +# enum input +def test_echo_enum(zero_client): + assert ( + zero_client.call("echo_enum", server.Color.RED, return_type=server.Color) + == server.Color.RED + ) + + +# enum int input +def test_echo_enum_int(zero_client): + assert ( + zero_client.call("echo_enum_int", server.ColorInt.GREEN) + == server.ColorInt.GREEN + ) + + +# dataclass input +def test_echo_dataclass(zero_client): + data = server.Dataclass(name="John", age=30) + result = zero_client.call("echo_dataclass", data, return_type=server.Dataclass) + assert result == data + + +# typing.Tuple input +def test_echo_typing_tuple(zero_client): + assert zero_client.call( + "echo_typing_tuple", (1, "a"), return_type=typing.Tuple + ) == (1, "a") + + +# typing.List input +def test_echo_typing_list(zero_client): + assert zero_client.call("echo_typing_list", [1, 2, 3]) == [1, 2, 3] + + +# typing.Dict input +def test_echo_typing_dict(zero_client): + assert zero_client.call("echo_typing_dict", {1: "a"}, return_type=typing.Dict) == { + 1: "a" + } + + +# typing.Set input +def test_echo_typing_set(zero_client): + assert zero_client.call("echo_typing_set", {1, 2, 3}, return_type=typing.Set) == { + 1, + 2, + 3, + } + + +# typing.FrozenSet input +def test_echo_typing_frozenset(zero_client): + assert zero_client.call( + "echo_typing_frozenset", frozenset({1, 2, 3}), return_type=typing.FrozenSet + ) == frozenset({1, 2, 3}) + + +# typing.Union input +def test_echo_typing_union(zero_client): + assert ( + zero_client.call("echo_typing_union", 1, return_type=typing.Union[str, int]) + == 1 + ) + assert ( + zero_client.call("echo_typing_union", "a", return_type=typing.Union[str, int]) + == "a" + ) + + +# typing.Optional input +def test_echo_typing_optional(zero_client): + assert zero_client.call("echo_typing_optional", None) == 0 + assert ( + zero_client.call("echo_typing_optional", 1, return_type=typing.Optional[int]) + == 1 + ) + + +# msgspec.Struct input +def test_echo_msgspec_struct(zero_client): + msg = server.Message(msg="hello world", start_time=datetime.datetime.now()) + result = zero_client.call("echo_msgspec_struct", msg, return_type=server.Message) + assert result.msg == msg.msg + assert result.start_time == msg.start_time + + +def test_hello_world(zero_client): msg = zero_client.call("hello_world", "") assert msg == "hello world" -def test_necho(): - zero_client = ZeroClient(server.HOST, server.PORT) +def test_necho(zero_client): with pytest.raises(zero.error.MethodNotFoundException): msg = zero_client.call("necho", "hello") assert msg is None @@ -31,20 +213,12 @@ def test_echo_wrong_port(): assert msg is None -def test_sum_list(): - zero_client = ZeroClient(server.HOST, server.PORT) +def test_sum_list(zero_client): msg = zero_client.call("sum_list", [1, 2, 3]) assert msg == 6 -def test_echo_dict(): - zero_client = ZeroClient(server.HOST, server.PORT) - msg = zero_client.call("echo_dict", {1: "b"}) - assert msg == {1: "b"} - - -def test_echo_dict_validation_error(): - zero_client = ZeroClient(server.HOST, server.PORT) +def test_echo_dict_validation_error(zero_client): with pytest.raises(ValidationException): msg = zero_client.call("echo_dict", {"a": "b"}) assert msg == { @@ -52,16 +226,14 @@ def test_echo_dict_validation_error(): } -def test_echo_tuple(): - zero_client = ZeroClient(server.HOST, server.PORT) +def test_echo_tuple_2(zero_client): msg = zero_client.call("echo_tuple", (1, "a")) assert isinstance(msg, list) # IMPORTANT assert msg == [1, "a"] -def test_echo_union(): - zero_client = ZeroClient(server.HOST, server.PORT) - msg = zero_client.call("echo_union", 1) +def test_echo_union(zero_client): + msg = zero_client.call("echo_typing_union", 1) assert msg == 1 @@ -106,11 +278,11 @@ class Example: def test_msgspec_struct(): - now = datetime.datetime.now() + msg = Message("hello world", datetime.datetime.now()) zero_client = ZeroClient(server.HOST, server.PORT) - msg = zero_client.call("msgspec_struct", now, return_type=Message) + msg = zero_client.call("echo_msgspec_struct", msg, return_type=Message) assert msg.msg == "hello world" - assert msg.start_time == now + assert msg.start_time == msg.start_time def test_send_bytes(): @@ -121,18 +293,18 @@ def test_send_bytes(): def test_send_http_request(): with pytest.raises(requests.exceptions.ReadTimeout): - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) def test_server_works_after_multiple_http_requests(): """Because of this issue https://github.com/Ananto30/zero/issues/41""" try: - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) - requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) + requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1) except requests.exceptions.ReadTimeout: pass zero_client = ZeroClient(server.HOST, server.PORT) diff --git a/tests/functional/single_server/server.py b/tests/functional/single_server/server.py index 7d4d278..f6940c2 100644 --- a/tests/functional/single_server/server.py +++ b/tests/functional/single_server/server.py @@ -1,7 +1,11 @@ import asyncio import datetime +import decimal +import enum import time import typing +import uuid +from dataclasses import dataclass import jwt import msgspec @@ -14,36 +18,210 @@ app = ZeroServer(port=PORT) -async def echo(msg: str) -> str: +# None input +async def hello_world() -> str: + return "hello world" + + +# bool input +@app.register_rpc +def echo_bool(msg: bool) -> bool: return msg -async def hello_world() -> str: - return "hello world" +# int input +@app.register_rpc +def echo_int(msg: int) -> int: + return msg -async def decode_jwt(msg: str) -> str: - encoded_jwt = jwt.encode(msg, "secret", algorithm="HS256") # type: ignore - decoded_jwt = jwt.decode(encoded_jwt, "secret", algorithms=["HS256"]) - return decoded_jwt # type: ignore +# float input +@app.register_rpc +def echo_float(msg: float) -> float: + return msg -def sum_list(msg: typing.List[int]) -> int: - return sum(msg) +# str input +@app.register_rpc +def echo_str(msg: str) -> str: + return msg -def echo_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]: +# bytes input +@app.register_rpc +def echo_bytes(msg: bytes) -> bytes: return msg +# bytearray input +@app.register_rpc +def echo_bytearray(msg: bytearray) -> bytearray: + return msg + + +# tuple input +@app.register_rpc def echo_tuple(msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]: return msg -def echo_union(msg: typing.Union[int, str]) -> typing.Union[int, str]: +# list input +@app.register_rpc +def echo_list(msg: typing.List[int]) -> typing.List[int]: + return msg + + +# dict input +@app.register_rpc +def echo_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]: return msg +# set input +@app.register_rpc +def echo_set(msg: typing.Set[int]) -> typing.Set[int]: + return msg + + +# frozenset input +@app.register_rpc +def echo_frozenset(msg: typing.FrozenSet[int]) -> typing.FrozenSet[int]: + return msg + + +# datetime input +@app.register_rpc +def echo_datetime(msg: datetime.datetime) -> datetime.datetime: + return msg + + +# date input +@app.register_rpc +def echo_date(msg: datetime.date) -> datetime.date: + return msg + + +# time input +@app.register_rpc +def echo_time(msg: datetime.time) -> datetime.time: + return msg + + +# uuid input +@app.register_rpc +def echo_uuid(msg: uuid.UUID) -> uuid.UUID: + return msg + + +# decimal input +@app.register_rpc +def echo_decimal(msg: decimal.Decimal) -> decimal.Decimal: + return msg + + +# enum input +class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +@app.register_rpc +def echo_enum(msg: Color) -> Color: + return msg + + +# enum int input +class ColorInt(enum.IntEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +@app.register_rpc +def echo_enum_int(msg: ColorInt) -> ColorInt: + return msg + + +# dataclass input +@dataclass +class Dataclass: + name: str + age: int + + +@app.register_rpc +def echo_dataclass(msg: Dataclass) -> Dataclass: + return msg + + +# typing.Tuple input +@app.register_rpc +def echo_typing_tuple(msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]: + return msg + + +# typing.List input +@app.register_rpc +def echo_typing_list(msg: typing.List[int]) -> typing.List[int]: + return msg + + +# typing.Dict input +@app.register_rpc +def echo_typing_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]: + return msg + + +# typing.Set input +@app.register_rpc +def echo_typing_set(msg: typing.Set[int]) -> typing.Set[int]: + return msg + + +# typing.FrozenSet input +@app.register_rpc +def echo_typing_frozenset(msg: typing.FrozenSet[int]) -> typing.FrozenSet[int]: + return msg + + +# typing.Union input +@app.register_rpc +def echo_typing_union(msg: typing.Union[int, str]) -> typing.Union[int, str]: + return msg + + +# typing.Optional input +@app.register_rpc +def echo_typing_optional(msg: typing.Optional[int]) -> int: + return msg or 0 + + +# msgspec.Struct input +class Message(msgspec.Struct): + msg: str + start_time: datetime.datetime + + +@app.register_rpc +def echo_msgspec_struct(msg: Message) -> Message: + return msg + + +async def echo(msg: str) -> str: + return msg + + +async def decode_jwt(msg: str) -> str: + encoded_jwt = jwt.encode(msg, "secret", algorithm="HS256") # type: ignore + decoded_jwt = jwt.decode(encoded_jwt, "secret", algorithms=["HS256"]) + return decoded_jwt # type: ignore + + +def sum_list(msg: typing.List[int]) -> int: + return sum(msg) + + def divide(msg: typing.Tuple[int, int]) -> int: return int(msg[0] / msg[1]) @@ -69,16 +247,6 @@ def error(msg: str) -> str: raise RuntimeError(msg) -class Message(msgspec.Struct): - msg: str - start_time: datetime.datetime - - -@app.register_rpc -def msgspec_struct(start: datetime.datetime) -> Message: - return Message(msg="hello world", start_time=start) - - @app.register_rpc def send_bytes(msg: bytes) -> bytes: return msg @@ -90,9 +258,6 @@ def run(port): app.register_rpc(hello_world) app.register_rpc(decode_jwt) app.register_rpc(sum_list) - app.register_rpc(echo_dict) - app.register_rpc(echo_tuple) - app.register_rpc(echo_union) app.register_rpc(divide) app.run(2) diff --git a/tests/functional/test_async_to_sync.py b/tests/functional/test_async_to_sync.py new file mode 100644 index 0000000..de585aa --- /dev/null +++ b/tests/functional/test_async_to_sync.py @@ -0,0 +1,51 @@ +import asyncio + +import pytest + +from zero.utils.async_to_sync import async_to_sync + + +# Test case 1: Test a simple async function +async def simple_async_function(x): + await asyncio.sleep(0.1) # Simulate async work + return x * 2 + + +def test_simple_async_function(): + sync_function = async_to_sync(simple_async_function) + result = sync_function(5) + assert result == 10, "The async function should return 10 when called with 5" + + +# Test case 2: Test an async function that raises an exception +async def async_function_raises_exception(): + raise ValueError("This is a test exception") + + +def test_async_function_exception(): + sync_function = async_to_sync(async_function_raises_exception) + with pytest.raises(ValueError) as exc_info: + sync_function() + assert ( + str(exc_info.value) == "This is a test exception" + ), "The exception message should be 'This is a test exception'" + + +# Test case 3: Test the reusability of async_to_sync for multiple functions +async def another_simple_async_function(x): + await asyncio.sleep(0.1) # Simulate async work + return x + 100 + + +def test_reusability_of_async_to_sync(): + sync_function_1 = async_to_sync(simple_async_function) + result_1 = sync_function_1(5) + assert ( + result_1 == 10 + ), "The first async function should return 10 when called with 5" + + sync_function_2 = async_to_sync(another_simple_async_function) + result_2 = sync_function_2(5) + assert ( + result_2 == 105 + ), "The second async function should return 105 when called with 5" diff --git a/tests/unit/test_type_util.py b/tests/unit/test_type_util.py new file mode 100644 index 0000000..813fe0b --- /dev/null +++ b/tests/unit/test_type_util.py @@ -0,0 +1,134 @@ +import unittest +from typing import Optional +from unittest.mock import MagicMock + +from zero.utils.type_util import ( + get_function_input_class, + get_function_return_class, + verify_function_args, + verify_function_input_type, + verify_function_return, + verify_function_return_type, +) + + +class TestVerifyFunctionReturnType(unittest.TestCase): + def test_valid_return_type(self): + def func() -> int: + return 1 + + verify_function_return_type(func) + + def test_none_return_type(self): + def func() -> None: + return None + + with self.assertRaises(TypeError): + verify_function_return_type(func) + + def test_optional_return_type(self): + def func() -> Optional[int]: + return None + + with self.assertRaises(TypeError): + verify_function_return_type(func) + + def test_invalid_return_type(self): + class CustomType: + pass + + def func() -> CustomType: + return CustomType() + + with self.assertRaises(TypeError): + verify_function_return_type(func) + + def test_mocked_return_type(self): + def func() -> MagicMock: + return MagicMock() + + with self.assertRaises(TypeError): + verify_function_return_type(func) + + def test__verify_function_args__ok(self): + def func(a: int) -> int: + return a + + verify_function_args(func) + + def test__verify_function_args__multiple_args(self): + def func(a: int, b: int) -> int: + return a + b + + with self.assertRaises(ValueError): + verify_function_args(func) + + def test__verify_function_args__no_type_hint(self): + def func(a): + return a + + with self.assertRaises(TypeError): + verify_function_args(func) + + def test__verify_function_return__ok(self): + def func() -> int: + return 1 + + verify_function_return(func) + + def test__verify_function_return__no_type_hint(self): + def func(): + return 1 + + with self.assertRaises(TypeError): + verify_function_return(func) + + def test__get_function_input_class__ok(self): + def func(a: int) -> int: + return a + + self.assertEqual(get_function_input_class(func), int) + + def test__get_function_input_class__no_args(self): + def func() -> int: + return 1 + + self.assertEqual(get_function_input_class(func), None) + + def test__get_function_input_class__multiple_args(self): + def func(a: int, b: int) -> int: + return a + b + + self.assertEqual(get_function_input_class(func), None) + + def test__get_function_return_class__ok(self): + def func() -> int: + return 1 + + self.assertEqual(get_function_return_class(func), int) + + def test__get_function_return_class__no_return(self): + def func(): + return 1 + + self.assertEqual(get_function_return_class(func), None) + + def test__verify_function_input_type__ok(self): + def func(a: int) -> int: + return a + + verify_function_input_type(func) + + def test__verify_function_input_type__invalid(self): + def func(a: MagicMock) -> int: + return a + + with self.assertRaises(TypeError): + verify_function_input_type(func) + + def test__verify_function_input_type__no_type_hint(self): + def func(a) -> int: + return a + + with self.assertRaises(KeyError): + verify_function_input_type(func) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py new file mode 100644 index 0000000..125eaa3 --- /dev/null +++ b/tests/unit/test_util.py @@ -0,0 +1,21 @@ +import logging +import unittest +from unittest.mock import patch + +from zero.utils.util import log_error + + +class TestLogError(unittest.TestCase): + def test_log_error(self): + @log_error + def divide(a, b): + return a / b + + with patch.object(logging, "exception") as mock_exception: + result = divide(10, 2) + self.assertEqual(result, 5) + mock_exception.assert_not_called() + + result = divide(10, 0) + self.assertIsNone(result) + mock_exception.assert_called_once() diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index efa0da4..0ce97d9 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -1,6 +1,10 @@ import unittest from unittest.mock import MagicMock, Mock, patch +import msgspec + +from zero.encoder.protocols import Encoder +from zero.error import SERVER_PROCESSING_ERROR from zero.protocols.zeromq.worker import _Worker @@ -55,6 +59,26 @@ def test_start_dealer_worker_exception_handling(self, mock_get_worker): self.assertIn("Test Exception", log.output[0]) mock_worker.close.assert_called_once() + @patch("zero.protocols.zeromq.worker.get_worker") + def test_start_dealer_worker_keyboard_interrupt_handling(self, mock_get_worker): + mock_worker = Mock() + mock_get_worker.return_value = mock_worker + mock_worker.listen.side_effect = KeyboardInterrupt + + worker_id = 1 + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + + with self.assertLogs(level="WARNING") as log: + worker.start_dealer_worker(worker_id) + self.assertIn("terminating worker", log.output[0]) + mock_worker.close.assert_called_once() + @patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x) def test_handle_msg_get_rpc_contract(self, mock_async_to_sync): worker = _Worker( @@ -70,7 +94,7 @@ def test_handle_msg_get_rpc_contract(self, mock_async_to_sync): with patch.object( worker, "generate_rpc_contract", return_value=expected_response ) as mock_generate_rpc_contract: - response = worker.handle_msg("get_rpc_contract", msg) + response = worker.execute_rpc("get_rpc_contract", msg) mock_generate_rpc_contract.assert_called_once_with(msg) self.assertEqual(response, expected_response) @@ -89,7 +113,7 @@ def test_handle_msg_rpc_call_exception(self, mock_async_to_sync): self.rpc_return_type_map, ) - response = worker.handle_msg("failing_function", "msg") + response = worker.execute_rpc("failing_function", "msg") self.assertEqual( response, {"__zerror__server_exception": "Exception('RPC Exception')"} ) @@ -105,7 +129,7 @@ def test_handle_msg_connect(self): msg = "some_message" expected_response = "connected" - response = worker.handle_msg("connect", msg) + response = worker.execute_rpc("connect", msg) self.assertEqual(response, expected_response) @@ -122,7 +146,7 @@ def test_handle_msg_function_not_found(self): "__zerror__function_not_found": "Function `some_function_not_found` not found!" } - response = worker.handle_msg("some_function_not_found", msg) + response = worker.execute_rpc("some_function_not_found", msg) self.assertEqual(response, expected_response) @@ -143,7 +167,7 @@ def test_handle_msg_server_exception(self): "zero.protocols.zeromq.worker.async_to_sync", side_effect=Exception("Exception occurred"), ): - response = worker.handle_msg("some_function", msg) + response = worker.execute_rpc("some_function", msg) self.assertEqual(response, expected_response) @@ -219,3 +243,99 @@ def test_spawn_worker(self): rpc_return_type_map, ) mock_worker.start_dealer_worker.assert_called_once_with(worker_id) + + +def some_function(msg: str) -> str: + return msg + + +class TestWorkerHandleMsg(unittest.TestCase): + def setUp(self): + self.rpc_router = { + "get_rpc_contract": (MagicMock(), False), + "connect": (MagicMock(), False), + "some_function": (some_function, False), + } + self.device_comm_channel = "tcp://example.com:5555" + self.encoder = MagicMock(spec=Encoder) + self.rpc_input_type_map = { + "some_function": str, + } + self.rpc_return_type_map = { + "some_function": str, + } + + def test_handle_msg_with_valid_input(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + func_name_encoded = b"some_function" + data = self.encoder.encode("msg_data") + + worker.execute_rpc = Mock() + worker.execute_rpc.return_value = "response" + self.encoder.decode_type.return_value = "msg_data" + + response = worker.handle_msg(func_name_encoded, data) + + worker.execute_rpc.assert_called_once_with( + func_name_encoded.decode(), "msg_data" + ) + self.encoder.encode.assert_called_with("response") + self.assertEqual(response, self.encoder.encode.return_value) + + def test_handle_msg_with_validation_error(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + func_name_encoded = b"some_function" + data = b"msg_data" + expected_error = "__zerror__validation_error" + expected_error_message = "Validation Error" + expected_encoded_error = b"encoded_error" + + self.encoder.decode_type.side_effect = msgspec.ValidationError( + expected_error_message + ) + self.encoder.encode.return_value = expected_encoded_error + + response = worker.handle_msg(func_name_encoded, data) + + self.encoder.decode_type.assert_called_once_with(data, str) + self.encoder.encode.assert_called_once_with( + {expected_error: expected_error_message} + ) + self.assertEqual(response, expected_encoded_error) + self.assertEqual(response, expected_encoded_error) + + def test_handle_msg_with_server_exception(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + func_name_encoded = b"some_function" + data = self.encoder.encode("msg_data") + + worker.execute_rpc = Mock() + worker.execute_rpc.side_effect = Exception("Server Exception") + self.encoder.decode_type.return_value = "msg_data" + + worker.handle_msg(func_name_encoded, data) + + worker.execute_rpc.assert_called_once_with( + func_name_encoded.decode(), "msg_data" + ) + self.encoder.encode.assert_called_with( + {"__zerror__server_exception": SERVER_PROCESSING_ERROR} + ) diff --git a/zero/codegen/codegen.py b/zero/codegen/codegen.py index 9ffd6f4..b7449e0 100644 --- a/zero/codegen/codegen.py +++ b/zero/codegen/codegen.py @@ -1,86 +1,248 @@ +import datetime +import decimal +import enum import inspect - -# from pydantic import BaseModel +import sys +import uuid +from dataclasses import is_dataclass +from typing import ( + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +import msgspec + +from zero.utils.type_util import typing_types + +python_version = sys.version_info class CodeGen: - def __init__(self, rpc_router, rpc_input_type_map, rpc_return_type_map): + def __init__( + self, + rpc_router: Dict[str, Tuple[Callable, bool]], + rpc_input_type_map: Dict[str, Optional[type]], + rpc_return_type_map: Dict[str, Optional[type]], + ): self._rpc_router = rpc_router self._rpc_input_type_map = rpc_input_type_map self._rpc_return_type_map = rpc_return_type_map - self._typing_imports = set() + + # for imports + self._typing_imports: List[str] = [ + str(typ).replace("typing.", "") for typ in typing_types + ] + self._typing_imports.sort() + self._datetime_imports: Set[str] = set() + self._has_uuid = False + self._has_decimal = False + self._has_enum = True def generate_code(self, host="localhost", port=5559): code = f"""# Generated by Zero -# import types as per needed +# import types as per needed, not all imports are shown here from zero import ZeroClient zero_client = ZeroClient("{host}", {port}) - +""" + code += self.generate_models() + code += """ class RpcClient: def __init__(self, zero_client: ZeroClient): self._zero_client = zero_client """ for func_name in self._rpc_router: + input_param_name = ( + None + if self._rpc_input_type_map[func_name] is None + else self.get_function_input_param_name(func_name) + ) code += f""" {self.get_function_str(func_name)} - return self._zero_client.call("{func_name}", { - None if self._rpc_input_type_map[func_name] is None - else self.get_function_input_param_name(func_name) - }) + return self._zero_client.call("{func_name}", {input_param_name}) """ - # self.generate_data_classes() TODO: next feature + + # add imports after first 2 lines + code_lines = code.split("\n") + code_lines.insert(2, self.get_imports(code)) + code = "\n".join(code_lines) + + if "typing." in code: + code = code.replace("typing.", "") + if "@dataclasses.dataclass" in code: + code = code.replace("@dataclasses.dataclass", "@dataclass") + if "datetime.datetime" in code: + code = code.replace("datetime.datetime", "datetime") + if "datetime.date" in code: + code = code.replace("datetime.date", "date") + if "datetime.time" in code: + code = code.replace("datetime.time", "time") + return code - def get_imports(self): - return f"from typing import {', '.join(i for i in self._typing_imports)}" + def get_imports(self, code): + for func_name in self._rpc_input_type_map: + input_type = self._rpc_input_type_map[func_name] + self._track_imports(input_type) - def get_input_type_str(self, func_name: str): # pragma: no cover - if self._rpc_input_type_map[func_name] is None: - return "" - if self._rpc_input_type_map[func_name].__module__ == "typing": - type_name = self._rpc_input_type_map[func_name]._name - self._typing_imports.add(type_name) - return ": " + type_name - return ": " + self._rpc_input_type_map[func_name].__name__ - - def get_return_type_str(self, func_name: str): # pragma: no cover - if self._rpc_return_type_map[func_name].__module__ == "typing": - type_name = self._rpc_return_type_map[func_name]._name - self._typing_imports.add(type_name) - return type_name - return self._rpc_return_type_map[func_name].__name__ + for typ in list(self._typing_imports): + if typ + "[" not in code: + self._typing_imports.remove(typ) + + import_lines = [] + + if "@dataclasses.dataclass" in code or "@dataclass" in code: + import_lines.append("from dataclasses import dataclass") + + if self._datetime_imports: + import_lines.append( + "from datetime import " + ", ".join(sorted(self._datetime_imports)) + ) + + if self._has_decimal: + import_lines.append("import decimal") + if self._has_enum: + import_lines.append("import enum") + + if "(msgspec.Struct)" in code: + import_lines.append("import msgspec") + + if "(Struct)" in code: + import_lines.append("from msgspec import Struct") + + if self._typing_imports: + import_lines.append("from typing import " + ", ".join(self._typing_imports)) + + if self._has_uuid: + import_lines.append("import uuid") + + return "\n".join(import_lines) + + def _track_imports(self, input_type): + if not input_type: + return + if input_type in (datetime.datetime, datetime.date, datetime.time): + self._datetime_imports.add(input_type.__name__) + elif input_type == uuid.UUID: + self._has_uuid = True + elif input_type == decimal.Decimal: + self._has_decimal = True def get_function_str(self, func_name: str): func = self._rpc_router[func_name][0] func_lines = inspect.getsourcelines(func)[0] - def_line = [line for line in func_lines if "def" in line][0] + func_str = "".join(func_lines) + # from def to -> + def_str = func_str.split("def")[1].split("->")[0].strip() + def_str = "def " + def_str - # put self after the first ( - def_line = def_line.replace(f"{func_name}(", f"{func_name}(self").replace( - "async ", "" - ) + # Insert 'self' as the first parameter + insert_index = def_str.index("(") + 1 + if self._rpc_input_type_map[func_name]: # If there is input, add 'self, ' + def_str = def_str[:insert_index] + "self, " + def_str[insert_index:] + else: # If there is no input, just add 'self' + def_str = def_str[:insert_index] + "self" + def_str[insert_index:] - # if there is input, add comma after self - if self._rpc_input_type_map[func_name]: - def_line = def_line.replace(f"{func_name}(self", f"{func_name}(self, ") + # from -> to : + return_type_str = func_str.split("->")[1].split(":")[0].strip() + # add return type + def_str = def_str + f" -> {return_type_str}:" - return def_line.replace("\n", "") + return def_str.strip() def get_function_input_param_name(self, func_name: str): func = self._rpc_router[func_name][0] func_lines = inspect.getsourcelines(func)[0] - def_line = [line for line in func_lines if "def" in line][0] - params = def_line.split("(")[1].split(")")[0] - return params.split(":")[0].strip() - - # def generate_data_classes(self): - # code = "" - # for func_name in self._rpc_input_type_map: - # input_class = self._rpc_input_type_map[func_name] - # if input_class and is_pydantic(input_class): - # code += inspect.getsource(input_class) + func_str = "".join(func_lines) + # from bracket to bracket + input_param_name = func_str.split("(")[1].split(")")[0] + # everything until : + input_param_name = input_param_name.split(":")[0] + return input_param_name.strip() + + def _generate_class_code(self, cls: Type, already_generated: Set[Type]) -> str: + if cls in already_generated: + return "" + + code = self._generate_code_for_bases(cls, already_generated) + code += self._generate_code_for_fields(cls, already_generated) + + if python_version >= (3, 9): + code += inspect.getsource(cls) + "\n\n" + else: + # python 3.8 doesnt return @dataclass decorator + if is_dataclass(cls): + code += f"@dataclass\n{inspect.getsource(cls)}\n\n" + else: + code += inspect.getsource(cls) + "\n\n" + + already_generated.add(cls) + return code + + def _generate_code_for_bases(self, cls: Type, already_generated: Set[Type]) -> str: + code = "" + for base_cls in cls.__bases__: + if issubclass(base_cls, msgspec.Struct) and base_cls is not msgspec.Struct: + code += self._generate_class_code(base_cls, already_generated) + elif is_dataclass(base_cls): + code += self._generate_class_code(base_cls, already_generated) + return code + + def _generate_code_for_fields(self, cls: Type, already_generated: Set[Type]) -> str: + code = "" + for field_type in get_type_hints(cls).values(): + code += self._generate_code_for_type(field_type, already_generated) + return code + + def _generate_code_for_type(self, typ: Type, already_generated: Set[Type]) -> str: + code = "" + typs = self._resolve_field_type(typ) + for it in typs: + self._track_imports(it) + if isinstance(it, type) and ( + issubclass(it, (msgspec.Struct, enum.Enum, enum.IntEnum)) + or is_dataclass(it) + ): + code += self._generate_class_code(it, already_generated) + return code + + def _resolve_field_type(self, field_type) -> List[Type]: + origin = get_origin(field_type) + if origin in (list, tuple, set, frozenset, Optional): + return [get_args(field_type)[0]] + elif origin == dict: + return [get_args(field_type)[1]] + elif origin == Union: + return list(get_args(field_type)) + + return [field_type] + + def generate_models(self) -> str: + already_generated: Set[Type] = set() + code = "" + + merged_types = list(self._rpc_input_type_map.values()) + list( + self._rpc_return_type_map.values() + ) + # retain order and remove duplicates + merged_types = list(dict.fromkeys(merged_types)) + + for input_type in merged_types: + if input_type is None: + continue + code += self._generate_code_for_type(input_type, already_generated) + + return code diff --git a/zero/config.py b/zero/config.py index bd89c09..0ce5feb 100644 --- a/zero/config.py +++ b/zero/config.py @@ -11,7 +11,6 @@ RESERVED_FUNCTIONS = ["get_rpc_contract", "connect", "__server_info__"] ZEROMQ_PATTERN = "proxy" -ENCODER = "msgspec" SUPPORTED_PROTOCOLS = { "zeromq": { "server": ZMQServer, diff --git a/zero/encoder/__init__.py b/zero/encoder/__init__.py index 6b48645..67da175 100644 --- a/zero/encoder/__init__.py +++ b/zero/encoder/__init__.py @@ -1,2 +1,3 @@ -from .factory import get_encoder from .protocols import Encoder + +__all__ = ["Encoder"] diff --git a/zero/encoder/factory.py b/zero/encoder/factory.py deleted file mode 100644 index 1ed02b2..0000000 --- a/zero/encoder/factory.py +++ /dev/null @@ -1,9 +0,0 @@ -from .msgspc import MsgspecEncoder -from .protocols import Encoder - - -def get_encoder(name: str) -> Encoder: - if name == "msgspec": - return MsgspecEncoder() - - raise ValueError(f"unknown encoder: {name}") diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py index 1b80791..da185c0 100644 --- a/zero/protocols/zeromq/client.py +++ b/zero/protocols/zeromq/client.py @@ -1,9 +1,11 @@ import logging import threading -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, Optional, Type, TypeVar from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder +from zero.utils.type_util import AllowedType from zero.zeromq_patterns import ( AsyncZeroMQClient, ZeroMQClient, @@ -23,7 +25,7 @@ def __init__( ): self._address = address self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self.client_pool = ZMQClientPool( self._address, @@ -34,7 +36,7 @@ def __init__( def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -65,7 +67,7 @@ def __init__( ): self._address = address self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self.client_pool = AsyncZMQClientPool( self._address, @@ -76,7 +78,7 @@ def __init__( async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -114,7 +116,7 @@ def __init__( self._pool: Dict[int, ZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() def get(self) -> ZeroMQClient: thread_id = threading.get_ident() @@ -146,7 +148,7 @@ def __init__( self._pool: Dict[int, AsyncZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() async def get(self) -> AsyncZeroMQClient: thread_id = threading.get_ident() diff --git a/zero/protocols/zeromq/worker.py b/zero/protocols/zeromq/worker.py index 99b1ebe..7722667 100644 --- a/zero/protocols/zeromq/worker.py +++ b/zero/protocols/zeromq/worker.py @@ -1,7 +1,7 @@ import asyncio import logging import time -from typing import Optional +from typing import Any, Optional from msgspec import ValidationError @@ -37,43 +37,48 @@ def __init__( ) def start_dealer_worker(self, worker_id): - def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]: - try: - func_name = func_name_encoded.decode() - input_type = self._rpc_input_type_map.get(func_name) - - msg = "" - if data: - if input_type: - msg = self._encoder.decode_type(data, input_type) - else: - msg = self._encoder.decode(data) - - response = self.handle_msg(func_name, msg) - return self._encoder.encode(response) - except ValidationError as exc: - logging.exception(exc) - return self._encoder.encode({"__zerror__validation_error": str(exc)}) - except Exception as inner_exc: # pylint: disable=broad-except - logging.exception(inner_exc) - return self._encoder.encode( - {"__zerror__server_exception": SERVER_PROCESSING_ERROR} - ) - worker = get_worker(config.ZEROMQ_PATTERN, worker_id) try: - worker.listen(self._device_comm_channel, process_message) + worker.listen(self._device_comm_channel, self.handle_msg) + except KeyboardInterrupt: logging.warning( "Caught KeyboardInterrupt, terminating worker %d", worker_id ) + except Exception as exc: # pylint: disable=broad-except logging.exception(exc) + finally: logging.warning("Closing worker %d", worker_id) worker.close() - def handle_msg(self, rpc, msg): + def handle_msg(self, func_name_encoded: bytes, data: bytes) -> Optional[bytes]: + try: + func_name = func_name_encoded.decode() + input_type = self._rpc_input_type_map.get(func_name) + + msg = "" + if data: + if input_type: + msg = self._encoder.decode_type(data, input_type) + else: + msg = self._encoder.decode(data) + + response = self.execute_rpc(func_name, msg) + return self._encoder.encode(response) + + except ValidationError as exc: + logging.exception(exc) + return self._encoder.encode({"__zerror__validation_error": str(exc)}) + + except Exception as inner_exc: # pylint: disable=broad-except + logging.exception(inner_exc) + return self._encoder.encode( + {"__zerror__server_exception": SERVER_PROCESSING_ERROR} + ) + + def execute_rpc(self, rpc: str, msg: Any): if rpc == "get_rpc_contract": return self.generate_rpc_contract(msg) @@ -88,10 +93,11 @@ def handle_msg(self, rpc, msg): ret = None try: - if is_coro: - ret = async_to_sync(func)(msg) if msg else async_to_sync(func)() + func_to_call = async_to_sync(func) if is_coro else func + if self._rpc_input_type_map.get(rpc): + ret = func_to_call(msg) else: - ret = func(msg) if msg else func() + ret = func_to_call() except Exception as exc: # pylint: disable=broad-except logging.exception(exc) @@ -102,6 +108,7 @@ def handle_msg(self, rpc, msg): def generate_rpc_contract(self, msg): try: return self.codegen.generate_code(msg[0], msg[1]) + except Exception as exc: # pylint: disable=broad-except logging.exception(exc) return {"__zerror__failed_to_generate_client_code": str(exc)} diff --git a/zero/rpc/client.py b/zero/rpc/client.py index 77d54e8..3fd2507 100644 --- a/zero/rpc/client.py +++ b/zero/rpc/client.py @@ -1,8 +1,10 @@ -from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Optional, Type, TypeVar from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder from zero.error import MethodNotFoundException, RemoteException, ValidationException +from zero.utils.type_util import AllowedType if TYPE_CHECKING: from zero.rpc.protocols import AsyncZeroClientProtocol, ZeroClientProtocol @@ -55,7 +57,7 @@ def __init__( """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self._client_inst: "ZeroClientProtocol" = self._determine_client_cls(protocol)( self._address, self._default_timeout, @@ -79,7 +81,7 @@ def _determine_client_cls(self, protocol: str) -> Type["ZeroClientProtocol"]: def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: @@ -173,7 +175,7 @@ def __init__( """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() self._client_inst: "AsyncZeroClientProtocol" = self._determine_client_cls( "zeromq" )( @@ -199,7 +201,7 @@ def _determine_client_cls(self, protocol: str) -> Type["AsyncZeroClientProtocol" async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py index 3752229..a4f8f5f 100644 --- a/zero/rpc/protocols.py +++ b/zero/rpc/protocols.py @@ -6,17 +6,17 @@ Tuple, Type, TypeVar, - Union, runtime_checkable, ) from zero.encoder import Encoder +from zero.utils.type_util import AllowedType T = TypeVar("T") @runtime_checkable -class ZeroServerProtocol(Protocol): # pragma: no cover +class ZeroServerProtocol(Protocol): def __init__( self, address: str, @@ -35,7 +35,7 @@ def stop(self): @runtime_checkable -class ZeroClientProtocol(Protocol): # pragma: no cover +class ZeroClientProtocol(Protocol): def __init__( self, address: str, @@ -47,7 +47,7 @@ def __init__( def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: @@ -58,7 +58,7 @@ def close(self): @runtime_checkable -class AsyncZeroClientProtocol(Protocol): # pragma: no cover +class AsyncZeroClientProtocol(Protocol): def __init__( self, address: str, @@ -70,7 +70,7 @@ def __init__( async def call( self, rpc_func_name: str, - msg: Union[int, float, str, dict, list, tuple, None], + msg: AllowedType, timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> Optional[T]: diff --git a/zero/rpc/server.py b/zero/rpc/server.py index f9cbc14..5347d3f 100644 --- a/zero/rpc/server.py +++ b/zero/rpc/server.py @@ -1,10 +1,21 @@ import logging import os from asyncio import iscoroutinefunction -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Optional, + Tuple, + Type, + Union, +) from zero import config -from zero.encoder import Encoder, get_encoder +from zero.encoder import Encoder +from zero.encoder.msgspc import MsgspecEncoder from zero.utils import type_util if TYPE_CHECKING: @@ -47,7 +58,7 @@ def __init__( self._address = f"tcp://{self._host}:{self._port}" # to encode/decode messages from/to client - self._encoder = encoder or get_encoder(config.ENCODER) + self._encoder = encoder or MsgspecEncoder() # Stores rpc functions against their names # and if they are coroutines @@ -79,7 +90,7 @@ def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]: ) return server_cls - def register_rpc(self, func: Callable): + def register_rpc(self, func: Callable[..., Union[Any, Coroutine]]): """ Register a function available for clients. Function should have a single argument. diff --git a/zero/utils/type_util.py b/zero/utils/type_util.py index fd69d4c..256bc84 100644 --- a/zero/utils/type_util.py +++ b/zero/utils/type_util.py @@ -4,7 +4,18 @@ import enum import typing import uuid -from typing import Callable, Optional, get_origin, get_type_hints +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Optional, + Protocol, + Type, + Union, + get_origin, + get_type_hints, +) import msgspec @@ -17,15 +28,10 @@ bytes, bytearray, tuple, - typing.Tuple, list, - typing.List, dict, - typing.Dict, set, - typing.Set, frozenset, - typing.FrozenSet, ] std_lib_types: typing.List = [ @@ -40,20 +46,61 @@ ] typing_types: typing.List = [ - typing.Any, + typing.Tuple, + typing.List, + typing.Dict, + typing.Set, + typing.FrozenSet, typing.Union, typing.Optional, ] msgspec_types: typing.List = [ msgspec.Struct, - msgspec.Raw, ] allowed_types = builtin_types + std_lib_types + typing_types +class IsDataclass(Protocol): + # as already noted in comments, checking for this attribute is currently + # the most reliable way to ascertain that something is a dataclass + __dataclass_fields__: ClassVar[Dict[str, Any]] + + +AllowedType = Union[ + None, + bool, + int, + float, + str, + bytes, + bytearray, + tuple, + list, + dict, + set, + frozenset, + datetime.datetime, + datetime.date, + datetime.time, + uuid.UUID, + decimal.Decimal, + enum.Enum, + enum.IntEnum, + IsDataclass, + typing.Tuple, + typing.List, + typing.Dict, + typing.Set, + typing.FrozenSet, + msgspec.Struct, + Type[enum.Enum], # For enum classes + Type[enum.IntEnum], # For int enum classes +] + + def verify_function_args(func: Callable) -> None: arg_count = func.__code__.co_argcount if arg_count < 1: @@ -73,13 +120,6 @@ def verify_function_args(func: Callable) -> None: def verify_function_return(func: Callable) -> None: - return_count = func.__code__.co_argcount - if return_count > 1: - raise ValueError( - f"`{func.__name__}` has more than 1 return values; " - "RPC functions can have only one return value" - ) - types = get_type_hints(func) if not types.get("return"): raise TypeError( @@ -106,17 +146,10 @@ def get_function_return_class(func: Callable): def verify_function_input_type(func: Callable): input_type = get_function_input_class(func) - if input_type in allowed_types: - return - origin_type = get_origin(input_type) - if origin_type is not None and origin_type in allowed_types: + if is_allowed_type(input_type): return - for mtype in msgspec_types: - if input_type is not None and issubclass(input_type, mtype): - return - raise TypeError( f"{func.__name__} has type {input_type} which is not allowed; " "allowed types are: \n" + "\n".join([str(t) for t in allowed_types]) @@ -125,16 +158,21 @@ def verify_function_input_type(func: Callable): def verify_function_return_type(func: Callable): return_type = get_function_return_class(func) - if return_type in allowed_types: - return - origin_type = get_origin(return_type) - if origin_type is not None and origin_type in allowed_types: - return + # None is not allowed as return type + if return_type is None: + raise TypeError( + f"{func.__name__} returns None; RPC functions must return a value" + ) - for typ in msgspec_types: - if issubclass(return_type, typ): - return + # Optional is not allowed as return type + if get_origin(return_type) == typing.Union and type(None) in return_type.__args__: + raise TypeError( + f"{func.__name__} returns Optional; RPC functions must return a value" + ) + + if is_allowed_type(return_type): + return raise TypeError( f"{func.__name__} has return type {return_type} which is not allowed; " @@ -151,22 +189,22 @@ def verify_allowed_type(msg, rpc_method: Optional[str] = None): ) -def verify_incoming_rpc_call_input_type( - msg, rpc_method: str, rpc_input_type_map: dict -): # pragma: no cover - input_type = rpc_input_type_map[rpc_method] - if input_type is None: - return +def is_allowed_type(typ: Type): + if typ in allowed_types: + return True + + if str(typ).startswith("