diff --git a/python/argo_client/connection.py b/python/argo_client/connection.py index c489eaf..f96f8ea 100644 --- a/python/argo_client/connection.py +++ b/python/argo_client/connection.py @@ -158,32 +158,22 @@ def __init__(self, command: str, *, self.persist = persist super().__init__(command, environment=environment) - - def buffer_replies(self) -> None: - """Read any replies that the server has sent, and add their byte - representation to the internal buffer, freeing up space in - the pipe or socket. - """ - try: - arrived = self.socket.recv(4096) - while arrived != b'': - self.buf.extend(arrived) - arrived = self.socket.recv(4096) - return None - except BlockingIOError: - return None - def get_one_reply(self) -> Optional[str]: - """If a complete reply has been buffered, parse it from the buffer and - return it as a bytestring.""" - self.buffer_replies() - try: - (msg, rest) = netstring.decode(self.buf) - self.buf = bytearray(rest) - self._log_rx(msg) - return msg - except (ValueError, IndexError): - return None + """Return the next message if there is one. Block until the message + is ready or the socket has closed. Return None if the socket closes + and there are no buffered messages ready.""" + while True: + got = netstring.decode(self.buf) + if got is None: + arrived = self.socket.recv(4096) + if arrived == '': + return None + self.buf.extend(arrived) + else: + (msg, rest) = got + self.buf = bytearray(rest) + self._log_rx(msg) + return msg def send_one_message(self, message: str, expecting_response : bool = True) -> None: self._log_tx(message) @@ -261,33 +251,23 @@ def __init__(self, host: str, port: int, ipv6: bool=True): def setup(self) -> None: self.socket = socket.socket(socket.AF_INET6 if self.ipv6 else socket.AF_INET, socket.SOCK_STREAM) self.socket.connect((self.host, self.port)) - self.socket.setblocking(False) - - def buffer_replies(self) -> None: - """Read any replies that the server has sent, and add their byte - representation to the internal buffer, freeing up space in - the pipe or socket. - """ - try: - arrived = self.socket.recv(4096) - while arrived != b'': - self.buf.extend(arrived) - arrived = self.socket.recv(4096) - return None - except BlockingIOError: - return None def get_one_reply(self) -> Optional[str]: - """If a complete reply has been buffered, parse it from the buffer and - return it as a bytestring.""" - self.buffer_replies() - try: - (msg, rest) = netstring.decode(self.buf) - self.buf = bytearray(rest) - self._log_rx(msg) - return msg - except (ValueError, IndexError): - return None + """Return the next message if there is one. Block until the message + is ready or the socket has closed. Return None if the socket closes + and there are no buffered messages ready.""" + while True: + got = netstring.decode(self.buf) + if got is None: + arrived = self.socket.recv(4096) + if arrived == '': + return None + self.buf.extend(arrived) + else: + (msg, rest) = got + self.buf = bytearray(rest) + self._log_rx(msg) + return msg def send_one_message(self, message: str, *, expecting_response : bool = True) -> None: self._log_tx(message) @@ -425,16 +405,6 @@ def get_id(self) -> int: this connection.""" return self.ids.get() - def _process_replies(self) -> None: - """Remove all pending replies from the internal buffer, parse them - into JSON, and add them to the internal collection of replies. - """ - reply_bytes = self.process.get_one_reply() - while reply_bytes is not None: - the_reply = json.loads(reply_bytes) - self.replies[the_reply['id']] = the_reply - reply_bytes = self.process.get_one_reply() - def send_command(self, method: str, params: dict, *, timeout : Optional[float] = None) -> int: """Send a message to the server with the given JSONRPC command method and parameters. The return value is the unique request @@ -478,11 +448,17 @@ def send_notification(self, method: str, params: dict) -> None: def wait_for_reply_to(self, request_id: int) -> Any: """Block until a reply is received for the given ``request_id``. Return the reply.""" - self._process_replies() - while request_id not in self.replies: - self._process_replies() - return self.replies[request_id] #self.replies.pop(request_id) # delete reply while returning it + if request_id in self.replies: + return self.replies.pop(request_id) + + while True: + reply_bytes = self.process.get_one_reply() + the_reply = json.loads(reply_bytes) + if the_reply['id'] == request_id: + return the_reply + else: + self.replies[the_reply['id']] = the_reply def logging(self, on : bool, *, dest : TextIO = sys.stderr) -> None: """Whether to log received and transmitted JSON.""" diff --git a/python/argo_client/netstring.py b/python/argo_client/netstring.py index 4641866..286c840 100644 --- a/python/argo_client/netstring.py +++ b/python/argo_client/netstring.py @@ -2,7 +2,7 @@ as a lightweight transport layer for JSON RPC. """ -from typing import Tuple +from typing import Optional, Tuple def encode(string : str) -> bytes: """Encode a ``str`` into a netstring. @@ -13,7 +13,13 @@ def encode(string : str) -> bytes: bytestring = string.encode() return str(len(bytestring)).encode() + b':' + bytestring + b',' -def decode(netstring : bytes) -> Tuple[str, bytes]: +class InvalidNetstring(Exception): + """Exception for malformed netstrings""" + def __init__(self, message): + self.message = message + super().__init__(self.message) + +def decode(netstring : bytes) -> Optional[Tuple[str, bytes]]: """Decode the first valid netstring from a bytestring, returning its string contents and the remainder of the bytestring. @@ -22,20 +28,26 @@ def decode(netstring : bytes) -> Tuple[str, bytes]: """ - i = 0 - length_bytes = bytearray(b'') - while chr(netstring[i]).isdigit(): - length_bytes.append(netstring[i]) - i += 1 - if chr(netstring[i]).encode() != b':': - raise ValueError("Malformed netstring, missing :") - length = int(length_bytes.decode()) - i += 1 - out = bytearray(b'') - for j in range(0, length): - out.append(netstring[i]) - i += 1 - if chr(netstring[i]).encode() != b',': - raise ValueError("Malformed netstring, missing ,") - i += 1 - return (out.decode(), netstring[i:]) + colon = netstring.find(b':') + if colon == -1 and len(netstring) >= 10 or colon >= 10: + # cut things off at about a gigabyte + raise InvalidNetstring("message length too long") + + if colon == -1: + # incomplete length, wait for more bytes + return None + + lengthstring = netstring[0:colon] + if colon == 0 or not lengthstring.isdigit(): + raise InvalidNetstring("invalid format, malformed message length") + + length = int(lengthstring) + comma = colon + length + 1 + if len(netstring) < comma: + # incomplete message, wait for more bytes + return None + + if netstring[comma] != 44: # comma + raise InvalidNetstring("invalid format, missing comma") + + return (netstring[colon + 1 : comma].decode(), netstring[comma+1:])