Skip to content

Commit

Permalink
Using blocking network IO to avoid busy-wait
Browse files Browse the repository at this point in the history
This change also makes the netstring parser more robust to support the way that it's used in the network loop
  • Loading branch information
glguy committed Sep 4, 2024
1 parent 8da9e41 commit dd4ed90
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 83 deletions.
109 changes: 45 additions & 64 deletions python/argo_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,32 +158,22 @@ def __init__(self, command: str, *,
self.persist = persist
super().__init__(command, environment=environment)


def buffer_replies(self) -> None:
"""Read any replies that the server has sent, and add their byte
representation to the internal buffer, freeing up space in
the pipe or socket.
"""
try:
arrived = self.socket.recv(4096)
while arrived != b'':
self.buf.extend(arrived)
arrived = self.socket.recv(4096)
return None
except BlockingIOError:
return None

def get_one_reply(self) -> Optional[str]:
"""If a complete reply has been buffered, parse it from the buffer and
return it as a bytestring."""
self.buffer_replies()
try:
(msg, rest) = netstring.decode(self.buf)
self.buf = bytearray(rest)
self._log_rx(msg)
return msg
except (ValueError, IndexError):
return None
"""Return the next message if there is one. Block until the message
is ready or the socket has closed. Return None if the socket closes
and there are no buffered messages ready."""
while True:
got = netstring.decode(self.buf)
if got is None:
arrived = self.socket.recv(4096)
if arrived == '':
return None
self.buf.extend(arrived)
else:
(msg, rest) = got
self.buf = bytearray(rest)
self._log_rx(msg)
return msg

def send_one_message(self, message: str, expecting_response : bool = True) -> None:
self._log_tx(message)
Expand Down Expand Up @@ -261,33 +251,23 @@ def __init__(self, host: str, port: int, ipv6: bool=True):
def setup(self) -> None:
self.socket = socket.socket(socket.AF_INET6 if self.ipv6 else socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.host, self.port))
self.socket.setblocking(False)

def buffer_replies(self) -> None:
"""Read any replies that the server has sent, and add their byte
representation to the internal buffer, freeing up space in
the pipe or socket.
"""
try:
arrived = self.socket.recv(4096)
while arrived != b'':
self.buf.extend(arrived)
arrived = self.socket.recv(4096)
return None
except BlockingIOError:
return None

def get_one_reply(self) -> Optional[str]:
"""If a complete reply has been buffered, parse it from the buffer and
return it as a bytestring."""
self.buffer_replies()
try:
(msg, rest) = netstring.decode(self.buf)
self.buf = bytearray(rest)
self._log_rx(msg)
return msg
except (ValueError, IndexError):
return None
"""Return the next message if there is one. Block until the message
is ready or the socket has closed. Return None if the socket closes
and there are no buffered messages ready."""
while True:
got = netstring.decode(self.buf)
if got is None:
arrived = self.socket.recv(4096)
if arrived == '':
return None
self.buf.extend(arrived)
else:
(msg, rest) = got
self.buf = bytearray(rest)
self._log_rx(msg)
return msg

def send_one_message(self, message: str, *, expecting_response : bool = True) -> None:
self._log_tx(message)
Expand Down Expand Up @@ -425,16 +405,6 @@ def get_id(self) -> int:
this connection."""
return self.ids.get()

def _process_replies(self) -> None:
"""Remove all pending replies from the internal buffer, parse them
into JSON, and add them to the internal collection of replies.
"""
reply_bytes = self.process.get_one_reply()
while reply_bytes is not None:
the_reply = json.loads(reply_bytes)
self.replies[the_reply['id']] = the_reply
reply_bytes = self.process.get_one_reply()

def send_command(self, method: str, params: dict, *, timeout : Optional[float] = None) -> int:
"""Send a message to the server with the given JSONRPC command
method and parameters. The return value is the unique request
Expand Down Expand Up @@ -478,11 +448,22 @@ def send_notification(self, method: str, params: dict) -> None:
def wait_for_reply_to(self, request_id: int) -> Any:
"""Block until a reply is received for the given
``request_id``. Return the reply."""
self._process_replies()
while request_id not in self.replies:
self._process_replies()

return self.replies[request_id] #self.replies.pop(request_id) # delete reply while returning it
if request_id in self.replies:
return self.replies.pop(request_id)

while True:
reply_bytes = self.process.get_one_reply()

# The connection has closed, no more replies will be received
if reply_bytes is None:
raise ValueError("connection closed before reply found")

the_reply = json.loads(reply_bytes)
if the_reply['id'] == request_id:
return the_reply
else:
self.replies[the_reply['id']] = the_reply

def logging(self, on : bool, *, dest : TextIO = sys.stderr) -> None:
"""Whether to log received and transmitted JSON."""
Expand Down
50 changes: 31 additions & 19 deletions python/argo_client/netstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
as a lightweight transport layer for JSON RPC.
"""

from typing import Tuple
from typing import Optional, Tuple

def encode(string : str) -> bytes:
"""Encode a ``str`` into a netstring.
Expand All @@ -13,7 +13,13 @@ def encode(string : str) -> bytes:
bytestring = string.encode()
return str(len(bytestring)).encode() + b':' + bytestring + b','

def decode(netstring : bytes) -> Tuple[str, bytes]:
class InvalidNetstring(Exception):
"""Exception for malformed netstrings"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)

def decode(netstring : bytes) -> Optional[Tuple[str, bytes]]:
"""Decode the first valid netstring from a bytestring, returning its
string contents and the remainder of the bytestring.
Expand All @@ -22,20 +28,26 @@ def decode(netstring : bytes) -> Tuple[str, bytes]:
"""

i = 0
length_bytes = bytearray(b'')
while chr(netstring[i]).isdigit():
length_bytes.append(netstring[i])
i += 1
if chr(netstring[i]).encode() != b':':
raise ValueError("Malformed netstring, missing :")
length = int(length_bytes.decode())
i += 1
out = bytearray(b'')
for j in range(0, length):
out.append(netstring[i])
i += 1
if chr(netstring[i]).encode() != b',':
raise ValueError("Malformed netstring, missing ,")
i += 1
return (out.decode(), netstring[i:])
colon = netstring.find(b':')
if colon == -1 and len(netstring) >= 10 or colon >= 10:
# cut things off at about a gigabyte
raise InvalidNetstring("message length too long")

if colon == -1:
# incomplete length, wait for more bytes
return None

lengthstring = netstring[0:colon]
if colon == 0 or not lengthstring.isdigit():
raise InvalidNetstring("invalid format, malformed message length")

length = int(lengthstring)
comma = colon + length + 1
if len(netstring) <= comma:
# incomplete message, wait for more bytes
return None

if netstring[comma] != 44: # comma
raise InvalidNetstring("invalid format, missing comma")

return (netstring[colon + 1 : comma].decode(), netstring[comma+1:])

0 comments on commit dd4ed90

Please sign in to comment.