Skip to content

Commit

Permalink
Support thread workers (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 authored Jul 4, 2024
1 parent c68b307 commit 0691f88
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/functional/multiple_servers/server1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def hello() -> str:

def run(port):
print("Starting server1 on port", port)
app = ZeroServer(port=port)
app = ZeroServer(port=port, use_threads=True)
app.register_rpc(echo)
app.register_rpc(hello)
app.run(2)
6 changes: 6 additions & 0 deletions tests/functional/single_server/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import zero.error
from tests.functional.single_server import threaded_server
from zero import AsyncZeroClient, ZeroClient

from . import server
Expand Down Expand Up @@ -221,3 +222,8 @@ def test_random_timeout_async():
# time_taken_ms = (time.perf_counter() - start) * 1000

# assert time_taken_ms < 1000


def test_threaded_server_hello_world():
client = ZeroClient(threaded_server.HOST, threaded_server.PORT)
assert client.call("hello_world", "") == "hello world"
7 changes: 7 additions & 0 deletions tests/functional/single_server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ def base_server():
process = start_subprocess("tests.functional.single_server.server")
yield
kill_subprocess(process)


@pytest.fixture(autouse=True, scope="session")
def threaded_server():
process = start_subprocess("tests.functional.single_server.threaded_server")
yield
kill_subprocess(process)
16 changes: 16 additions & 0 deletions tests/functional/single_server/threaded_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from zero.rpc.server import ZeroServer

PORT = 7777
HOST = "localhost"


app = ZeroServer(port=PORT, use_threads=True)


@app.register_rpc
async def hello_world() -> str:
return "hello world"


if __name__ == "__main__":
app.run(2)
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _ping(port: int) -> bool:

def kill_process(process: Process):
process.terminate()
process.kill()
_wait_for_process_to_die(process)
process.join()

Expand Down
11 changes: 8 additions & 3 deletions zero/protocols/zeromq/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import signal
import sys
from functools import partial
from multiprocessing.pool import Pool
from multiprocessing.pool import Pool, ThreadPool
from typing import Callable, Dict, Optional, Tuple

import zmq.utils.win32
Expand All @@ -26,6 +26,7 @@ def __init__(
rpc_input_type_map: Dict[str, Optional[type]],
rpc_return_type_map: Dict[str, Optional[type]],
encoder: Encoder,
use_threads: bool,
):
self._broker: ZeroMQBroker = None # type: ignore
self._device_comm_channel: str = None # type: ignore
Expand All @@ -37,6 +38,7 @@ def __init__(
self._rpc_input_type_map = rpc_input_type_map
self._rpc_return_type_map = rpc_return_type_map
self._encoder = encoder
self._use_threads = use_threads

def start(self, workers: int = os.cpu_count() or 1):
"""
Expand Down Expand Up @@ -67,7 +69,10 @@ def start(self, workers: int = os.cpu_count() or 1):
self._start_server(workers, spawn_worker)

def _start_server(self, workers: int, spawn_worker: Callable[[int], None]):
self._pool = Pool(workers)
if self._use_threads:
self._pool = ThreadPool(workers)
else:
self._pool = Pool(workers)

# process termination signals
util.register_signal_term(self._sig_handler)
Expand Down Expand Up @@ -113,4 +118,4 @@ def _remove_ipc(self):
def _terminate_pool(self):
self._pool.terminate()
self._pool.close()
self._pool.join()
# self._pool.join()
1 change: 1 addition & 0 deletions zero/rpc/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
rpc_input_type_map: Dict[str, Optional[type]],
rpc_return_type_map: Dict[str, Optional[type]],
encoder: Encoder,
use_threads: bool,
):
...

Expand Down
11 changes: 11 additions & 0 deletions zero/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
port: int = 5559,
encoder: Optional[Encoder] = None,
protocol: str = "zeromq",
use_threads: bool = False,
):
"""
ZeroServer registers and exposes rpc functions that can be called from a ZeroClient.
Expand All @@ -41,21 +42,30 @@ def __init__(
----------
host: str
Host of the ZeroServer.
port: int
Port of the ZeroServer.
encoder: Optional[Encoder]
Encoder to encode/decode messages from/to client.
Default is msgspec.
If any other encoder is used, the client should use the same encoder.
Implement custom encoder by inheriting from `zero.encoder.Encoder`.
protocol: str
Protocol to use for communication.
Default is zeromq.
If any other protocol is used, the client should use the same protocol.
use_threads: bool
Use threads instead of processes.
By default it uses processes.
If True, the server uses threads instead of processes. GIL will be watching you!
"""
self._host = host
self._port = port
self._address = f"tcp://{self._host}:{self._port}"
self._use_threads = use_threads

# to encode/decode messages from/to client
if encoder and not isinstance(encoder, Encoder):
Expand All @@ -78,6 +88,7 @@ def __init__(
self._rpc_input_type_map,
self._rpc_return_type_map,
self._encoder,
self._use_threads,
)

def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]:
Expand Down

0 comments on commit 0691f88

Please sign in to comment.