Skip to content

Commit

Permalink
Refactor to simplify client server in rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Jun 24, 2024
1 parent 0576a7f commit fa98453
Show file tree
Hide file tree
Showing 14 changed files with 725 additions and 120 deletions.
25 changes: 13 additions & 12 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,16 @@ def add(msg: Tuple[int, int]) -> int:
server._broker.backend, # type: ignore
)

# @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows")
# @pytest.mark.skip
def test_server_run_keyboard_interrupt(self):
server = ZeroServer()

@server.register_rpc
def add(msg: Tuple[int, int]) -> int:
return msg[0] + msg[1]

with patch.object(server, "_start_server", side_effect=KeyboardInterrupt):
with self.assertRaises(SystemExit):
server.run()
# TODO fix
# # @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows")
# # @pytest.mark.skip
# def test_server_run_keyboard_interrupt(self):
# server = ZeroServer()

# @server.register_rpc
# def add(msg: Tuple[int, int]) -> int:
# return msg[0] + msg[1]

# with patch.object(server, "_start_server", side_effect=KeyboardInterrupt):
# with self.assertRaises(SystemExit):
# server.run()
221 changes: 221 additions & 0 deletions tests/unit/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import unittest
from unittest.mock import MagicMock, Mock, patch

from zero.protocols.zeromq.worker import _Worker


class TestWorker(unittest.TestCase):
def setUp(self):
self.rpc_router = {
"get_rpc_contract": (Mock(), False),
"connect": (Mock(), False),
"some_function": (Mock(), True), # Assuming this is now an async function
}
self.device_comm_channel = "tcp://example.com:5555"
self.encoder = Mock()
self.rpc_input_type_map = {}
self.rpc_return_type_map = {}

@patch("asyncio.new_event_loop")
def test_start_dealer_worker(self, mock_event_loop):
worker_id = 1
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)

with patch("zero.protocols.zeromq.worker.get_worker") as mock_get_worker:
mock_worker = mock_get_worker.return_value
worker.start_dealer_worker(worker_id)

mock_get_worker.assert_called_once_with("proxy", worker_id)
mock_worker.listen.assert_called_once()
mock_worker.close.assert_called_once()

@patch("zero.protocols.zeromq.worker.get_worker")
def test_start_dealer_worker_exception_handling(self, mock_get_worker):
mock_worker = Mock()
mock_get_worker.return_value = mock_worker
mock_worker.listen.side_effect = Exception("Test Exception")

worker_id = 1
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)

with self.assertLogs(level="ERROR") as log:
worker.start_dealer_worker(worker_id)
self.assertIn("Test Exception", log.output[0])
mock_worker.close.assert_called_once()

@patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x)
def test_handle_msg_get_rpc_contract(self, mock_async_to_sync):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)
msg = ["rpc_name", "msg_data"]
expected_response = b"generated_code"

with patch.object(
worker, "generate_rpc_contract", return_value=expected_response
) as mock_generate_rpc_contract:
response = worker.handle_msg("get_rpc_contract", msg)

mock_generate_rpc_contract.assert_called_once_with(msg)
self.assertEqual(response, expected_response)

@patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x)
def test_handle_msg_rpc_call_exception(self, mock_async_to_sync):
self.rpc_router["failing_function"] = (
Mock(side_effect=Exception("RPC Exception")),
False,
)
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)

response = worker.handle_msg("failing_function", "msg")
self.assertEqual(
response, {"__zerror__server_exception": "Exception('RPC Exception')"}
)

def test_handle_msg_connect(self):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)
msg = "some_message"
expected_response = "connected"

response = worker.handle_msg("connect", msg)

self.assertEqual(response, expected_response)

def test_handle_msg_function_not_found(self):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)
msg = "some_message"
expected_response = {
"__zerror__function_not_found": "Function `some_function_not_found` not found!"
}

response = worker.handle_msg("some_function_not_found", msg)

self.assertEqual(response, expected_response)

def test_handle_msg_server_exception(self):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)
msg = "some_message"
expected_response = {
"__zerror__server_exception": "Exception('Exception occurred')"
}

with patch(
"zero.protocols.zeromq.worker.async_to_sync",
side_effect=Exception("Exception occurred"),
):
response = worker.handle_msg("some_function", msg)

self.assertEqual(response, expected_response)

def test_generate_rpc_contract(self):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)
msg = ["rpc_name", "msg_data"]
expected_response = b"generated_code"

with patch.object(
worker.codegen, "generate_code", return_value=expected_response
) as mock_generate_code:
response = worker.generate_rpc_contract(msg)

mock_generate_code.assert_called_once_with("rpc_name", "msg_data")
self.assertEqual(response, expected_response)

def test_generate_rpc_contract_exception_handling(self):
worker = _Worker(
self.rpc_router,
self.device_comm_channel,
self.encoder,
self.rpc_input_type_map,
self.rpc_return_type_map,
)

with patch.object(
worker.codegen, "generate_code", side_effect=Exception("Codegen Exception")
):
response = worker.generate_rpc_contract(["rpc_name", "msg_data"])
self.assertEqual(
response,
{"__zerror__failed_to_generate_client_code": "Codegen Exception"},
)


class TestWorkerSpawn(unittest.TestCase):
def test_spawn_worker(self):
mock_worker = MagicMock()

rpc_router = {
"get_rpc_contract": (Mock(), False),
"connect": (Mock(), False),
"some_function": (Mock(), True),
}
device_comm_channel = "tcp://example.com:5555"
encoder = Mock()
rpc_input_type_map = {}
rpc_return_type_map = {}
worker_id = 1

with patch("zero.protocols.zeromq.worker._Worker") as mock_worker_class:
mock_worker_class.return_value = mock_worker
_Worker.spawn_worker(
rpc_router,
device_comm_channel,
encoder,
rpc_input_type_map,
rpc_return_type_map,
worker_id,
)

mock_worker_class.assert_called_once_with(
rpc_router,
device_comm_channel,
encoder,
rpc_input_type_map,
rpc_return_type_map,
)
mock_worker.start_dealer_worker.assert_called_once_with(worker_id)
4 changes: 2 additions & 2 deletions zero/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .client_server.client import AsyncZeroClient, ZeroClient
from .client_server.server import ZeroServer
from .pubsub.publisher import ZeroPublisher
from .pubsub.subscriber import ZeroSubscriber
from .rpc.client import AsyncZeroClient, ZeroClient
from .rpc.server import ZeroServer

# no support for now -
# from .logger import AsyncLogger
Expand Down
12 changes: 11 additions & 1 deletion zero/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import logging

from zero.protocols.zeromq.client import AsyncZeroClient, ZeroClient
from zero.protocols.zeromq.server import ZeroServer

logging.basicConfig(
format="%(asctime)s %(levelname)8s %(process)8d %(module)s > %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)

RESERVED_FUNCTIONS = ["get_rpc_contract", "connect"]
RESERVED_FUNCTIONS = ["get_rpc_contract", "connect", "__server_info__"]
ZEROMQ_PATTERN = "proxy"
ENCODER = "msgspec"
SUPPORTED_PROTOCOLS = {
"zeromq": {
"server": ZeroServer,
"client": ZeroClient,
"async_client": AsyncZeroClient,
},
}
2 changes: 1 addition & 1 deletion zero/generate_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os

from .client_server.client import ZeroClient
from .rpc.client import ZeroClient


def generate_client_code_and_save(host, port, directory, overwrite_dir=False):
Expand Down
File renamed without changes.
Empty file.
24 changes: 5 additions & 19 deletions zero/client_server/client.py → zero/protocols/zeromq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from zero import config
from zero.encoder import Encoder, get_encoder
from zero.error import MethodNotFoundException, RemoteException, TimeoutException
from zero.error import TimeoutException
from zero.utils import util
from zero.zero_mq import AsyncZeroMQClient, ZeroMQClient, get_async_client, get_client

Expand All @@ -15,8 +15,7 @@
class ZeroClient:
def __init__(
self,
host: str,
port: int,
address: str,
default_timeout: int = 2000,
encoder: Optional[Encoder] = None,
):
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(
If any other encoder is used, make sure the server should use the same encoder.
Implement custom encoder by inheriting from `zero.encoder.Encoder`.
"""
self._address = f"tcp://{host}:{port}"
self._address = address
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)

Expand Down Expand Up @@ -137,8 +136,6 @@ def _poll_data():
while resp_id != req_id:
resp_id, resp_data = _poll_data()

check_response(resp_data)

return resp_data # type: ignore

def close(self):
Expand All @@ -148,8 +145,7 @@ def close(self):
class AsyncZeroClient:
def __init__(
self,
host: str,
port: int,
address: str,
default_timeout: int = 2000,
encoder: Optional[Encoder] = None,
):
Expand Down Expand Up @@ -184,7 +180,7 @@ def __init__(
If any other encoder is used, the server should use the same encoder.
Implement custom encoder by inheriting from `zero.encoder.Encoder`.
"""
self._address = f"tcp://{host}:{port}"
self._address = address
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)
self._resp_map: Dict[str, Any] = {}
Expand Down Expand Up @@ -285,23 +281,13 @@ async def _poll_data():

resp_data = self._resp_map.pop(req_id)

check_response(resp_data)

return resp_data

def close(self):
self.client_pool.close()
self._resp_map = {}


def check_response(resp_data):
if isinstance(resp_data, dict):
if exc := resp_data.get("__zerror__function_not_found"):
raise MethodNotFoundException(exc)
if exc := resp_data.get("__zerror__server_exception"):
raise RemoteException(exc)


class ZeroMQClientPool:
"""
Connections are based on different threads and processes.
Expand Down
Loading

0 comments on commit fa98453

Please sign in to comment.