From 42bf4a3d0de71f2401c3a759068a817d40578f0c Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Fri, 5 Apr 2024 08:39:51 +0200 Subject: [PATCH 1/9] Removed Django version number from project description. Falls out of sync with other (canonical) metadata. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8a2a442c..f956c46e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ version = attr: channels.__version__ url = http://github.com/django/channels author = Django Software Foundation author_email = foundation@djangoproject.com -description = Brings async, event-driven capabilities to Django 3.2 and up. +description = Brings async, event-driven capabilities to Django. long_description = file: README.rst long_description_content_type = text/x-rst license = BSD From 8087d475f01bf40d51495ea57aee4b38458a910f Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Mon, 15 Apr 2024 00:12:42 -0700 Subject: [PATCH 2/9] Drop long deleted cookie_date (#2091) --- channels/sessions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/channels/sessions.py b/channels/sessions.py index fa16ec99..0a4cadd2 100644 --- a/channels/sessions.py +++ b/channels/sessions.py @@ -10,14 +10,10 @@ from django.utils import timezone from django.utils.encoding import force_str from django.utils.functional import LazyObject +from django.utils.http import http_date from channels.db import database_sync_to_async -try: - from django.utils.http import http_date -except ImportError: - from django.utils.http import cookie_date as http_date - class CookieMiddleware: """ From 42deaca0e25f5dbb6c5133dc969366b93526960f Mon Sep 17 00:00:00 2001 From: Karel Hovorka Date: Mon, 13 May 2024 13:22:18 +0700 Subject: [PATCH 3/9] Made WebsocketCommunicator assertions more informative. (#2098) --- channels/testing/websocket.py | 14 ++++++++++---- tests/test_testing.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index dd48686d..57ea4a65 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -78,19 +78,23 @@ async def receive_from(self, timeout=1): """ response = await self.receive_output(timeout) # Make sure this is a send message - assert response["type"] == "websocket.send" + assert ( + response["type"] == "websocket.send" + ), f"Expected type 'websocket.send', but was '{response['type']}'" # Make sure there's exactly one key in the response assert ("text" in response) != ( "bytes" in response ), "The response needs exactly one of 'text' or 'bytes'" # Pull out the right key and typecheck it for our users if "text" in response: - assert isinstance(response["text"], str), "Text frame payload is not str" + assert isinstance( + response["text"], str + ), f"Text frame payload is not str, it is {type(response['text'])}" return response["text"] else: assert isinstance( response["bytes"], bytes - ), "Binary frame payload is not bytes" + ), f"Binary frame payload is not bytes, it is {type(response['bytes'])}" return response["bytes"] async def receive_json_from(self, timeout=1): @@ -98,7 +102,9 @@ async def receive_json_from(self, timeout=1): Receives a JSON text frame payload and decodes it """ payload = await self.receive_from(timeout) - assert isinstance(payload, str), "JSON data is not a text frame" + assert isinstance( + payload, str + ), f"JSON data is not a text frame, it is {type(payload)}" return json.loads(payload) async def disconnect(self, code=1000, timeout=1): diff --git a/tests/test_testing.py b/tests/test_testing.py index 12164b9a..6beae7d5 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -47,6 +47,13 @@ def receive(self, text_data=None, bytes_data=None): self.send(text_data=text_data, bytes_data=bytes_data) +class AcceptCloseWebsocketApp(WebsocketConsumer): + def connect(self): + assert self.scope["path"] == "/testws/" + self.accept() + self.close() + + class ErrorWebsocketApp(WebsocketConsumer): """ Barebones WebSocket ASGI app for error testing. @@ -93,6 +100,25 @@ async def test_websocket_communicator(): await communicator.disconnect() +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_websocket_incorrect_read_json(): + """ + When using an invalid communicator method, an assertion error will be raised with + informative message. + In this test, the server accepts and then immediately closes the connection so + the server is not in a valid state to handle "receive_from". + """ + communicator = WebsocketCommunicator(AcceptCloseWebsocketApp(), "/testws/") + await communicator.connect() + with pytest.raises(AssertionError) as exception_info: + await communicator.receive_from() + assert ( + str(exception_info.value) + == "Expected type 'websocket.send', but was 'websocket.close'" + ) + + @pytest.mark.django_db @pytest.mark.asyncio async def test_websocket_application(): From 1d12e4c8942f279a1bc808010e78301be83737af Mon Sep 17 00:00:00 2001 From: Devid <13779643+sevdog@users.noreply.github.com> Date: Thu, 13 Jun 2024 07:36:04 +0100 Subject: [PATCH 4/9] Removed outdated deprecation message (#2103) --- channels/routing.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/channels/routing.py b/channels/routing.py index d09d5ffe..f48c4d33 100644 --- a/channels/routing.py +++ b/channels/routing.py @@ -33,20 +33,6 @@ def get_default_application(): return value -DEPRECATION_MSG = """ -Using ProtocolTypeRouter without an explicit "http" key is deprecated. -Given that you have not passed the "http" you likely should use Django's -get_asgi_application(): - - from django.core.asgi import get_asgi_application - - application = ProtocolTypeRouter( - "http": get_asgi_application() - # Other protocols here. - ) -""" - - class ProtocolTypeRouter: """ Takes a mapping of protocol type names to other Application instances, From 5d8ddd98148781d9021235c9c1ad5a126163bac4 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Sat, 15 Jun 2024 08:18:16 +0200 Subject: [PATCH 5/9] Added testing against Django 5.1. --- setup.cfg | 1 + tox.ini | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index f956c46e..3fa8b058 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Framework :: Django Framework :: Django :: 4.2 Framework :: Django :: 5.0 + Framework :: Django :: 5.1 Topic :: Internet :: WWW/HTTP [options] diff --git a/tox.ini b/tox.ini index 552ea833..88518b8d 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist = py{38,39,310,311}-dj42 py{310,311,312}-dj50 + py{310,311,312}-dj51 py{310,311,312}-djmain qa @@ -11,7 +12,8 @@ commands = pytest -v {posargs} deps = dj42: Django>=4.2,<5.0 - dj50: Django>=5.0rc1,<5.1 + dj50: Django>=5.0,<5.1 + dj51: Django>=5.1a1,<5.2 djmain: https://github.com/django/django/archive/main.tar.gz [testenv:qa] From aa91c280953dc649f92b83c709407017fd8e055f Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Wed, 10 Jul 2024 01:32:41 +0100 Subject: [PATCH 6/9] Don't actually close DB connections during tests (#2101) --- channels/testing/__init__.py | 3 +- channels/testing/application.py | 17 ++++++++++ channels/testing/http.py | 2 +- channels/testing/websocket.py | 2 +- docs/topics/testing.rst | 4 +-- tests/conftest.py | 9 +++++- tests/test_database.py | 55 +++++++++++++++++++++++++++++++++ 7 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 channels/testing/application.py create mode 100644 tests/test_database.py diff --git a/channels/testing/__init__.py b/channels/testing/__init__.py index f96625cd..d7dee3ef 100644 --- a/channels/testing/__init__.py +++ b/channels/testing/__init__.py @@ -1,5 +1,4 @@ -from asgiref.testing import ApplicationCommunicator # noqa - +from .application import ApplicationCommunicator # noqa from .http import HttpCommunicator # noqa from .live import ChannelsLiveServerTestCase # noqa from .websocket import WebsocketCommunicator # noqa diff --git a/channels/testing/application.py b/channels/testing/application.py new file mode 100644 index 00000000..2003178c --- /dev/null +++ b/channels/testing/application.py @@ -0,0 +1,17 @@ +from unittest import mock + +from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator + + +def no_op(): + pass + + +class ApplicationCommunicator(BaseApplicationCommunicator): + async def send_input(self, message): + with mock.patch("channels.db.close_old_connections", no_op): + return await super().send_input(message) + + async def receive_output(self, timeout=1): + with mock.patch("channels.db.close_old_connections", no_op): + return await super().receive_output(timeout) diff --git a/channels/testing/http.py b/channels/testing/http.py index 6b1514ca..8130265a 100644 --- a/channels/testing/http.py +++ b/channels/testing/http.py @@ -1,6 +1,6 @@ from urllib.parse import unquote, urlparse -from asgiref.testing import ApplicationCommunicator +from channels.testing.application import ApplicationCommunicator class HttpCommunicator(ApplicationCommunicator): diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index 57ea4a65..24e58d36 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -1,7 +1,7 @@ import json from urllib.parse import unquote, urlparse -from asgiref.testing import ApplicationCommunicator +from channels.testing.application import ApplicationCommunicator class WebsocketCommunicator(ApplicationCommunicator): diff --git a/docs/topics/testing.rst b/docs/topics/testing.rst index a3c14a00..c3547fd8 100644 --- a/docs/topics/testing.rst +++ b/docs/topics/testing.rst @@ -73,8 +73,8 @@ you might need to fall back to it if you are testing things like HTTP chunked responses or long-polling, which aren't supported in ``HttpCommunicator`` yet. .. note:: - ``ApplicationCommunicator`` is actually provided by the base ``asgiref`` - package, but we let you import it from ``channels.testing`` for convenience. + ``ApplicationCommunicator`` extends the class provided by the base ``asgiref`` + package. Channels adds support for running unit tests with async consumers. To construct it, pass it an application and a scope: diff --git a/tests/conftest.py b/tests/conftest.py index 8e7b3155..94c9803a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,14 @@ def pytest_configure(): settings.configure( - DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3"}}, + DATABASES={ + "default": { + "ENGINE": "django.db.backends.sqlite3", + # Override Django’s default behaviour of using an in-memory database + # in tests for SQLite, since that avoids connection.close() working. + "TEST": {"NAME": "test_db.sqlite3"}, + } + }, INSTALLED_APPS=[ "django.contrib.auth", "django.contrib.contenttypes", diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..3faf05b5 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,55 @@ +from django import db +from django.test import TestCase + +from channels.db import database_sync_to_async +from channels.generic.http import AsyncHttpConsumer +from channels.generic.websocket import AsyncWebsocketConsumer +from channels.testing import HttpCommunicator, WebsocketCommunicator + + +@database_sync_to_async +def basic_query(): + with db.connections["default"].cursor() as cursor: + cursor.execute("SELECT 1234") + return cursor.fetchone()[0] + + +class WebsocketConsumer(AsyncWebsocketConsumer): + async def connect(self): + await basic_query() + await self.accept("fun") + + +class HttpConsumer(AsyncHttpConsumer): + async def handle(self, body): + await basic_query() + await self.send_response( + 200, + b"", + headers={b"Content-Type": b"text/plain"}, + ) + + +class ConnectionClosingTests(TestCase): + async def test_websocket(self): + self.assertNotRegex( + db.connections["default"].settings_dict.get("NAME"), + "memorydb", + "This bug only occurs when the database is materialized on disk", + ) + communicator = WebsocketCommunicator(WebsocketConsumer.as_asgi(), "/") + connected, subprotocol = await communicator.connect() + self.assertTrue(connected) + self.assertEqual(subprotocol, "fun") + + async def test_http(self): + self.assertNotRegex( + db.connections["default"].settings_dict.get("NAME"), + "memorydb", + "This bug only occurs when the database is materialized on disk", + ) + communicator = HttpCommunicator( + HttpConsumer.as_asgi(), method="GET", path="/test/" + ) + connected = await communicator.get_response() + self.assertTrue(connected) From 8d90b07ba40bdbfe4d0b928e319c5a47a5a3ec5d Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Thu, 11 Jul 2024 03:59:18 +0100 Subject: [PATCH 7/9] Improve async Django support and improve docs (#2090) --- channels/consumer.py | 3 ++- channels/db.py | 6 +++++- channels/generic/http.py | 2 ++ channels/generic/websocket.py | 2 ++ docs/topics/consumers.rst | 3 ++- docs/topics/databases.rst | 23 ++++++++++++++++++++++- docs/tutorial/part_3.rst | 10 +++++++--- tests/security/test_websocket.py | 1 + tests/test_generic_http.py | 4 ++++ tests/test_generic_websocket.py | 7 +++++++ tests/test_testing.py | 1 + 11 files changed, 55 insertions(+), 7 deletions(-) diff --git a/channels/consumer.py b/channels/consumer.py index 85543d6c..fc065432 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -3,7 +3,7 @@ from asgiref.sync import async_to_sync from . import DEFAULT_CHANNEL_LAYER -from .db import database_sync_to_async +from .db import aclose_old_connections, database_sync_to_async from .exceptions import StopConsumer from .layers import get_channel_layer from .utils import await_many_dispatch @@ -70,6 +70,7 @@ async def dispatch(self, message): """ handler = getattr(self, get_handler_name(message), None) if handler: + await aclose_old_connections() await handler(message) else: raise ValueError("No handler for message type %s" % message["type"]) diff --git a/channels/db.py b/channels/db.py index 0650e7a8..2961b5cd 100644 --- a/channels/db.py +++ b/channels/db.py @@ -1,4 +1,4 @@ -from asgiref.sync import SyncToAsync +from asgiref.sync import SyncToAsync, sync_to_async from django.db import close_old_connections @@ -17,3 +17,7 @@ def thread_handler(self, loop, *args, **kwargs): # The class is TitleCased, but we want to encourage use as a callable/decorator database_sync_to_async = DatabaseSyncToAsync + + +async def aclose_old_connections(): + return await sync_to_async(close_old_connections)() diff --git a/channels/generic/http.py b/channels/generic/http.py index 909e8570..0d043cc3 100644 --- a/channels/generic/http.py +++ b/channels/generic/http.py @@ -1,5 +1,6 @@ from channels.consumer import AsyncConsumer +from ..db import aclose_old_connections from ..exceptions import StopConsumer @@ -88,4 +89,5 @@ async def http_disconnect(self, message): Let the user do their cleanup and close the consumer. """ await self.disconnect() + await aclose_old_connections() raise StopConsumer() diff --git a/channels/generic/websocket.py b/channels/generic/websocket.py index 9ce2657b..6d41c8ee 100644 --- a/channels/generic/websocket.py +++ b/channels/generic/websocket.py @@ -3,6 +3,7 @@ from asgiref.sync import async_to_sync from ..consumer import AsyncConsumer, SyncConsumer +from ..db import aclose_old_connections from ..exceptions import ( AcceptConnection, DenyConnection, @@ -247,6 +248,7 @@ async def websocket_disconnect(self, message): "BACKEND is unconfigured or doesn't support groups" ) await self.disconnect(message["code"]) + await aclose_old_connections() raise StopConsumer() async def disconnect(self, code): diff --git a/docs/topics/consumers.rst b/docs/topics/consumers.rst index 69249147..294a9aed 100644 --- a/docs/topics/consumers.rst +++ b/docs/topics/consumers.rst @@ -112,7 +112,8 @@ callable into an asynchronous coroutine. If you want to call the Django ORM from an ``AsyncConsumer`` (or any other asynchronous code), you should use the ``database_sync_to_async`` adapter - instead. See :doc:`/topics/databases` for more. + or use the async versions of the methods (prefixed with ``a``, like ``aget``). + See :doc:`/topics/databases` for more. Closing Consumers diff --git a/docs/topics/databases.rst b/docs/topics/databases.rst index 5d06bebe..e0d2c4af 100644 --- a/docs/topics/databases.rst +++ b/docs/topics/databases.rst @@ -11,7 +11,8 @@ code is already run in a synchronous mode and Channels will do the cleanup for you as part of the ``SyncConsumer`` code. If you are writing asynchronous code, however, you will need to call -database methods in a safe, synchronous context, using ``database_sync_to_async``. +database methods in a safe, synchronous context, using ``database_sync_to_async`` +or by using the asynchronous methods prefixed with ``a`` like ``Model.objects.aget()``. Database Connections @@ -26,6 +27,11 @@ Python 3.7 and below, and `min(32, os.cpu_count() + 4)` for Python 3.8+. To avoid having too many threads idling in connections, you can instead rewrite your code to use async consumers and only dip into threads when you need to use Django's ORM (using ``database_sync_to_async``). +When using async consumers Channels will automatically call Django's ``close_old_connections`` method when a new connection is started, when a connection is closed, and whenever anything is received from the client. +This mirrors Django's logic for closing old connections at the start and end of a request, to the extent possible. Connections are *not* automatically closed when sending data from a consumer since Channels has no way +to determine if this is a one-off send (and connections could be closed) or a series of sends (in which closing connections would kill performance). Instead, if you have a long-lived async consumer you should +periodically call ``aclose_old_connections`` (see below). + database_sync_to_async ---------------------- @@ -58,3 +64,18 @@ You can also use it as a decorator: @database_sync_to_async def get_name(self): return User.objects.all()[0].name + +aclose_old_connections +---------------------- + +``django.db.aclose_old_connections`` is an async wrapper around Django's +``close_old_connections``. When using a long-lived ``AsyncConsumer`` that +calls the Django ORM it is important to call this function periodically. + +Preferrably, this function should be called before making the first query +in a while. For example, it should be called if the Consumer is woken up +by a channels layer event and needs to make a few ORM queries to determine +what to send to the client. This function should be called *before* making +those queries. Calling this function more than necessary is not necessarily +a bad thing, but it does require a context switch to synchronous code and +so incurs a small penalty. \ No newline at end of file diff --git a/docs/tutorial/part_3.rst b/docs/tutorial/part_3.rst index 99182362..dbda8474 100644 --- a/docs/tutorial/part_3.rst +++ b/docs/tutorial/part_3.rst @@ -15,16 +15,20 @@ asynchronous consumers can provide a higher level of performance since they don't need to create additional threads when handling requests. ``ChatConsumer`` only uses async-native libraries (Channels and the channel layer) -and in particular it does not access synchronous Django models. Therefore it can +and in particular it does not access synchronous code. Therefore it can be rewritten to be asynchronous without complications. .. note:: - Even if ``ChatConsumer`` *did* access Django models or other synchronous code it + Even if ``ChatConsumer`` *did* access Django models or synchronous code it would still be possible to rewrite it as asynchronous. Utilities like :ref:`asgiref.sync.sync_to_async ` and :doc:`channels.db.database_sync_to_async ` can be used to call synchronous code from an asynchronous consumer. The performance - gains however would be less than if it only used async-native libraries. + gains however would be less than if it only used async-native libraries. Django + models include methods prefixed with ``a`` that can be used safely from async + contexts, provided that + :doc:`channels.db.aclose_old_connections ` is called + occasionally. Let's rewrite ``ChatConsumer`` to be asynchronous. Put the following code in ``chat/consumers.py``: diff --git a/tests/security/test_websocket.py b/tests/security/test_websocket.py index 52f9e21b..1444ea82 100644 --- a/tests/security/test_websocket.py +++ b/tests/security/test_websocket.py @@ -5,6 +5,7 @@ from channels.testing import WebsocketCommunicator +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_origin_validator(): """ diff --git a/tests/test_generic_http.py b/tests/test_generic_http.py index bfb889c0..0b6d0ecb 100644 --- a/tests/test_generic_http.py +++ b/tests/test_generic_http.py @@ -8,6 +8,7 @@ from channels.testing import HttpCommunicator +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_http_consumer(): """ @@ -38,6 +39,7 @@ async def handle(self, body): assert response["headers"] == [(b"Content-Type", b"application/json")] +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_error(): class TestConsumer(AsyncHttpConsumer): @@ -51,6 +53,7 @@ async def handle(self, body): assert str(excinfo.value) == "Error correctly raised" +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_per_scope_consumers(): """ @@ -87,6 +90,7 @@ async def handle(self, body): assert response["body"] != second_response["body"] +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_http_consumer_future(): """ diff --git a/tests/test_generic_websocket.py b/tests/test_generic_websocket.py index 73cdb486..c553eb84 100644 --- a/tests/test_generic_websocket.py +++ b/tests/test_generic_websocket.py @@ -154,6 +154,7 @@ def receive(self, text_data=None, bytes_data=None): assert channel_layer.groups == {} +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_websocket_consumer(): """ @@ -195,6 +196,7 @@ async def disconnect(self, code): assert "disconnected" in results +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_websocket_consumer_subprotocol(): """ @@ -217,6 +219,7 @@ async def connect(self): assert subprotocol == "subprotocol2" +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_websocket_consumer_groups(): """ @@ -253,6 +256,7 @@ async def receive(self, text_data=None, bytes_data=None): assert channel_layer.groups == {} +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_websocket_consumer_specific_channel_layer(): """ @@ -323,6 +327,7 @@ def receive_json(self, data=None): await communicator.wait() +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_async_json_websocket_consumer(): """ @@ -355,6 +360,7 @@ async def receive_json(self, data=None): await communicator.wait() +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_block_underscored_type_function_call(): """ @@ -390,6 +396,7 @@ async def _my_private_handler(self, _): await communicator.receive_from() +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_block_leading_dot_type_function_call(): """ diff --git a/tests/test_testing.py b/tests/test_testing.py index 6beae7d5..fbfbf436 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -23,6 +23,7 @@ async def http_request(self, event): await self.send({"type": "http.response.body", "body": b"test response"}) +@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_http_communicator(): """ From e39fe13dac76ed9a43d454d2d5616dc44e36fa8f Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Mon, 29 Jul 2024 17:49:35 -0700 Subject: [PATCH 8/9] Use the async sessions api if it exists (#2092) --- channels/sessions.py | 14 ++++--- docs/topics/sessions.rst | 3 +- tests/test_http.py | 89 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 11 deletions(-) diff --git a/channels/sessions.py b/channels/sessions.py index 0a4cadd2..f1d51d47 100644 --- a/channels/sessions.py +++ b/channels/sessions.py @@ -2,6 +2,7 @@ import time from importlib import import_module +import django from django.conf import settings from django.contrib.sessions.backends.base import UpdateError from django.core.exceptions import SuspiciousOperation @@ -163,9 +164,7 @@ def __init__(self, scope, send): async def resolve_session(self): session_key = self.scope["cookies"].get(self.cookie_name) - self.scope["session"]._wrapped = await database_sync_to_async( - self.session_store - )(session_key) + self.scope["session"]._wrapped = self.session_store(session_key) async def send(self, message): """ @@ -183,7 +182,7 @@ async def send(self, message): and message.get("status", 200) != 500 and (modified or settings.SESSION_SAVE_EVERY_REQUEST) ): - await database_sync_to_async(self.save_session)() + await self.save_session() # If this is a message type that can transport cookies back to the # client, then do so. if message["type"] in self.cookie_response_message_types: @@ -221,12 +220,15 @@ async def send(self, message): # Pass up the send return await self.real_send(message) - def save_session(self): + async def save_session(self): """ Saves the current session. """ try: - self.scope["session"].save() + if django.VERSION >= (5, 1): + await self.scope["session"].asave() + else: + await database_sync_to_async(self.scope["session"].save)() except UpdateError: raise SuspiciousOperation( "The request's session was deleted before the " diff --git a/docs/topics/sessions.rst b/docs/topics/sessions.rst index 29abb9cd..871194c1 100644 --- a/docs/topics/sessions.rst +++ b/docs/topics/sessions.rst @@ -73,7 +73,8 @@ whenever the session is modified. If you are in a WebSocket consumer, however, the session is populated **but will never be saved automatically** - you must call -``scope["session"].save()`` yourself whenever you want to persist a session +``scope["session"].save()`` (or the asynchronous version, +``scope["session"].asave()``) yourself whenever you want to persist a session to your session store. If you don't save, the session will still work correctly inside the consumer (as it's stored as an instance variable), but other connections or HTTP views won't be able to see the changes. diff --git a/tests/test_http.py b/tests/test_http.py index aa1fbe50..bb55ba0c 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,6 +1,9 @@ import re +from importlib import import_module +import django import pytest +from django.conf import settings from channels.consumer import AsyncConsumer from channels.db import database_sync_to_async @@ -93,15 +96,12 @@ async def test_session_samesite_invalid(samesite_invalid): @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_muliple_sessions(): +async def test_multiple_sessions(): """ Create two application instances and test then out of order to verify that separate scopes are used. """ - async def inner(scope, receive, send): - send(scope["path"]) - class SimpleHttpApp(AsyncConsumer): async def http_request(self, event): await database_sync_to_async(self.scope["session"].save)() @@ -123,3 +123,84 @@ async def http_request(self, event): first_response = await first_communicator.get_response() assert first_response["body"] == b"/first/" + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_session_saves(): + """ + Saves information to a session and validates that it actually saves to the backend + """ + + class SimpleHttpApp(AsyncConsumer): + @database_sync_to_async + def set_fav_color(self): + self.scope["session"]["fav_color"] = "blue" + + async def http_request(self, event): + if django.VERSION >= (5, 1): + await self.scope["session"].aset("fav_color", "blue") + else: + await self.set_fav_color() + await self.send( + {"type": "http.response.start", "status": 200, "headers": []} + ) + await self.send( + { + "type": "http.response.body", + "body": self.scope["session"].session_key.encode(), + } + ) + + app = SessionMiddlewareStack(SimpleHttpApp.as_asgi()) + + communicator = HttpCommunicator(app, "GET", "/first/") + + response = await communicator.get_response() + session_key = response["body"].decode() + + SessionStore = import_module(settings.SESSION_ENGINE).SessionStore + session = SessionStore(session_key=session_key) + if django.VERSION >= (5, 1): + session_fav_color = await session.aget("fav_color") + else: + session_fav_color = await database_sync_to_async(session.get)("fav_color") + + assert session_fav_color == "blue" + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_session_save_update_error(): + """ + Intentionally deletes the session to ensure that SuspiciousOperation is raised + """ + + async def inner(scope, receive, send): + send(scope["path"]) + + class SimpleHttpApp(AsyncConsumer): + @database_sync_to_async + def set_fav_color(self): + self.scope["session"]["fav_color"] = "blue" + + async def http_request(self, event): + # Create a session as normal: + await database_sync_to_async(self.scope["session"].save)() + + # Then simulate it's deletion from somewhere else: + # (e.g. logging out from another request) + SessionStore = import_module(settings.SESSION_ENGINE).SessionStore + session = SessionStore(session_key=self.scope["session"].session_key) + await database_sync_to_async(session.flush)() + + await self.send( + {"type": "http.response.start", "status": 200, "headers": []} + ) + + app = SessionMiddlewareStack(SimpleHttpApp.as_asgi()) + + communicator = HttpCommunicator(app, "GET", "/first/") + + with pytest.raises(django.core.exceptions.SuspiciousOperation): + await communicator.get_response() From e5331869fd7b26dd396ea1f60743feadfc1f4ca2 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 30 Jul 2024 03:22:04 +0200 Subject: [PATCH 9/9] InMemoryChannelLayer improvements, test fixes (#1976) --- channels/layers.py | 63 +++++++++++++++++++---------------- tests/test_inmemorychannel.py | 30 +++++++++++++++-- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 12bbd2b8..48f7baca 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -198,13 +198,13 @@ def __init__( group_expiry=86400, capacity=100, channel_capacity=None, - **kwargs + **kwargs, ): super().__init__( expiry=expiry, capacity=capacity, channel_capacity=channel_capacity, - **kwargs + **kwargs, ) self.channels = {} self.groups = {} @@ -225,13 +225,14 @@ async def send(self, channel, message): # name in message assert "__asgi_channel__" not in message - queue = self.channels.setdefault(channel, asyncio.Queue()) - # Are we full - if queue.qsize() >= self.capacity: - raise ChannelFull(channel) - + queue = self.channels.setdefault( + channel, asyncio.Queue(maxsize=self.get_capacity(channel)) + ) # Add message - await queue.put((time.time() + self.expiry, deepcopy(message))) + try: + queue.put_nowait((time.time() + self.expiry, deepcopy(message))) + except asyncio.queues.QueueFull: + raise ChannelFull(channel) async def receive(self, channel): """ @@ -242,14 +243,16 @@ async def receive(self, channel): assert self.valid_channel_name(channel) self._clean_expired() - queue = self.channels.setdefault(channel, asyncio.Queue()) + queue = self.channels.setdefault( + channel, asyncio.Queue(maxsize=self.get_capacity(channel)) + ) # Do a plain direct receive try: _, message = await queue.get() finally: if queue.empty(): - del self.channels[channel] + self.channels.pop(channel, None) return message @@ -279,19 +282,17 @@ def _clean_expired(self): self._remove_from_groups(channel) # Is the channel now empty and needs deleting? if queue.empty(): - del self.channels[channel] + self.channels.pop(channel, None) # Group Expiration timeout = int(time.time()) - self.group_expiry - for group in self.groups: - for channel in list(self.groups.get(group, set())): - # If join time is older than group_expiry end the group membership - if ( - self.groups[group][channel] - and int(self.groups[group][channel]) < timeout - ): + for channels in self.groups.values(): + for name, timestamp in list(channels.items()): + # If join time is older than group_expiry + # end the group membership + if timestamp and timestamp < timeout: # Delete from group - del self.groups[group][channel] + channels.pop(name, None) # Flush extension @@ -308,8 +309,7 @@ def _remove_from_groups(self, channel): Removes a channel from all groups. Used when a message on it expires. """ for channels in self.groups.values(): - if channel in channels: - del channels[channel] + channels.pop(channel, None) # Groups extension @@ -329,11 +329,13 @@ async def group_discard(self, group, channel): assert self.valid_channel_name(channel), "Invalid channel name" assert self.valid_group_name(group), "Invalid group name" # Remove from group set - if group in self.groups: - if channel in self.groups[group]: - del self.groups[group][channel] - if not self.groups[group]: - del self.groups[group] + group_channels = self.groups.get(group, None) + if group_channels: + # remove channel if in group + group_channels.pop(channel, None) + # is group now empty? If yes remove it + if not group_channels: + self.groups.pop(group, None) async def group_send(self, group, message): # Check types @@ -341,10 +343,15 @@ async def group_send(self, group, message): assert self.valid_group_name(group), "Invalid group name" # Run clean self._clean_expired() + # Send to each channel - for channel in self.groups.get(group, set()): + ops = [] + if group in self.groups: + for channel in self.groups[group].keys(): + ops.append(asyncio.create_task(self.send(channel, message))) + for send_result in asyncio.as_completed(ops): try: - await self.send(channel, message) + await send_result except ChannelFull: pass diff --git a/tests/test_inmemorychannel.py b/tests/test_inmemorychannel.py index 3f05ed7e..4ba4bfab 100644 --- a/tests/test_inmemorychannel.py +++ b/tests/test_inmemorychannel.py @@ -26,9 +26,36 @@ async def test_send_receive(channel_layer): await channel_layer.send( "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} ) + await channel_layer.send( + "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} + ) message = await channel_layer.receive("test-channel-1") assert message["type"] == "test.message" assert message["text"] == "Ahoy-hoy!" + # not removed because not empty + assert "test-channel-1" in channel_layer.channels + message = await channel_layer.receive("test-channel-1") + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" + # removed because empty + assert "test-channel-1" not in channel_layer.channels + + +@pytest.mark.asyncio +async def test_race_empty(channel_layer): + """ + Makes sure the race is handled gracefully. + """ + receive_task = asyncio.create_task(channel_layer.receive("test-channel-1")) + await asyncio.sleep(0.1) + await channel_layer.send( + "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} + ) + del channel_layer.channels["test-channel-1"] + await asyncio.sleep(0.1) + message = await receive_task + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" @pytest.mark.asyncio @@ -62,7 +89,6 @@ async def test_multi_send_receive(channel_layer): """ Tests overlapping sends and receives, and ordering. """ - channel_layer = InMemoryChannelLayer() await channel_layer.send("test-channel-3", {"type": "message.1"}) await channel_layer.send("test-channel-3", {"type": "message.2"}) await channel_layer.send("test-channel-3", {"type": "message.3"}) @@ -76,7 +102,6 @@ async def test_groups_basic(channel_layer): """ Tests basic group operation. """ - channel_layer = InMemoryChannelLayer() await channel_layer.group_add("test-group", "test-gr-chan-1") await channel_layer.group_add("test-group", "test-gr-chan-2") await channel_layer.group_add("test-group", "test-gr-chan-3") @@ -97,7 +122,6 @@ async def test_groups_channel_full(channel_layer): """ Tests that group_send ignores ChannelFull """ - channel_layer = InMemoryChannelLayer() await channel_layer.group_add("test-group", "test-gr-chan-1") await channel_layer.group_send("test-group", {"type": "message.1"}) await channel_layer.group_send("test-group", {"type": "message.1"})