diff --git a/.gitignore b/.gitignore index d023e69..d7bf0c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__/ .vscode venv/ +venv* venv-pypy/ .pytest_cache .idea diff --git a/benchmarks/dockerize/aiohttp/Dockerfile b/benchmarks/dockerize/aiohttp/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/aiohttp/Dockerfile +++ b/benchmarks/dockerize/aiohttp/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/aiohttp/requirements.txt b/benchmarks/dockerize/aiohttp/requirements.txt index 20007c1..15cbcca 100644 --- a/benchmarks/dockerize/aiohttp/requirements.txt +++ b/benchmarks/dockerize/aiohttp/requirements.txt @@ -1,7 +1,6 @@ aiohttp gunicorn PyJWT -aioredis +redis>=4.2.0rc1 uvloop -msgpack -redis \ No newline at end of file +msgpack \ No newline at end of file diff --git a/benchmarks/dockerize/aiohttp/shared.py b/benchmarks/dockerize/aiohttp/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/aiohttp/shared.py +++ b/benchmarks/dockerize/aiohttp/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/dockerize/aiozmq/Dockerfile b/benchmarks/dockerize/aiozmq/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/aiozmq/Dockerfile +++ b/benchmarks/dockerize/aiozmq/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/aiozmq/requirements.txt b/benchmarks/dockerize/aiozmq/requirements.txt index 709ac6c..b0b5ece 100644 --- a/benchmarks/dockerize/aiozmq/requirements.txt +++ b/benchmarks/dockerize/aiozmq/requirements.txt @@ -1,9 +1,8 @@ aiozmq PyJWT -aioredis +redis>=4.2.0rc1 uvloop msgpack aiohttp gunicorn -redis sanic \ No newline at end of file diff --git a/benchmarks/dockerize/aiozmq/shared.py b/benchmarks/dockerize/aiozmq/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/aiozmq/shared.py +++ b/benchmarks/dockerize/aiozmq/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/dockerize/blacksheep/Dockerfile b/benchmarks/dockerize/blacksheep/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/blacksheep/Dockerfile +++ b/benchmarks/dockerize/blacksheep/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/blacksheep/requirements.txt b/benchmarks/dockerize/blacksheep/requirements.txt index 5e7db40..4306756 100644 --- a/benchmarks/dockerize/blacksheep/requirements.txt +++ b/benchmarks/dockerize/blacksheep/requirements.txt @@ -1,7 +1,6 @@ blacksheep uvicorn PyJWT -aioredis +redis>=4.2.0rc1 uvloop -msgpack -redis \ No newline at end of file +msgpack \ No newline at end of file diff --git a/benchmarks/dockerize/blacksheep/shared.py b/benchmarks/dockerize/blacksheep/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/blacksheep/shared.py +++ b/benchmarks/dockerize/blacksheep/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/dockerize/fast_api/Dockerfile b/benchmarks/dockerize/fast_api/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/fast_api/Dockerfile +++ b/benchmarks/dockerize/fast_api/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/fast_api/requirements.txt b/benchmarks/dockerize/fast_api/requirements.txt index d1935fa..aa0bcf6 100644 --- a/benchmarks/dockerize/fast_api/requirements.txt +++ b/benchmarks/dockerize/fast_api/requirements.txt @@ -1,8 +1,7 @@ fastapi uvicorn PyJWT -aioredis +redis>=4.2.0rc1 uvloop msgpack -aiohttp -redis \ No newline at end of file +aiohttp \ No newline at end of file diff --git a/benchmarks/dockerize/fast_api/shared.py b/benchmarks/dockerize/fast_api/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/fast_api/shared.py +++ b/benchmarks/dockerize/fast_api/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/dockerize/sanic/Dockerfile b/benchmarks/dockerize/sanic/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/sanic/Dockerfile +++ b/benchmarks/dockerize/sanic/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/sanic/requirements.txt b/benchmarks/dockerize/sanic/requirements.txt index 6ae6432..83a2bb4 100644 --- a/benchmarks/dockerize/sanic/requirements.txt +++ b/benchmarks/dockerize/sanic/requirements.txt @@ -1,7 +1,6 @@ sanic PyJWT -aioredis +redis>=4.2.0rc1 uvloop msgpack -aiohttp -redis \ No newline at end of file +aiohttp \ No newline at end of file diff --git a/benchmarks/dockerize/sanic/shared.py b/benchmarks/dockerize/sanic/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/sanic/shared.py +++ b/benchmarks/dockerize/sanic/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/dockerize/zero/Dockerfile b/benchmarks/dockerize/zero/Dockerfile index 521bd59..03ee481 100755 --- a/benchmarks/dockerize/zero/Dockerfile +++ b/benchmarks/dockerize/zero/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY . . RUN pip install -r requirements.txt \ No newline at end of file diff --git a/benchmarks/dockerize/zero/requirements.txt b/benchmarks/dockerize/zero/requirements.txt index 78da89c..8406191 100644 --- a/benchmarks/dockerize/zero/requirements.txt +++ b/benchmarks/dockerize/zero/requirements.txt @@ -1,10 +1,9 @@ zeroapi PyJWT -aioredis +redis>=4.2.0rc1 uvloop aiohttp gunicorn -redis sanic msgpack uvicorn diff --git a/benchmarks/dockerize/zero/shared.py b/benchmarks/dockerize/zero/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/dockerize/zero/shared.py +++ b/benchmarks/dockerize/zero/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/local/zero/Dockerfile b/benchmarks/local/zero/Dockerfile index dcd64fc..77944ff 100755 --- a/benchmarks/local/zero/Dockerfile +++ b/benchmarks/local/zero/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.11-slim COPY requirements.txt requirements.txt RUN pip install -r requirements.txt diff --git a/benchmarks/local/zero/requirements.txt b/benchmarks/local/zero/requirements.txt index 8526f03..7189534 100644 --- a/benchmarks/local/zero/requirements.txt +++ b/benchmarks/local/zero/requirements.txt @@ -1,11 +1,10 @@ pyzmq msgspec PyJWT -aioredis +redis>=4.2.0rc1 uvloop aiohttp gunicorn -redis sanic msgpack blacksheep diff --git a/benchmarks/local/zero/shared.py b/benchmarks/local/zero/shared.py index 2fb4b47..d021efe 100644 --- a/benchmarks/local/zero/shared.py +++ b/benchmarks/local/zero/shared.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -import aioredis import msgpack import redis +from redis import asyncio as aioredis @dataclass diff --git a/benchmarks/others/requirements.txt b/benchmarks/others/requirements.txt index ed5880f..1a86cda 100644 --- a/benchmarks/others/requirements.txt +++ b/benchmarks/others/requirements.txt @@ -7,7 +7,6 @@ httpx fastapi sanic PyJWT -redis -aioredis @ git+https://github.com/aio-libs/aioredis@ff5a8fe068ebda837d14c3b3777a6182e610854a +redis>=4.2.0rc1 uvloop pydantic diff --git a/tests/concurrency/rps_async.py b/tests/concurrency/rps_async.py index bd4f1a0..3f2e12c 100644 --- a/tests/concurrency/rps_async.py +++ b/tests/concurrency/rps_async.py @@ -19,11 +19,11 @@ async def task(semaphore, items): async def process_tasks(items_chunk): - conc = 8 + conc = 16 semaphore = asyncio.BoundedSemaphore(conc) tasks = [task(semaphore, items) for items in items_chunk] await asyncio.gather(*tasks) - await async_client.close() + async_client.close() def run_chunk(items_chunk): diff --git a/tests/concurrency/single_req.py b/tests/concurrency/single_req.py deleted file mode 100644 index 66d02cd..0000000 --- a/tests/concurrency/single_req.py +++ /dev/null @@ -1,10 +0,0 @@ -from zero import ZeroClient - -client = ZeroClient("localhost", 5559) - -if __name__ == "__main__": - for i in range(10): - res = client.call("sleep", 100) - if res != "slept for 100 msecs": - print(f"expected: slept for 100 msecs, got: {res}") - print(res) diff --git a/tests/concurrency/single_req_async.py b/tests/concurrency/single_req_async.py new file mode 100644 index 0000000..ad855d9 --- /dev/null +++ b/tests/concurrency/single_req_async.py @@ -0,0 +1,26 @@ +import asyncio + +from zero import AsyncZeroClient, ZeroClient + +client = ZeroClient("localhost", 5559) +async_client = AsyncZeroClient("localhost", 5559) + +# Create a semaphore outside of the task function +semaphore = asyncio.BoundedSemaphore(32) + + +async def task(sleep_time, i): + # Use the semaphore as an async context manager to limit concurrency + async with semaphore: + res = await async_client.call("sleep", sleep_time) + assert res == f"slept for {sleep_time} msecs" + print(res, i) + + +async def main(): + tasks = [task(200, i) for i in range(500)] + await asyncio.gather(*tasks) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/concurrency/single_req_sync.py b/tests/concurrency/single_req_sync.py new file mode 100644 index 0000000..f386ca5 --- /dev/null +++ b/tests/concurrency/single_req_sync.py @@ -0,0 +1,9 @@ +from zero import ZeroClient + +client = ZeroClient("localhost", 5559) + +if __name__ == "__main__": + for i in range(50): + res = client.call("sleep", 200) + assert res == "slept for 200 msecs" + print(res) diff --git a/tests/concurrency/sleep_test_async.py b/tests/concurrency/sleep_test_async.py index efc4283..e6ab7ae 100644 --- a/tests/concurrency/sleep_test_async.py +++ b/tests/concurrency/sleep_test_async.py @@ -1,18 +1,9 @@ import asyncio import random import time -from contextlib import contextmanager from zero import AsyncZeroClient - -@contextmanager -def get_client(): - client = AsyncZeroClient("localhost", 5559) - yield client - client.close() - - async_client = AsyncZeroClient("localhost", 5559) @@ -25,7 +16,7 @@ async def task(semaphore, sleep_time): async def test(): - conc = 10 + conc = 32 semaphore = asyncio.BoundedSemaphore(conc) sleep_times = [] diff --git a/tests/concurrency/sleep_test_sync.py b/tests/concurrency/sleep_test_sync.py index 057205d..4c44d06 100644 --- a/tests/concurrency/sleep_test_sync.py +++ b/tests/concurrency/sleep_test_sync.py @@ -20,7 +20,7 @@ def get_and_print(msg): resp = func(msg) if resp != f"slept for {msg} msecs": print(f"expected: slept for {msg} msecs, got: {resp}") - print(resp) + # print(resp) if __name__ == "__main__": diff --git a/tests/functional/single_server/client_test.py b/tests/functional/single_server/client_test.py index 8f5ce41..462f07a 100644 --- a/tests/functional/single_server/client_test.py +++ b/tests/functional/single_server/client_test.py @@ -118,6 +118,24 @@ def test_timeout_all(): assert msg is None +def test_timeout_all_async(): + client = AsyncZeroClient(server.HOST, server.PORT) + + with pytest.raises(zero.error.TimeoutException): + msg = asyncio.run(client.call("sleep", 1000, timeout=10)) + assert msg is None + + with pytest.raises(zero.error.TimeoutException): + msg = asyncio.run(client.call("sleep", 1000, timeout=200)) + assert msg is None + + # the server is 2 cores, so even if the timeout is greater, + # server couldn't complete the last 2 calls and will timeout + with pytest.raises(zero.error.TimeoutException): + msg = asyncio.run(client.call("sleep", 50, timeout=300)) + assert msg is None + + # TODO fix server is blocked until a long running call is completed # def test_one_call_should_not_affect_another(): # client = ZeroClient(server.HOST, server.PORT) @@ -145,8 +163,12 @@ def test_timeout_all(): def test_random_timeout(): client = ZeroClient(server.HOST, server.PORT) + fails = 0 + should_fail = 0 for _ in range(100): sleep_time = random.randint(10, 100) + # error margin of 10 ms + should_fail += sleep_time > 60 try: msg = client.call("sleep", sleep_time, timeout=50) assert msg == f"slept for {sleep_time} msecs" @@ -154,13 +176,22 @@ def test_random_timeout(): assert ( sleep_time > 1 ) # considering network latency, 50 msecs is too low in github actions + fails += 1 + + client.close() + + assert fails >= should_fail def test_random_timeout_async(): client = AsyncZeroClient(server.HOST, server.PORT) + fails = 0 + should_fail = 0 for _ in range(100): sleep_time = random.randint(10, 100) + # error margin of 10 ms + should_fail += sleep_time > 60 try: msg = asyncio.run(client.call("sleep", sleep_time, timeout=50)) assert msg == f"slept for {sleep_time} msecs" @@ -168,22 +199,25 @@ def test_random_timeout_async(): assert ( sleep_time > 1 ) # considering network latency, 50 msecs is too low in github actions + fails += 1 client.close() + assert fails >= should_fail -@pytest.mark.asyncio -async def test_async_sleep(): - client = AsyncZeroClient(server.HOST, server.PORT) - async def task(sleep_time): - res = await client.call("sleep", sleep_time) - assert res == f"slept for {sleep_time} msecs" +# @pytest.mark.asyncio +# async def test_async_sleep(): +# client = AsyncZeroClient(server.HOST, server.PORT) - start = time.time() - tasks = [task(200) for _ in range(5)] - await asyncio.gather(*tasks) - end = time.time() - time_taken_ms = 1e3 * (end - start) +# async def task(sleep_time): +# res = await client.call("sleep_async", sleep_time) +# assert res == f"slept for {sleep_time} msecs" + +# tasks = [task(200) for _ in range(5)] + +# start = time.perf_counter() +# await asyncio.gather(*tasks) +# time_taken_ms = (time.perf_counter() - start) * 1000 - assert time_taken_ms < 1000 +# assert time_taken_ms < 1000 diff --git a/zero/codegen/codegen.py b/zero/codegen/codegen.py index b7449e0..31015e3 100644 --- a/zero/codegen/codegen.py +++ b/zero/codegen/codegen.py @@ -182,7 +182,7 @@ def _generate_class_code(self, cls: Type, already_generated: Set[Type]) -> str: if python_version >= (3, 9): code += inspect.getsource(cls) + "\n\n" - else: + else: # pragma: no cover # python 3.8 doesnt return @dataclass decorator if is_dataclass(cls): code += f"@dataclass\n{inspect.getsource(cls)}\n\n" @@ -209,23 +209,23 @@ def _generate_code_for_fields(self, cls: Type, already_generated: Set[Type]) -> def _generate_code_for_type(self, typ: Type, already_generated: Set[Type]) -> str: code = "" - typs = self._resolve_field_type(typ) - for it in typs: - self._track_imports(it) - if isinstance(it, type) and ( - issubclass(it, (msgspec.Struct, enum.Enum, enum.IntEnum)) - or is_dataclass(it) + all_possible_typs = self._resolve_field_type(typ) + for possible_typ in all_possible_typs: + self._track_imports(possible_typ) + if isinstance(possible_typ, type) and ( + issubclass(possible_typ, (msgspec.Struct, enum.Enum, enum.IntEnum)) + or is_dataclass(possible_typ) ): - code += self._generate_class_code(it, already_generated) + code += self._generate_class_code(possible_typ, already_generated) return code def _resolve_field_type(self, field_type) -> List[Type]: origin = get_origin(field_type) if origin in (list, tuple, set, frozenset, Optional): return [get_args(field_type)[0]] - elif origin == dict: + if origin == dict: return [get_args(field_type)[1]] - elif origin == Union: + if origin == Union: return list(get_args(field_type)) return [field_type] diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py index da185c0..7661d98 100644 --- a/zero/protocols/zeromq/client.py +++ b/zero/protocols/zeromq/client.py @@ -20,18 +20,12 @@ class ZMQClient: def __init__( self, address: str, - default_timeout: int = 2000, - encoder: Optional[Encoder] = None, + default_timeout: int, + encoder: Encoder, ): - self._address = address - self._default_timeout = default_timeout self._encoder = encoder or MsgspecEncoder() - self.client_pool = ZMQClientPool( - self._address, - self._default_timeout, - self._encoder, - ) + self.client_pool = ZMQClientPool(address, default_timeout) def call( self, @@ -62,18 +56,12 @@ class AsyncZMQClient: def __init__( self, address: str, - default_timeout: int = 2000, - encoder: Optional[Encoder] = None, + default_timeout: int, + encoder: Encoder, ): - self._address = address - self._default_timeout = default_timeout - self._encoder = encoder or MsgspecEncoder() + self._encoder = encoder - self.client_pool = AsyncZMQClientPool( - self._address, - self._default_timeout, - self._encoder, - ) + self.client_pool = AsyncZMQClientPool(address, default_timeout) async def call( self, @@ -108,15 +96,12 @@ class ZMQClientPool: If the connection is not available, it creates a new connection and stores it in the pool. """ - __slots__ = ["_pool", "_address", "_timeout", "_encoder"] + __slots__ = ["_pool", "_address", "_timeout"] - def __init__( - self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None - ): + def __init__(self, address: str, timeout: int): self._pool: Dict[int, ZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or MsgspecEncoder() def get(self) -> ZeroMQClient: thread_id = threading.get_ident() @@ -140,15 +125,12 @@ class AsyncZMQClientPool: If the connection is not available, it creates a new connection and stores it in the pool. """ - __slots__ = ["_pool", "_address", "_timeout", "_encoder"] + __slots__ = ["_pool", "_address", "_timeout"] - def __init__( - self, address: str, timeout: int = 2000, encoder: Optional[Encoder] = None - ): + def __init__(self, address: str, timeout: int): self._pool: Dict[int, AsyncZeroMQClient] = {} self._address = address self._timeout = timeout - self._encoder = encoder or MsgspecEncoder() async def get(self) -> AsyncZeroMQClient: thread_id = threading.get_ident() diff --git a/zero/rpc/client.py b/zero/rpc/client.py index 1595aba..22b857b 100644 --- a/zero/rpc/client.py +++ b/zero/rpc/client.py @@ -141,6 +141,7 @@ def __init__( port: int, default_timeout: int = 2000, encoder: Optional[Encoder] = None, + protocol: str = "zeromq", ): """ AsyncZeroClient provides the asynchronous client interface for calling the ZeroServer. @@ -165,19 +166,25 @@ def __init__( Port of the ZeroServer. default_timeout: int - Default timeout for all calls. Default is 2000 ms. + Default timeout for all calls in milliseconds. + Default is 2000 milliseconds (2 seconds). encoder: Optional[Encoder] Encoder to encode/decode messages from/to client. Default is msgspec. If any other encoder is used, the server should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. + + protocol: str + Protocol to use for communication. + Default is zeromq. + If any other protocol is used, the server should use the same protocol. """ self._address = f"tcp://{host}:{port}" self._default_timeout = default_timeout self._encoder = encoder or MsgspecEncoder() self._client_inst: "AsyncZeroClientProtocol" = self._determine_client_cls( - "zeromq" + protocol )( self._address, self._default_timeout, @@ -246,8 +253,9 @@ async def call( Or zeromq cannot receive the response from the server. Mainly represents zmq.error.Again exception. """ + _timeout = timeout or self._default_timeout resp_data = await self._client_inst.call( - rpc_func_name, msg, timeout, return_type + rpc_func_name, msg, _timeout, return_type ) check_response(resp_data) return resp_data diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py index a4f8f5f..f335d1b 100644 --- a/zero/rpc/protocols.py +++ b/zero/rpc/protocols.py @@ -76,5 +76,5 @@ async def call( ) -> Optional[T]: ... - async def close(self): + def close(self): ... diff --git a/zero/utils/util.py b/zero/utils/util.py index be234c4..4295cda 100644 --- a/zero/utils/util.py +++ b/zero/utils/util.py @@ -1,8 +1,8 @@ import logging +import os import signal import socket import sys -import time import uuid from typing import Callable @@ -53,17 +53,8 @@ def unique_id() -> str: return str(uuid.uuid4()).replace("-", "") -def current_time_us() -> int: - """ - Get current time in microseconds. - - Returns - ------- - int - Current time in microseconds. - - """ - return int(time.time() * 1e6) +def unique_id_bytes() -> bytes: + return os.urandom(16) def register_signal_term(sigterm_handler: Callable): diff --git a/zero/zeromq_patterns/factory.py b/zero/zeromq_patterns/factory.py index e5a1abd..ee0352a 100644 --- a/zero/zeromq_patterns/factory.py +++ b/zero/zeromq_patterns/factory.py @@ -10,7 +10,7 @@ def get_client(pattern: str, default_timeout: int = 2000) -> ZeroMQClient: raise ValueError(f"Invalid pattern: {pattern}") -def get_async_client(pattern: str, default_timeout: int = 2000) -> AsyncZeroMQClient: +def get_async_client(pattern: str, default_timeout: int) -> AsyncZeroMQClient: if pattern == "proxy": return queue_device.AsyncZeroMQClient(default_timeout) diff --git a/zero/zeromq_patterns/queue_device/client.py b/zero/zeromq_patterns/queue_device/client.py index f159c39..535d5e9 100644 --- a/zero/zeromq_patterns/queue_device/client.py +++ b/zero/zeromq_patterns/queue_device/client.py @@ -1,7 +1,6 @@ import asyncio import logging import sys -from asyncio import Event from typing import Dict, Optional import zmq @@ -13,8 +12,8 @@ class ZeroMQClient: - def __init__(self, default_timeout): - self._address = None + def __init__(self, default_timeout: int): + self._address = "" self._default_timeout = default_timeout self._context = zmq.Context.instance() @@ -29,12 +28,12 @@ def __init__(self, default_timeout): def connect(self, address: str) -> None: self._address = address self.socket.connect(address) - self._send(util.unique_id().encode() + b"connect" + b"") + self._send(util.unique_id_bytes() + b"connect" + b"") self._recv() logging.info("Connected to server at %s", self._address) def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: - _timeout = self._default_timeout if timeout is None else timeout + _timeout = timeout or self._default_timeout def _poll_data(): # poll is slow, need to find a better way @@ -45,16 +44,16 @@ def _poll_data(): rcv_data = self._recv() - # first 32 bytes as response id - resp_id = rcv_data[:32].decode() + # first 16 bytes as response id + resp_id = rcv_data[:16] # the rest is response data - resp_data = rcv_data[32:] + resp_data = rcv_data[16:] return resp_id, resp_data - req_id = util.unique_id() - self._send(req_id.encode() + message) + req_id = util.unique_id_bytes() + self._send(req_id + message) resp_id, resp_data = None, None # as the client is synchronous, we know that the response will be available any next poll @@ -91,7 +90,7 @@ def _recv(self) -> bytes: class AsyncZeroMQClient: - def __init__(self, default_timeout: int = 2000): + def __init__(self, default_timeout: int): if sys.platform == "win32": # windows need special event loop policy to work with zmq asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) @@ -105,80 +104,53 @@ def __init__(self, default_timeout: int = 2000): self.socket.setsockopt(zmq.RCVTIMEO, default_timeout) self.socket.setsockopt(zmq.SNDTIMEO, default_timeout) - self.poller = zmqasync.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - self._resp_map: Dict[str, bytes] = {} - - # self.peer1, self.peer2 = zpipe_async(self._context) + self._resp_events: Dict[bytes, asyncio.Event] = {} + self._resp_data: Dict[bytes, bytes] = {} async def connect(self, address: str) -> None: self._address = address self.socket.connect(address) - await self._send(util.unique_id().encode() + b"connect" + b"") + await self._send(util.unique_id_bytes() + b"connect" + b"") await self._recv() logging.info("Connected to server at %s", self._address) async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: - _timeout = self._default_timeout if timeout is None else timeout - expire_at = util.current_time_us() + (_timeout * 1000) - - is_data = Event() + _timeout = timeout or self._default_timeout + _expire_at = (asyncio.get_event_loop().time() * 1e3) + _timeout async def _poll_data(): - # async has issue with poller, after 3-4 calls, it returns empty - # if not await self._poll(_timeout): - # raise TimeoutException(f"Timeout while sending message at {self._address}") - resp = await self._recv() - # first 32 bytes as response id - resp_id = resp[:32].decode() - - # the rest is response data - resp_data = resp[32:] - self._resp_map[resp_id] = resp_data - - # pipe is a good way to notify the main event loop that there is a response - # but pipe is actually slower than sleep, because it is a zmq socket - # yes it uses inproc, but still slower than asyncio.sleep - # try: - # await self.peer1.send(b"") - # except zmqerr.Again: - # # if the pipe is full, just pass - # pass - - is_data.set() + resp_id = resp[:16] + resp_data = resp[16:] - req_id = util.unique_id() - await self._send(req_id.encode() + message) + if resp_id in self._resp_events: + self._resp_data[resp_id] = resp_data + self._resp_events[resp_id].set() - # poll can get response of a different call - # so we poll until we get the response of this call or timeout - await _poll_data() + req_id = util.unique_id_bytes() + self._resp_events[req_id] = asyncio.Event() - while req_id not in self._resp_map: - if util.current_time_us() > expire_at: - raise TimeoutException( - f"Timeout while waiting for response at {self._address}" - ) - - # await asyncio.sleep(1e-6) - await asyncio.wait_for(is_data.wait(), timeout=_timeout) - - # try: - # await self.peer2.recv() - # except zmqerr.Again: - # # if the pipe is empty, just pass - # pass - - resp_data = self._resp_map.pop(req_id) + await self._send(req_id + message) - return resp_data + try: + await asyncio.wait_for(_poll_data(), _timeout / 1e3) + remaining_time = _expire_at - (asyncio.get_event_loop().time() * 1e3) + await asyncio.wait_for( + self._resp_events[req_id].wait(), remaining_time / 1e3 + ) + return self._resp_data.pop(req_id, b"") + except asyncio.TimeoutError as exc: + self._resp_events.pop(req_id, None) + self._resp_data.pop(req_id, None) + raise TimeoutException( + f"Timeout while waiting for response at {self._address}" + ) from exc def close(self) -> None: self.socket.close() - self._resp_map.clear() + self._resp_events.clear() + self._resp_data.clear() async def _send(self, message: bytes) -> None: try: @@ -188,10 +160,6 @@ async def _send(self, message: bytes) -> None: f"Connection error for send at {self._address}" ) from exc - async def _poll(self, timeout: int) -> bool: - socks = dict(await self.poller.poll(timeout)) - return self.socket in socks - async def _recv(self) -> bytes: try: return await self.socket.recv() diff --git a/zero/zeromq_patterns/queue_device/worker.py b/zero/zeromq_patterns/queue_device/worker.py index ebed4d5..9d62cce 100644 --- a/zero/zeromq_patterns/queue_device/worker.py +++ b/zero/zeromq_patterns/queue_device/worker.py @@ -39,9 +39,9 @@ def _recv_and_process(self, msg_handler: Callable[[bytes, bytes], Optional[bytes # so the broker knows who to send the response to ident, data = frames - # first 32 bytes is request id - req_id = data[:32] - data = data[32:] + # first 16 bytes is request id + req_id = data[:16] + data = data[16:] # then 80 bytes is function name func_name = data[:80].strip()