Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated websocket consumers for ASGI v2.3. #2002

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions channels/generic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ def websocket_connect(self, message):
def connect(self):
self.accept()

def accept(self, subprotocol=None):
def accept(self, subprotocol=None, headers=None):
"""
Accepts an incoming socket
"""
super().send({"type": "websocket.accept", "subprotocol": subprotocol})
message = {"type": "websocket.accept", "subprotocol": subprotocol}
if headers:
message["headers"] = list(headers)

super().send(message)

def websocket_receive(self, message):
"""
Expand Down Expand Up @@ -79,14 +83,16 @@ def send(self, text_data=None, bytes_data=None, close=False):
if close:
self.close(close)

def close(self, code=None):
def close(self, code=None, reason=None):
"""
Closes the WebSocket from the server end
"""
message = {"type": "websocket.close"}
if code is not None and code is not True:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my curiosity, why is there a check for the "code not being True"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, hard questions! I think this is a compatibility check. It's been a while, but IIRC, there was a case where a "True" value was being passed to close... I need to look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is not my code. Phew. So, your guess is as good as mine. Maybe in the distant past, someone would call close(True) for whatever reason and this is trying to allow for that? Nothing in the library or unit tests which is doing that, though.

super().send({"type": "websocket.close", "code": code})
else:
super().send({"type": "websocket.close"})
message["code"] = code
if reason:
message["reason"] = reason
Comment on lines +93 to +94
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if those conditionals are needed, but should be fine.

Copy link
Contributor Author

@kristjanvalur kristjanvalur Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to not pollute the response with an empty "reason" which is optional in any case.

super().send(message)

def websocket_disconnect(self, message):
"""
Expand Down Expand Up @@ -179,11 +185,14 @@ async def websocket_connect(self, message):
async def connect(self):
await self.accept()

async def accept(self, subprotocol=None):
async def accept(self, subprotocol=None, headers=None):
"""
Accepts an incoming socket
"""
await super().send({"type": "websocket.accept", "subprotocol": subprotocol})
message = {"type": "websocket.accept", "subprotocol": subprotocol}
if headers:
message["headers"] = list(headers)
await super().send(message)

async def websocket_receive(self, message):
"""
Expand Down Expand Up @@ -214,14 +223,16 @@ async def send(self, text_data=None, bytes_data=None, close=False):
if close:
await self.close(close)

async def close(self, code=None):
async def close(self, code=None, reason=None):
"""
Closes the WebSocket from the server end
"""
message = {"type": "websocket.close"}
if code is not None and code is not True:
await super().send({"type": "websocket.close", "code": code})
else:
await super().send({"type": "websocket.close"})
message["code"] = code
if reason:
message["reason"] = reason
await super().send(message)

async def websocket_disconnect(self, message):
"""
Expand Down
9 changes: 8 additions & 1 deletion channels/testing/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class WebsocketCommunicator(ApplicationCommunicator):
(uninstantiated) along with the initial connection parameters.
"""

def __init__(self, application, path, headers=None, subprotocols=None):
def __init__(
self, application, path, headers=None, subprotocols=None, spec_version=None
):
if not isinstance(path, str):
raise TypeError("Expected str, got {}".format(type(path)))
parsed = urlparse(path)
Expand All @@ -23,7 +25,10 @@ def __init__(self, application, path, headers=None, subprotocols=None):
"headers": headers or [],
"subprotocols": subprotocols or [],
}
if spec_version:
self.scope["spec_version"] = spec_version
super().__init__(application, self.scope)
self.response_headers = None

async def connect(self, timeout=1):
"""
Expand All @@ -37,6 +42,8 @@ async def connect(self, timeout=1):
if response["type"] == "websocket.close":
return (False, response.get("code", 1000))
else:
assert response["type"] == "websocket.accept"
self.response_headers = response.get("headers", [])
return (True, response.get("subprotocol", None))

async def send_to(self, text_data=None, bytes_data=None):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_generic_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,57 @@ async def _my_private_handler(self, _):
ValueError, match=r"Malformed type in message \(leading underscore\)"
):
await communicator.receive_from()


@pytest.mark.parametrize("async_consumer", [False, True])
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_accept_headers(async_consumer):
"""
Tests that JsonWebsocketConsumer is implemented correctly.
"""

class TestConsumer(WebsocketConsumer):
def connect(self):
self.accept(headers=[[b"foo", b"bar"]])

class AsyncTestConsumer(AsyncWebsocketConsumer):
async def connect(self):
await self.accept(headers=[[b"foo", b"bar"]])

app = AsyncTestConsumer() if async_consumer else TestConsumer()

# Open a connection
communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3")
connected, _ = await communicator.connect()
assert connected
assert communicator.response_headers == [[b"foo", b"bar"]]


@pytest.mark.parametrize("async_consumer", [False, True])
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_close_reason(async_consumer):
"""
Tests that JsonWebsocketConsumer is implemented correctly.
"""

class TestConsumer(WebsocketConsumer):
def connect(self):
self.accept()
self.close(code=4007, reason="test reason")

class AsyncTestConsumer(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
await self.close(code=4007, reason="test reason")

app = AsyncTestConsumer() if async_consumer else TestConsumer()

# Open a connection
communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3")
connected, _ = await communicator.connect()
msg = await communicator.receive_output()
assert msg["type"] == "websocket.close"
assert msg["code"] == 4007
assert msg["reason"] == "test reason"