From 7adad97b5e4f99d54ab6029330840a8c27db98d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:13:39 +0100 Subject: [PATCH] Support max_queue=None like the legacy implementation. Fix #1540. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/client.py | 7 ++-- src/websockets/asyncio/connection.py | 4 +- src/websockets/asyncio/messages.py | 23 ++++++++--- src/websockets/asyncio/server.py | 7 ++-- src/websockets/sync/client.py | 7 ++-- src/websockets/sync/connection.py | 4 +- src/websockets/sync/messages.py | 25 +++++++++--- src/websockets/sync/server.py | 7 ++-- tests/asyncio/test_connection.py | 12 +++++- tests/asyncio/test_messages.py | 61 +++++++++++++++++++++++++--- tests/sync/test_connection.py | 15 ++++++- tests/sync/test_messages.py | 59 +++++++++++++++++++++++++-- 13 files changed, 198 insertions(+), 40 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 074a81c8..8e1ad81f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Improvements +............ + +* Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` + implementations for consistency with the legacy implementation, even though + this is never a good idea. + Bug fixes ......... diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 74ae70f0..cdd9bfac 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -60,7 +60,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol @@ -222,7 +222,8 @@ class connect: max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -283,7 +284,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 186846ef..f1dcbada 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -56,14 +56,14 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 14ea7bf9..e6d1d31c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -84,7 +84,7 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -96,12 +96,15 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -256,6 +259,10 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True @@ -263,6 +270,10 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 15c9ba13..fdb92800 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -71,7 +71,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol @@ -643,7 +643,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -713,7 +714,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 54d0aef6..9e6da7ca 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -55,7 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -139,7 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -191,7 +191,8 @@ def connect( max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77f803c9..be3381c8 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,12 +49,12 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index af8635f1..98490797 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -33,7 +33,7 @@ class Assembler: def __init__( self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -49,12 +49,15 @@ def __init__( # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -260,7 +263,12 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + assert self.mutex.locked() + # Check for "> high" to support high = 0 if self.frames.qsize() > self.high and not self.paused: self.paused = True @@ -268,7 +276,12 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + assert self.mutex.locked() + # Check for "<= low" to support low = 0 if self.frames.qsize() <= self.low and self.paused: self.paused = False diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 8601ccef..9506d683 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -66,7 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -356,7 +356,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -438,7 +438,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b1c57c8c..8dd0a033 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1066,14 +1066,22 @@ async def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=None) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + async def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" connection = Connection( Protocol(self.LOCAL), max_queue=(4, 2), diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 566f71ce..186e3506 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -153,7 +153,7 @@ async def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") async def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -170,6 +170,19 @@ async def test_get_resumes_reading(self): await self.assembler.get() self.resume.assert_called_once_with() + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) @@ -302,7 +315,7 @@ async def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -321,6 +334,20 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + + self.resume.assert_not_called() + async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -367,6 +394,17 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination async def test_get_fails_when_interrupted_by_close(self): @@ -495,16 +533,29 @@ async def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits async def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + async def test_unset_high_and_low_water_marks(self): + """high and low unset the high-water and low-water marks.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 408b9697..0ea487c4 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -760,8 +760,21 @@ def test_max_queue(self): ) self.assertEqual(connection.recv_messages.high, 4) + def test_max_queue_none(self): + """max_queue parameter disables high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.high, None) + def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 9ebe4508..c84d3000 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -145,7 +145,7 @@ def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -162,6 +162,19 @@ def test_get_resumes_reading(self): self.assembler.get() self.resume.assert_called_once_with() + def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + self.assembler.get() + self.assembler.get() + self.assembler.get() + + self.resume.assert_not_called() + def test_get_timeout_before_first_frame(self): """get times out before reading the first frame.""" with self.assertRaises(TimeoutError): @@ -319,6 +332,20 @@ def test_get_iter_resumes_reading(self): next(iterator) self.resume.assert_called_once_with() + def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = self.assembler.get_iter() + next(iterator) + next(iterator) + next(iterator) + + self.resume.assert_not_called() + # Test put def test_put_pauses_reading(self): @@ -336,6 +363,17 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -470,16 +508,29 @@ def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + def test_unset_high_and_low_water_marks(self): + """high and low unset the high-water and low-water marks.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError):