Skip to content

Commit

Permalink
Don't log an error when process_request returns a response.
Browse files Browse the repository at this point in the history
Fix #1513.
  • Loading branch information
aaugustin committed Nov 3, 2024
1 parent 810bdeb commit cdeb882
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 80 deletions.
6 changes: 3 additions & 3 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ async def handshake(
return_when=asyncio.FIRST_COMPLETED,
)

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a response, when the response cannot be parsed, or
# when the response fails the handshake.
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a response, when the response cannot be parsed, or when the
# response fails the handshake.

if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
Expand Down
17 changes: 12 additions & 5 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,13 @@ async def handshake(

self.protocol.send_response(self.response)

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends an HTTP response that rejects the handshake.
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a request, when the request cannot be parsed, or when the
# handshake fails, including when process_request or process_response
# raises an exception.

# It isn't set when process_request or process_response sends an HTTP
# response that rejects the handshake.

if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
Expand Down Expand Up @@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None:
connection.close_transport()
return

assert connection.protocol.state is OPEN
if connection.protocol.state is not OPEN:
# process_request or process_response rejected the handshake.
connection.close_transport()
return

try:
connection.start_keepalive()
await self.handler(connection)
Expand Down
29 changes: 24 additions & 5 deletions src/websockets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,15 +518,34 @@ def close_expected(self) -> bool:
Whether the TCP connection is expected to close soon.
"""
# We expect a TCP close if and only if we sent a close frame:
# During the opening handshake, when our state is CONNECTING, we expect
# a TCP close if and only if the hansdake fails. When it does, we start
# the TCP closing handshake by sending EOF with send_eof().

# Once the opening handshake completes successfully, we expect a TCP
# close if and only if we sent a close frame, meaning that our state
# progressed to CLOSING:

# * Normal closure: once we send a close frame, we expect a TCP close:
# server waits for client to complete the TCP closing handshake;
# client waits for server to initiate the TCP closing handshake.

# * Abnormal closure: we always send a close frame and the same logic
# applies, except on EOFError where we don't send a close frame
# because we already received the TCP close, so we don't expect it.
# We already got a TCP Close if and only if the state is CLOSED.
return self.state is CLOSING or self.handshake_exc is not None

# If our state is CLOSED, we already received a TCP close so we don't
# expect it anymore.

# Micro-optimization: put the most common case first
if self.state is OPEN:
return False
if self.state is CLOSING:
return True
if self.state is CLOSED:
return False
assert self.state is CONNECTING
return self.eof_sent

# Private methods for receiving data.

Expand Down Expand Up @@ -616,14 +635,14 @@ def discard(self) -> Generator[None]:
# connection in the same circumstances where discard() replaces parse().
# The client closes it when it receives EOF from the server or times
# out. (The latter case cannot be handled in this Sans-I/O layer.)
assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent)
while not (yield from self.reader.at_eof()):
self.reader.discard()
if self.debug:
self.logger.debug("< EOF")
# A server closes the TCP connection immediately, while a client
# waits for the server to close the TCP connection.
if self.state != CONNECTING and self.side is CLIENT:
if self.side is CLIENT and self.state is not CONNECTING:
self.send_eof()
self.state = CLOSED
# If discard() completes normally, execution ends here.
Expand Down
6 changes: 0 additions & 6 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
InvalidHeader,
InvalidHeaderValue,
InvalidOrigin,
InvalidStatus,
InvalidUpgrade,
NegotiationError,
)
Expand Down Expand Up @@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None:
self.logger.info("connection open")

else:
# handshake_exc may be already set if accept() encountered an error.
# If the connection isn't open, set handshake_exc to guarantee that
# handshake_exc is None if and only if opening handshake succeeded.
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info(
"connection rejected (%d %s)",
response.status_code,
Expand Down
6 changes: 3 additions & 3 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def handshake(
if not self.response_rcvd.wait(timeout):
raise TimeoutError("timed out during handshake")

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a response, when the response cannot be parsed, or
# when the response fails the handshake.
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a response, when the response cannot be parsed, or when the
# response fails the handshake.

if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
Expand Down
11 changes: 7 additions & 4 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,13 @@ def handshake(

self.protocol.send_response(self.response)

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends an HTTP response that rejects the handshake.
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a request, when the request cannot be parsed, or when the
# handshake fails, including when process_request or process_response
# raises an exception.

# It isn't set when process_request or process_response sends an HTTP
# response that rejects the handshake.

if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
Expand Down
2 changes: 1 addition & 1 deletion tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def asyncTearDown(self):
if sys.version_info[:2] < (3, 10): # pragma: no cover

@contextlib.contextmanager
def assertNoLogs(self, logger="websockets", level=logging.ERROR):
def assertNoLogs(self, logger=None, level=None):
"""
No message is logged on the given logger with at least the given level.
Expand Down
146 changes: 94 additions & 52 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,17 @@ def process_request(ws, request):
async def handler(ws):
self.fail("handler must not run")

async with serve(handler, *args[1:], process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 403",
)
with self.assertNoLogs("websockets", logging.ERROR):
async with serve(
handler, *args[1:], process_request=process_request
) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 403",
)

async def test_async_process_request_returns_response(self):
"""Server aborts handshake if async process_request returns a response."""
Expand All @@ -166,44 +169,65 @@ async def process_request(ws, request):
async def handler(ws):
self.fail("handler must not run")

async with serve(handler, *args[1:], process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 403",
)
with self.assertNoLogs("websockets", logging.ERROR):
async with serve(
handler, *args[1:], process_request=process_request
) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 403",
)

async def test_process_request_raises_exception(self):
"""Server returns an error if process_request raises an exception."""

def process_request(ws, request):
raise RuntimeError
raise RuntimeError("BOOM")

async with serve(*args, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
with self.assertLogs("websockets", logging.ERROR) as logs:
async with serve(*args, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
self.assertEqual(
[record.getMessage() for record in logs.records],
["opening handshake failed"],
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["BOOM"],
)

async def test_async_process_request_raises_exception(self):
"""Server returns an error if async process_request raises an exception."""

async def process_request(ws, request):
raise RuntimeError
raise RuntimeError("BOOM")

async with serve(*args, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
with self.assertLogs("websockets", logging.ERROR) as logs:
async with serve(*args, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
self.assertEqual(
[record.getMessage() for record in logs.records],
["opening handshake failed"],
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["BOOM"],
)

async def test_process_response_returns_none(self):
"""Server runs process_response but keeps the handshake response."""
Expand Down Expand Up @@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self):
"""Server returns an error if process_response raises an exception."""

def process_response(ws, request, response):
raise RuntimeError
raise RuntimeError("BOOM")

async with serve(*args, process_response=process_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
with self.assertLogs("websockets", logging.ERROR) as logs:
async with serve(*args, process_response=process_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
self.assertEqual(
[record.getMessage() for record in logs.records],
["opening handshake failed"],
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["BOOM"],
)

async def test_async_process_response_raises_exception(self):
"""Server returns an error if async process_response raises an exception."""

async def process_response(ws, request, response):
raise RuntimeError
raise RuntimeError("BOOM")

async with serve(*args, process_response=process_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
with self.assertLogs("websockets", logging.ERROR) as logs:
async with serve(*args, process_response=process_response) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 500",
)
self.assertEqual(
[record.getMessage() for record in logs.records],
["opening handshake failed"],
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["BOOM"],
)

async def test_override_server(self):
"""Server can override Server header with server_header."""
Expand Down
20 changes: 19 additions & 1 deletion tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Frame,
)
from websockets.protocol import *
from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER
from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER

from .extensions.utils import Rsv2Extension
from .test_frames import FramesTestCase
Expand Down Expand Up @@ -1696,6 +1696,24 @@ def test_server_fails_connection(self):
server.fail(CloseCode.PROTOCOL_ERROR)
self.assertTrue(server.close_expected())

def test_client_is_connecting(self):
client = Protocol(CLIENT, state=CONNECTING)
self.assertFalse(client.close_expected())

def test_server_is_connecting(self):
server = Protocol(SERVER, state=CONNECTING)
self.assertFalse(server.close_expected())

def test_client_failed_connecting(self):
client = Protocol(CLIENT, state=CONNECTING)
client.send_eof()
self.assertTrue(client.close_expected())

def test_server_failed_connecting(self):
server = Protocol(SERVER, state=CONNECTING)
server.send_eof()
self.assertTrue(server.close_expected())


class ConnectionClosedTests(ProtocolTestCase):
"""
Expand Down

0 comments on commit cdeb882

Please sign in to comment.