Skip to content

Commit

Permalink
Merge branch 'main' into ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
cclauss authored Aug 15, 2024
2 parents 530adc3 + e533186 commit 7e06b02
Show file tree
Hide file tree
Showing 26 changed files with 339 additions and 81 deletions.
3 changes: 2 additions & 1 deletion channels/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from asgiref.sync import async_to_sync

from . import DEFAULT_CHANNEL_LAYER
from .db import database_sync_to_async
from .db import aclose_old_connections, database_sync_to_async
from .exceptions import StopConsumer
from .layers import get_channel_layer
from .utils import await_many_dispatch
Expand Down Expand Up @@ -70,6 +70,7 @@ async def dispatch(self, message):
"""
handler = getattr(self, get_handler_name(message), None)
if handler:
await aclose_old_connections()
await handler(message)
else:
raise ValueError("No handler for message type %s" % message["type"])
Expand Down
6 changes: 5 additions & 1 deletion channels/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asgiref.sync import SyncToAsync
from asgiref.sync import SyncToAsync, sync_to_async
from django.db import close_old_connections


Expand All @@ -17,3 +17,7 @@ def thread_handler(self, loop, *args, **kwargs):

# The class is TitleCased, but we want to encourage use as a callable/decorator
database_sync_to_async = DatabaseSyncToAsync


async def aclose_old_connections():
return await sync_to_async(close_old_connections)()
2 changes: 2 additions & 0 deletions channels/generic/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from channels.consumer import AsyncConsumer

from ..db import aclose_old_connections
from ..exceptions import StopConsumer


Expand Down Expand Up @@ -88,4 +89,5 @@ async def http_disconnect(self, message):
Let the user do their cleanup and close the consumer.
"""
await self.disconnect()
await aclose_old_connections()
raise StopConsumer()
2 changes: 2 additions & 0 deletions channels/generic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from asgiref.sync import async_to_sync

from ..consumer import AsyncConsumer, SyncConsumer
from ..db import aclose_old_connections
from ..exceptions import (
AcceptConnection,
DenyConnection,
Expand Down Expand Up @@ -247,6 +248,7 @@ async def websocket_disconnect(self, message):
"BACKEND is unconfigured or doesn't support groups"
)
await self.disconnect(message["code"])
await aclose_old_connections()
raise StopConsumer()

async def disconnect(self, code):
Expand Down
63 changes: 35 additions & 28 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def __init__(
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
**kwargs,
)
self.channels = {}
self.groups = {}
Expand All @@ -225,13 +225,14 @@ async def send(self, channel, message):
# name in message
assert "__asgi_channel__" not in message

queue = self.channels.setdefault(channel, asyncio.Queue())
# Are we full
if queue.qsize() >= self.capacity:
raise ChannelFull(channel)

queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)
# Add message
await queue.put((time.time() + self.expiry, deepcopy(message)))
try:
queue.put_nowait((time.time() + self.expiry, deepcopy(message)))
except asyncio.queues.QueueFull:
raise ChannelFull(channel)

async def receive(self, channel):
"""
Expand All @@ -242,14 +243,16 @@ async def receive(self, channel):
assert self.valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(channel, asyncio.Queue())
queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)

# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)

return message

Expand Down Expand Up @@ -279,19 +282,17 @@ def _clean_expired(self):
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)

# Group Expiration
timeout = int(time.time()) - self.group_expiry
for group in self.groups:
for channel in list(self.groups.get(group, set())):
# If join time is older than group_expiry end the group membership
if (
self.groups[group][channel]
and int(self.groups[group][channel]) < timeout
):
for channels in self.groups.values():
for name, timestamp in list(channels.items()):
# If join time is older than group_expiry
# end the group membership
if timestamp and timestamp < timeout:
# Delete from group
del self.groups[group][channel]
channels.pop(name, None)

# Flush extension

Expand All @@ -308,8 +309,7 @@ def _remove_from_groups(self, channel):
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
if channel in channels:
del channels[channel]
channels.pop(channel, None)

# Groups extension

Expand All @@ -329,22 +329,29 @@ async def group_discard(self, group, channel):
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
# Remove from group set
if group in self.groups:
if channel in self.groups[group]:
del self.groups[group][channel]
if not self.groups[group]:
del self.groups[group]
group_channels = self.groups.get(group, None)
if group_channels:
# remove channel if in group
group_channels.pop(channel, None)
# is group now empty? If yes remove it
if not group_channels:
self.groups.pop(group, None)

async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
# Run clean
self._clean_expired()

# Send to each channel
for channel in self.groups.get(group, set()):
ops = []
if group in self.groups:
for channel in self.groups[group].keys():
ops.append(asyncio.create_task(self.send(channel, message)))
for send_result in asyncio.as_completed(ops):
try:
await self.send(channel, message)
await send_result
except ChannelFull:
pass

Expand Down
14 changes: 0 additions & 14 deletions channels/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,6 @@ def get_default_application():
return value


DEPRECATION_MSG = """
Using ProtocolTypeRouter without an explicit "http" key is deprecated.
Given that you have not passed the "http" you likely should use Django's
get_asgi_application():
from django.core.asgi import get_asgi_application
application = ProtocolTypeRouter(
"http": get_asgi_application()
# Other protocols here.
)
"""


class ProtocolTypeRouter:
"""
Takes a mapping of protocol type names to other Application instances,
Expand Down
20 changes: 9 additions & 11 deletions channels/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from importlib import import_module

import django
from django.conf import settings
from django.contrib.sessions.backends.base import UpdateError
from django.core.exceptions import SuspiciousOperation
Expand All @@ -10,14 +11,10 @@
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.functional import LazyObject
from django.utils.http import http_date

from channels.db import database_sync_to_async

try:
from django.utils.http import http_date
except ImportError:
from django.utils.http import cookie_date as http_date


class CookieMiddleware:
"""
Expand Down Expand Up @@ -167,9 +164,7 @@ def __init__(self, scope, send):

async def resolve_session(self):
session_key = self.scope["cookies"].get(self.cookie_name)
self.scope["session"]._wrapped = await database_sync_to_async(
self.session_store
)(session_key)
self.scope["session"]._wrapped = self.session_store(session_key)

async def send(self, message):
"""
Expand All @@ -187,7 +182,7 @@ async def send(self, message):
and message.get("status", 200) != 500
and (modified or settings.SESSION_SAVE_EVERY_REQUEST)
):
await database_sync_to_async(self.save_session)()
await self.save_session()
# If this is a message type that can transport cookies back to the
# client, then do so.
if message["type"] in self.cookie_response_message_types:
Expand Down Expand Up @@ -225,12 +220,15 @@ async def send(self, message):
# Pass up the send
return await self.real_send(message)

def save_session(self):
async def save_session(self):
"""
Saves the current session.
"""
try:
self.scope["session"].save()
if django.VERSION >= (5, 1):
await self.scope["session"].asave()
else:
await database_sync_to_async(self.scope["session"].save)()
except UpdateError:
raise SuspiciousOperation(
"The request's session was deleted before the "
Expand Down
3 changes: 1 addition & 2 deletions channels/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from asgiref.testing import ApplicationCommunicator # noqa

from .application import ApplicationCommunicator # noqa
from .http import HttpCommunicator # noqa
from .live import ChannelsLiveServerTestCase # noqa
from .websocket import WebsocketCommunicator # noqa
Expand Down
17 changes: 17 additions & 0 deletions channels/testing/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from unittest import mock

from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator


def no_op():
pass


class ApplicationCommunicator(BaseApplicationCommunicator):
async def send_input(self, message):
with mock.patch("channels.db.close_old_connections", no_op):
return await super().send_input(message)

async def receive_output(self, timeout=1):
with mock.patch("channels.db.close_old_connections", no_op):
return await super().receive_output(timeout)
2 changes: 1 addition & 1 deletion channels/testing/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from urllib.parse import unquote, urlparse

from asgiref.testing import ApplicationCommunicator
from channels.testing.application import ApplicationCommunicator


class HttpCommunicator(ApplicationCommunicator):
Expand Down
16 changes: 11 additions & 5 deletions channels/testing/websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from urllib.parse import unquote, urlparse

from asgiref.testing import ApplicationCommunicator
from channels.testing.application import ApplicationCommunicator


class WebsocketCommunicator(ApplicationCommunicator):
Expand Down Expand Up @@ -78,27 +78,33 @@ async def receive_from(self, timeout=1):
"""
response = await self.receive_output(timeout)
# Make sure this is a send message
assert response["type"] == "websocket.send"
assert (
response["type"] == "websocket.send"
), f"Expected type 'websocket.send', but was '{response['type']}'"
# Make sure there's exactly one key in the response
assert ("text" in response) != (
"bytes" in response
), "The response needs exactly one of 'text' or 'bytes'"
# Pull out the right key and typecheck it for our users
if "text" in response:
assert isinstance(response["text"], str), "Text frame payload is not str"
assert isinstance(
response["text"], str
), f"Text frame payload is not str, it is {type(response['text'])}"
return response["text"]
else:
assert isinstance(
response["bytes"], bytes
), "Binary frame payload is not bytes"
), f"Binary frame payload is not bytes, it is {type(response['bytes'])}"
return response["bytes"]

async def receive_json_from(self, timeout=1):
"""
Receives a JSON text frame payload and decodes it
"""
payload = await self.receive_from(timeout)
assert isinstance(payload, str), "JSON data is not a text frame"
assert isinstance(
payload, str
), f"JSON data is not a text frame, it is {type(payload)}"
return json.loads(payload)

async def disconnect(self, code=1000, timeout=1):
Expand Down
3 changes: 2 additions & 1 deletion docs/topics/consumers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ callable into an asynchronous coroutine.

If you want to call the Django ORM from an ``AsyncConsumer`` (or any other
asynchronous code), you should use the ``database_sync_to_async`` adapter
instead. See :doc:`/topics/databases` for more.
or use the async versions of the methods (prefixed with ``a``, like ``aget``).
See :doc:`/topics/databases` for more.


Closing Consumers
Expand Down
Loading

0 comments on commit 7e06b02

Please sign in to comment.