diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 4650cf16..0b36202e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -297,6 +297,8 @@ def send_continuation(self, data: bytes, fin: bool) -> None: """ if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) @@ -318,6 +320,8 @@ def send_text(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) @@ -339,6 +343,8 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) @@ -358,6 +364,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: without a code. """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") @@ -383,6 +393,9 @@ def send_ping(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: @@ -396,6 +409,9 @@ def send_pong(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: @@ -675,6 +691,8 @@ def recv_frame(self, frame: Frame) -> None: # 1.4. Closing Handshake: "after receiving a control frame # indicating the connection should be closed, a peer discards # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. self.parser = self.discard() next(self.parser) # start coroutine @@ -687,15 +705,13 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - if self.debug: self.logger.debug("> %s", frame) self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) ) def send_eof(self) -> None: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b53c8a1e..1d5dab7a 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -465,15 +465,17 @@ def test_client_sends_text_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) @@ -679,15 +681,17 @@ def test_client_sends_binary_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) @@ -956,6 +960,37 @@ def test_server_receives_close_with_non_utf8_reason(self): ) self.assertIs(server.state, CLOSING) + def test_client_sends_close_twice(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_server_sends_close_twice(self): + server = Protocol(SERVER) + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_client_sends_close_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_close_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closed") + class PingTests(ProtocolTestCase): """ @@ -1072,35 +1107,23 @@ def test_client_sends_ping_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1109,9 +1132,24 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_ping_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_ping_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class PongTests(ProtocolTestCase): """ @@ -1212,23 +1250,23 @@ def test_client_sends_pong_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1237,9 +1275,24 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_pong_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_pong_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class FailTests(ProtocolTestCase): """ @@ -1370,8 +1423,9 @@ def test_client_send_close_in_fragmented_message(self): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_server_send_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1380,8 +1434,9 @@ def test_server_send_close_in_fragmented_message(self): server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT)