diff --git a/tests/functional/multiple_servers/server1.py b/tests/functional/multiple_servers/server1.py index 0458fc7..59e1236 100644 --- a/tests/functional/multiple_servers/server1.py +++ b/tests/functional/multiple_servers/server1.py @@ -11,7 +11,7 @@ def hello() -> str: def run(port): print("Starting server1 on port", port) - app = ZeroServer(port=port) + app = ZeroServer(port=port, use_threads=True) app.register_rpc(echo) app.register_rpc(hello) app.run(2) diff --git a/tests/functional/single_server/client_test.py b/tests/functional/single_server/client_test.py index 462f07a..7158d5f 100644 --- a/tests/functional/single_server/client_test.py +++ b/tests/functional/single_server/client_test.py @@ -5,6 +5,7 @@ import pytest import zero.error +from tests.functional.single_server import threaded_server from zero import AsyncZeroClient, ZeroClient from . import server @@ -221,3 +222,8 @@ def test_random_timeout_async(): # time_taken_ms = (time.perf_counter() - start) * 1000 # assert time_taken_ms < 1000 + + +def test_threaded_server_hello_world(): + client = ZeroClient(threaded_server.HOST, threaded_server.PORT) + assert client.call("hello_world", "") == "hello world" diff --git a/tests/functional/single_server/conftest.py b/tests/functional/single_server/conftest.py index 1437cbe..9af5abc 100644 --- a/tests/functional/single_server/conftest.py +++ b/tests/functional/single_server/conftest.py @@ -15,3 +15,10 @@ def base_server(): process = start_subprocess("tests.functional.single_server.server") yield kill_subprocess(process) + + +@pytest.fixture(autouse=True, scope="session") +def threaded_server(): + process = start_subprocess("tests.functional.single_server.threaded_server") + yield + kill_subprocess(process) diff --git a/tests/functional/single_server/threaded_server.py b/tests/functional/single_server/threaded_server.py new file mode 100644 index 0000000..c76ad4e --- /dev/null +++ b/tests/functional/single_server/threaded_server.py @@ -0,0 +1,16 @@ +from zero.rpc.server import ZeroServer + +PORT = 7777 +HOST = "localhost" + + +app = ZeroServer(port=PORT, use_threads=True) + + +@app.register_rpc +async def hello_world() -> str: + return "hello world" + + +if __name__ == "__main__": + app.run(2) diff --git a/tests/utils.py b/tests/utils.py index 7d226dc..4696e3d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,6 +35,7 @@ def _ping(port: int) -> bool: def kill_process(process: Process): process.terminate() + process.kill() _wait_for_process_to_die(process) process.join() diff --git a/zero/protocols/zeromq/server.py b/zero/protocols/zeromq/server.py index a47dd22..6099e38 100644 --- a/zero/protocols/zeromq/server.py +++ b/zero/protocols/zeromq/server.py @@ -3,7 +3,7 @@ import signal import sys from functools import partial -from multiprocessing.pool import Pool +from multiprocessing.pool import Pool, ThreadPool from typing import Callable, Dict, Optional, Tuple import zmq.utils.win32 @@ -26,6 +26,7 @@ def __init__( rpc_input_type_map: Dict[str, Optional[type]], rpc_return_type_map: Dict[str, Optional[type]], encoder: Encoder, + use_threads: bool, ): self._broker: ZeroMQBroker = None # type: ignore self._device_comm_channel: str = None # type: ignore @@ -37,6 +38,7 @@ def __init__( self._rpc_input_type_map = rpc_input_type_map self._rpc_return_type_map = rpc_return_type_map self._encoder = encoder + self._use_threads = use_threads def start(self, workers: int = os.cpu_count() or 1): """ @@ -67,7 +69,10 @@ def start(self, workers: int = os.cpu_count() or 1): self._start_server(workers, spawn_worker) def _start_server(self, workers: int, spawn_worker: Callable[[int], None]): - self._pool = Pool(workers) + if self._use_threads: + self._pool = ThreadPool(workers) + else: + self._pool = Pool(workers) # process termination signals util.register_signal_term(self._sig_handler) @@ -113,4 +118,4 @@ def _remove_ipc(self): def _terminate_pool(self): self._pool.terminate() self._pool.close() - self._pool.join() + # self._pool.join() diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py index f335d1b..250ca3b 100644 --- a/zero/rpc/protocols.py +++ b/zero/rpc/protocols.py @@ -24,6 +24,7 @@ def __init__( rpc_input_type_map: Dict[str, Optional[type]], rpc_return_type_map: Dict[str, Optional[type]], encoder: Encoder, + use_threads: bool, ): ... diff --git a/zero/rpc/server.py b/zero/rpc/server.py index 4f3eba6..0553894 100644 --- a/zero/rpc/server.py +++ b/zero/rpc/server.py @@ -31,6 +31,7 @@ def __init__( port: int = 5559, encoder: Optional[Encoder] = None, protocol: str = "zeromq", + use_threads: bool = False, ): """ ZeroServer registers and exposes rpc functions that can be called from a ZeroClient. @@ -41,21 +42,30 @@ def __init__( ---------- host: str Host of the ZeroServer. + port: int Port of the ZeroServer. + encoder: Optional[Encoder] Encoder to encode/decode messages from/to client. Default is msgspec. If any other encoder is used, the client should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. + protocol: str Protocol to use for communication. Default is zeromq. If any other protocol is used, the client should use the same protocol. + + use_threads: bool + Use threads instead of processes. + By default it uses processes. + If True, the server uses threads instead of processes. GIL will be watching you! """ self._host = host self._port = port self._address = f"tcp://{self._host}:{self._port}" + self._use_threads = use_threads # to encode/decode messages from/to client if encoder and not isinstance(encoder, Encoder): @@ -78,6 +88,7 @@ def __init__( self._rpc_input_type_map, self._rpc_return_type_map, self._encoder, + self._use_threads, ) def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]: