From ccc42a63c99997b093e0030de05dc10b12239b0c Mon Sep 17 00:00:00 2001 From: Eric Peterson Date: Wed, 27 Mar 2019 13:46:39 -0400 Subject: [PATCH] add option to allow exception propagation (#44) --- VERSION.txt | 2 +- rpcq/_server.py | 75 ++++++++++++++++++++++++++++++++++++++++--------- rpcq/_spec.py | 20 +++++++++---- 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index 197c4d5..005119b 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -2.4.0 +2.4.1 diff --git a/rpcq/_server.py b/rpcq/_server.py index 6e7524d..2b23d4f 100644 --- a/rpcq/_server.py +++ b/rpcq/_server.py @@ -35,15 +35,26 @@ class Server: """ Server that accepts JSON RPC calls through a socket. """ - def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False): + def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False, + serialize_exceptions: bool = True): """ Create a server that will be linked to a socket :param rpc_spec: JSON RPC spec + :param announce_timing: + :param serialize_exceptions: If set to True, this Server will catch all exceptions occurring + internally to it and, when possible, communicate them to the interrogating Client. If + set to False, this Server will re-raise any exceptions it encounters (including, but not + limited to, those which might occur through method calls to rpc_spec) for Server's + local owner to handle. + + IMPORTANT NOTE: When set to False, this *almost definitely* means an unrecoverable + crash, and the Server should then be _shutdown(). """ self.announce_timing = announce_timing + self.serialize_exceptions = serialize_exceptions - self.rpc_spec = rpc_spec if rpc_spec else RPCSpec() + self.rpc_spec = rpc_spec if rpc_spec else RPCSpec(serialize_exceptions=serialize_exceptions) self._exit_handlers = [] self._socket = None @@ -74,17 +85,49 @@ async def run_async(self, endpoint: str): """ self._connect(endpoint) - while True: - try: - # empty_frame may either be: - # 1. a single null frame if the client is a REQ socket - # 2. an empty list (ie. no frames) if the client is a DEALER socket - identity, *empty_frame, msg = await self._socket.recv_multipart() - request = from_msgpack(msg) + # spawn an initial listen task + listen_task = asyncio.ensure_future(self._socket.recv_multipart()) + task_list = [listen_task] - asyncio.ensure_future(self._process_request(identity, empty_frame, request)) - except Exception: - _log.exception('Exception thrown in Server run loop') + while True: + dones, pendings = await asyncio.wait(task_list, return_when=asyncio.FIRST_COMPLETED) + + # grab one "done" task to handle + task_list, done_list = list(pendings), list(dones) + done = done_list.pop() + task_list += done_list + + if done == listen_task: + try: + # empty_frame may either be: + # 1. a single null frame if the client is a REQ socket + # 2. an empty list (ie. no frames) if the client is a DEALER socket + identity, *empty_frame, msg = done.result() + request = from_msgpack(msg) + + # spawn a processing task + task_list.append(asyncio.ensure_future( + self._process_request(identity, empty_frame, request))) + except Exception as e: + if self.serialize_exceptions: + _log.exception('Exception thrown in Server run loop during request ' + 'reception: {}'.format(str(e))) + else: + raise e + finally: + # spawn a new listen task + listen_task = asyncio.ensure_future(self._socket.recv_multipart()) + task_list.append(listen_task) + else: + # if there's been an exception during processing, consider reraising it + try: + done.result() + except Exception as e: + if self.serialize_exceptions: + _log.exception('Exception thrown in Server run loop during request ' + 'dispatch: {}'.format(str(e))) + else: + raise e def run(self, endpoint: str, loop: AbstractEventLoop = None): """ @@ -151,5 +194,9 @@ async def _process_request(self, identity: bytes, empty_frame: list, request: RP _log.debug("Sending client %s reply: %s", identity, reply) await self._socket.send_multipart([identity, *empty_frame, to_msgpack(reply)]) - except Exception: - _log.exception('Exception thrown in _process_request') + except Exception as e: + if self.serialize_exceptions: + _log.exception('Exception thrown in _process_request') + else: + raise e + diff --git a/rpcq/_spec.py b/rpcq/_spec.py index 8596eda..1bd3276 100644 --- a/rpcq/_spec.py +++ b/rpcq/_spec.py @@ -31,7 +31,7 @@ class RPCSpec(object): """ Class for keeping track of class methods that are exposed to the JSON RPC interface """ - def __init__(self, *, provide_tracebacks: bool = True): + def __init__(self, *, provide_tracebacks: bool = True, serialize_exceptions: bool = True): """ Create a JsonRpcSpec object. @@ -61,9 +61,14 @@ def add(obj, *args): implementations will have their tracebacks forwarded to the calling client as part of the generated RPCError reply objject. If set to False, the generated RPCError reply will omit this information (but the traceback will still get written to the logfile). + :param serialize_exceptions: If set to True, unhandled exceptions which occur during RPC + call implementations will be serialized into RPCError messages (which the Server + instance will then probably send to the corresponding Client). If set to False, the + exception is re-raised and left for the local caller to handle further. """ self._json_rpc_methods = {} self.provide_tracebacks = provide_tracebacks + self.serialize_exceptions = serialize_exceptions def add_handler(self, f): """ @@ -114,11 +119,14 @@ async def run_handler(self, request: RPCRequest) -> Union[RPCReply, RPCError]: result = await result except Exception as e: - _traceback = traceback.format_exc() - _log.error(_traceback) - if self.provide_tracebacks: - return rpc_error(request.id, "{}\n{}".format(str(e), _traceback)) + if self.serialize_exceptions: + _traceback = traceback.format_exc() + _log.error(_traceback) + if self.provide_tracebacks: + return rpc_error(request.id, "{}\n{}".format(str(e), _traceback)) + else: + return rpc_error(request.id, str(e)) else: - return rpc_error(request.id, str(e)) + raise e return rpc_reply(request.id, result)