Skip to content

Commit

Permalink
Improve client using asyncio event
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Jun 25, 2024
1 parent a060a11 commit 8d4d732
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 277 deletions.
2 changes: 1 addition & 1 deletion tests/concurrency/rps_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
async def task(semaphore, items):
async with semaphore:
try:
await async_client.call("sum_async", items)
await async_client.call("sum_sync", items)
# res = await async_client.call("sum_async", items)
# print(res)
except Exception as e:
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from zero import ZeroServer
from zero.encoder.protocols import Encoder
from zero.zeromq_patterns.protocols import ZeroMQBroker
from zero.zeromq_patterns.interfaces import ZeroMQBroker

DEFAULT_PORT = 5559
DEFAULT_HOST = "0.0.0.0"
Expand Down Expand Up @@ -216,6 +216,17 @@ class Message:
def add(msg: Message) -> Message:
return Message()

def test_register_rpc_with_long_name(self):
server = ZeroServer()

with self.assertRaises(ValueError):

@server.register_rpc
def add_this_is_a_very_long_name_for_a_function_more_than_120_characters_ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff(
msg: Tuple[int, int]
) -> int:
return msg[0] + msg[1]

def test_server_run(self):
server = ZeroServer()

Expand Down
2 changes: 2 additions & 0 deletions zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
"async_client": AsyncZMQClient,
},
}

MAX_FUNC_NAME_LEN = 80
128 changes: 21 additions & 107 deletions zero/protocols/zeromq/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio
import logging
import threading
from typing import Any, Dict, Optional, Type, TypeVar, Union

from zero import config
from zero.encoder import Encoder, get_encoder
from zero.error import TimeoutException
from zero.utils import util
from zero.zeromq_patterns import (
AsyncZeroMQClient,
ZeroMQClient,
Expand All @@ -28,7 +25,7 @@ def __init__(
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)

self.client_pool = ZeroMQClientPool(
self.client_pool = ZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -43,47 +40,17 @@ def call(
) -> T:
zmqc = self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout

def _poll_data():
# TODO poll is slow, need to find a better way
if not zmqc.poll(_timeout):
raise TimeoutException(
f"Timeout while sending message at {self._address}"
)

rcv_data = zmqc.recv()

# first 32 bytes as response id
resp_id = rcv_data[:32].decode()

# the rest is response data
resp_data_encoded = rcv_data[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)

return resp_id, resp_data

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

# make function name exactly 80 bytes
func_name_bytes = rpc_func_name.ljust(config.MAX_FUNC_NAME_LEN).encode()
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

resp_id, resp_data = None, None
# as the client is synchronous, we know that the response will be available any next poll
# we try to get the response until timeout because a previous call might be timed out
# and the response is still in the socket,
# so we poll until we get the response for this call
while resp_id != req_id:
resp_id, resp_data = _poll_data()
resp_data_bytes = zmqc.request(func_name_bytes + msg_bytes, timeout)

return resp_data # type: ignore
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
Expand All @@ -99,9 +66,8 @@ def __init__(
self._address = address
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)
self._resp_map: Dict[str, Any] = {}

self.client_pool = AsyncZeroMQClientPool(
self.client_pool = AsyncZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -116,63 +82,23 @@ async def call(
) -> T:
zmqc = await self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout
expire_at = util.current_time_us() + (_timeout * 1000)

async def _poll_data():
# TODO async has issue with poller, after 3-4 calls, it returns empty
# if not await zmqc.poll(_timeout):
# raise TimeoutException(f"Timeout while sending message at {self._address}")

# first 32 bytes as response id
resp = await zmqc.recv()
resp_id = resp[:32].decode()

# the rest is response data
resp_data_encoded = resp[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)
self._resp_map[resp_id] = resp_data

# TODO try to use pipe instead of sleep
# await self.peer1.send(b"")

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

# make function name exactly 80 bytes
func_name_bytes = rpc_func_name.ljust(config.MAX_FUNC_NAME_LEN).encode()
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

# every request poll the data, so whenever a response comes, it will be stored in __resps
# dont need to poll again in the while loop
await _poll_data()

while req_id not in self._resp_map and util.current_time_us() <= expire_at:
# TODO the problem with the zpipe is that we can miss some response
# when we come to this line
# await self.peer2.recv()
await asyncio.sleep(1e-6)
resp_data_bytes = await zmqc.request(func_name_bytes + msg_bytes, timeout)

if util.current_time_us() > expire_at:
raise TimeoutException(
f"Timeout while waiting for response at {self._address}"
)

resp_data = self._resp_map.pop(req_id)

return resp_data
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
self._resp_map = {}


class ZeroMQClientPool:
class ZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient:
logging.debug("No connection found in current thread, creating new one")
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
self._pool[thread_id].connect(self._address)
self._try_connect_ping(self._pool[thread_id])
return self._pool[thread_id]

def _try_connect_ping(self, client: ZeroMQClient):
client.send(util.unique_id().encode() + b"connect" + b"")
client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
self._pool = {}


class AsyncZeroMQClientPool:
class AsyncZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient:
self._pool[thread_id] = get_async_client(
config.ZEROMQ_PATTERN, self._timeout
)
self._pool[thread_id].connect(self._address)
await self._try_connect_ping(self._pool[thread_id])
await self._pool[thread_id].connect(self._address)
return self._pool[thread_id]

async def _try_connect_ping(self, client: AsyncZeroMQClient):
await client.send(util.unique_id().encode() + b"connect" + b"")
await client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
Expand Down
6 changes: 4 additions & 2 deletions zero/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def register_rpc(self, func: Callable):
Function should have a single argument.
Argument and return should have a type hint.
If the function got exception, client will get None as return value.
Parameters
----------
func: Callable
Expand Down Expand Up @@ -135,6 +133,10 @@ def run(self, workers: int = os.cpu_count() or 1):
def _verify_function_name(self, func):
if not isinstance(func, Callable):
raise ValueError(f"register function; not {type(func)}")
if len(func.__name__) > config.MAX_FUNC_NAME_LEN:
raise ValueError(
f"function name can be at max {config.MAX_FUNC_NAME_LEN} characters; {func.__name__}"
)
if func.__name__ in self._rpc_router:
raise ValueError(
f"cannot have two RPC function same name: `{func.__name__}`"
Expand Down
2 changes: 1 addition & 1 deletion zero/zeromq_patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .factory import get_async_client, get_broker, get_client, get_worker
from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
2 changes: 1 addition & 1 deletion zero/zeromq_patterns/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from zero.zeromq_patterns import queue_device

from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker


def get_client(pattern: str, default_timeout: int = 2000) -> ZeroMQClient:
Expand Down
6 changes: 3 additions & 3 deletions zero/zeromq_patterns/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def zpipe_async(
ctx: zmq.asyncio.Context, timeout: int = 1000
ctx: zmq.asyncio.Context,
) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: # pragma: no cover
"""
Build inproc pipe for talking to threads
Expand All @@ -20,8 +20,8 @@ def zpipe_async(
sock_b = ctx.socket(zmq.PAIR)
sock_a.linger = sock_b.linger = 0
sock_a.hwm = sock_b.hwm = 1
sock_a.sndtimeo = sock_b.sndtimeo = timeout
sock_a.rcvtimeo = sock_b.rcvtimeo = timeout
sock_a.sndtimeo = sock_b.sndtimeo = 0
sock_a.rcvtimeo = sock_b.rcvtimeo = 0
iface = f"inproc://{util.unique_id()}"
sock_a.bind(iface)
sock_b.connect(iface)
Expand Down
59 changes: 59 additions & 0 deletions zero/zeromq_patterns/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable, Optional, Protocol, runtime_checkable


@runtime_checkable
class ZeroMQClient(Protocol): # pragma: no cover
def __init__(
self,
address: str,
default_timeout: int = 2000,
):
...

def connect(self, address: str) -> None:
...

def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
...

def close(self) -> None:
...


@runtime_checkable
class AsyncZeroMQClient(Protocol): # pragma: no cover
def __init__(
self,
address: str,
default_timeout: int = 2000,
):
...

async def connect(self, address: str) -> None:
...

async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
...

def close(self) -> None:
...


@runtime_checkable
class ZeroMQBroker(Protocol): # pragma: no cover
def listen(self, address: str, channel: str) -> None:
...

def close(self) -> None:
...


@runtime_checkable
class ZeroMQWorker(Protocol): # pragma: no cover
def listen(
self, address: str, msg_handler: Callable[[bytes, bytes], Optional[bytes]]
) -> None:
...

def close(self) -> None:
...
Loading

0 comments on commit 8d4d732

Please sign in to comment.