Skip to content

Commit

Permalink
Improve client using pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Jun 25, 2024
1 parent a060a11 commit c045801
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 219 deletions.
2 changes: 1 addition & 1 deletion tests/concurrency/rps_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
async def task(semaphore, items):
async with semaphore:
try:
await async_client.call("sum_async", items)
await async_client.call("sum_sync", items)
# res = await async_client.call("sum_async", items)
# print(res)
except Exception as e:
Expand Down
120 changes: 17 additions & 103 deletions zero/protocols/zeromq/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio
import logging
import threading
from typing import Any, Dict, Optional, Type, TypeVar, Union

from zero import config
from zero.encoder import Encoder, get_encoder
from zero.error import TimeoutException
from zero.utils import util
from zero.zeromq_patterns import (
AsyncZeroMQClient,
ZeroMQClient,
Expand All @@ -28,7 +25,7 @@ def __init__(
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)

self.client_pool = ZeroMQClientPool(
self.client_pool = ZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -43,47 +40,17 @@ def call(
) -> T:
zmqc = self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout

def _poll_data():
# TODO poll is slow, need to find a better way
if not zmqc.poll(_timeout):
raise TimeoutException(
f"Timeout while sending message at {self._address}"
)

rcv_data = zmqc.recv()

# first 32 bytes as response id
resp_id = rcv_data[:32].decode()

# the rest is response data
resp_data_encoded = rcv_data[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)

return resp_id, resp_data

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

msg_bytes = b"" if msg is None else self._encoder.encode(msg)
zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

resp_id, resp_data = None, None
# as the client is synchronous, we know that the response will be available any next poll
# we try to get the response until timeout because a previous call might be timed out
# and the response is still in the socket,
# so we poll until we get the response for this call
while resp_id != req_id:
resp_id, resp_data = _poll_data()
resp_data_bytes = zmqc.request(func_name_bytes + msg_bytes, timeout)

return resp_data # type: ignore
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
Expand All @@ -99,9 +66,8 @@ def __init__(
self._address = address
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)
self._resp_map: Dict[str, Any] = {}

self.client_pool = AsyncZeroMQClientPool(
self.client_pool = AsyncZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -116,63 +82,23 @@ async def call(
) -> T:
zmqc = await self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout
expire_at = util.current_time_us() + (_timeout * 1000)

async def _poll_data():
# TODO async has issue with poller, after 3-4 calls, it returns empty
# if not await zmqc.poll(_timeout):
# raise TimeoutException(f"Timeout while sending message at {self._address}")

# first 32 bytes as response id
resp = await zmqc.recv()
resp_id = resp[:32].decode()

# the rest is response data
resp_data_encoded = resp[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)
self._resp_map[resp_id] = resp_data

# TODO try to use pipe instead of sleep
# await self.peer1.send(b"")

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

msg_bytes = b"" if msg is None else self._encoder.encode(msg)
await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

# every request poll the data, so whenever a response comes, it will be stored in __resps
# dont need to poll again in the while loop
await _poll_data()

while req_id not in self._resp_map and util.current_time_us() <= expire_at:
# TODO the problem with the zpipe is that we can miss some response
# when we come to this line
# await self.peer2.recv()
await asyncio.sleep(1e-6)
resp_data_bytes = await zmqc.request(func_name_bytes + msg_bytes, timeout)

if util.current_time_us() > expire_at:
raise TimeoutException(
f"Timeout while waiting for response at {self._address}"
)

resp_data = self._resp_map.pop(req_id)

return resp_data
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
self._resp_map = {}


class ZeroMQClientPool:
class ZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient:
logging.debug("No connection found in current thread, creating new one")
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
self._pool[thread_id].connect(self._address)
self._try_connect_ping(self._pool[thread_id])
return self._pool[thread_id]

def _try_connect_ping(self, client: ZeroMQClient):
client.send(util.unique_id().encode() + b"connect" + b"")
client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
self._pool = {}


class AsyncZeroMQClientPool:
class AsyncZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient:
self._pool[thread_id] = get_async_client(
config.ZEROMQ_PATTERN, self._timeout
)
self._pool[thread_id].connect(self._address)
await self._try_connect_ping(self._pool[thread_id])
await self._pool[thread_id].connect(self._address)
return self._pool[thread_id]

async def _try_connect_ping(self, client: AsyncZeroMQClient):
await client.send(util.unique_id().encode() + b"connect" + b"")
await client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
Expand Down
6 changes: 3 additions & 3 deletions zero/zeromq_patterns/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def zpipe_async(
ctx: zmq.asyncio.Context, timeout: int = 1000
ctx: zmq.asyncio.Context,
) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: # pragma: no cover
"""
Build inproc pipe for talking to threads
Expand All @@ -20,8 +20,8 @@ def zpipe_async(
sock_b = ctx.socket(zmq.PAIR)
sock_a.linger = sock_b.linger = 0
sock_a.hwm = sock_b.hwm = 1
sock_a.sndtimeo = sock_b.sndtimeo = timeout
sock_a.rcvtimeo = sock_b.rcvtimeo = timeout
sock_a.sndtimeo = sock_b.sndtimeo = 0
sock_a.rcvtimeo = sock_b.rcvtimeo = 0
iface = f"inproc://{util.unique_id()}"
sock_a.bind(iface)
sock_b.connect(iface)
Expand Down
59 changes: 16 additions & 43 deletions zero/zeromq_patterns/protocols.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,41 @@
from typing import Any, Callable, Optional, Protocol, runtime_checkable

import zmq
import zmq.asyncio
from typing import Callable, Optional, Protocol, runtime_checkable


@runtime_checkable
class ZeroMQClient(Protocol): # pragma: no cover
@property
def context(self) -> zmq.Context:
def __init__(
self,
address: str,
default_timeout: int = 2000,
):
...

def connect(self, address: str) -> None:
...

def close(self) -> None:
...

def send(self, message: bytes) -> None:
def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
...

def send_multipart(self, message: list) -> None:
...

def poll(self, timeout: int) -> bool:
...

def recv(self) -> bytes:
...

def recv_multipart(self) -> list:
...

def request(self, message: bytes) -> Any:
def close(self) -> None:
...


@runtime_checkable
class AsyncZeroMQClient(Protocol): # pragma: no cover
@property
def context(self) -> zmq.asyncio.Context:
def __init__(
self,
address: str,
default_timeout: int = 2000,
):
...

def connect(self, address: str) -> None:
...

def close(self) -> None:
async def connect(self, address: str) -> None:
...

async def send(self, message: bytes) -> None:
async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
...

async def send_multipart(self, message: list) -> None:
...

async def poll(self, timeout: int) -> bool:
...

async def recv(self) -> bytes:
...

async def recv_multipart(self) -> list:
...

async def request(self, message: bytes) -> Any:
def close(self) -> None:
...


Expand Down
Loading

0 comments on commit c045801

Please sign in to comment.