Skip to content

Commit

Permalink
feat(typing): add type annotation to websocket module (#2295)
Browse files Browse the repository at this point in the history
* typing: type app

* typing: type websocket module

---------

Co-authored-by: Vytautas Liuolia <[email protected]>
  • Loading branch information
CaselIT and vytas7 authored Aug 30, 2024
1 parent 5cb2b89 commit f36a23e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 55 deletions.
110 changes: 59 additions & 51 deletions falcon/asgi/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,34 @@

import asyncio
import collections
from enum import auto
from enum import Enum
from typing import (
Any,
Deque,
Dict,
Iterable,
Mapping,
Optional,
Union,
)
from typing import Any, Deque, Dict, Iterable, Mapping, Optional, Tuple, Union

from falcon import errors
from falcon import media
from falcon import status_codes
from falcon.asgi_spec import AsgiEvent
from falcon.asgi_spec import AsgiSendMsg
from falcon.asgi_spec import EventType
from falcon.asgi_spec import WSCloseCode
from falcon.constants import WebSocketPayloadType
from falcon.typing import AsgiReceive
from falcon.typing import AsgiSend
from falcon.typing import HeaderList
from falcon.util import misc

_WebSocketState = Enum('_WebSocketState', 'HANDSHAKE ACCEPTED CLOSED')
__all__ = ('WebSocket',)


__all__ = ('WebSocket',)
class _WebSocketState(Enum):
HANDSHAKE = auto()
ACCEPTED = auto()
CLOSED = auto()


class WebSocket:
"""Represents a single WebSocket connection with a client.
Attributes:
ready (bool): ``True`` if the WebSocket connection has been
accepted and the client is still connected, ``False`` otherwise.
unaccepted (bool)): ``True`` if the WebSocket connection has not yet
been accepted, ``False`` otherwise.
closed (bool): ``True`` if the WebSocket connection has been closed
by the server or the client has disconnected.
subprotocols (tuple[str]): The list of subprotocol strings advertised
by the client, or an empty tuple if no subprotocols were
specified.
supports_accept_headers (bool): ``True`` if the ASGI server hosting
the app supports sending headers when accepting the WebSocket
connection, ``False`` otherwise.
"""
"""Represents a single WebSocket connection with a client."""

__slots__ = (
'_asgi_receive',
Expand All @@ -65,6 +47,13 @@ class WebSocket:
'subprotocols',
)

_state: _WebSocketState
_close_code: Optional[int]
subprotocols: Tuple[str, ...]
"""The list of subprotocol strings advertised by the client, or an empty tuple if
no subprotocols were specified.
"""

def __init__(
self,
ver: str,
Expand Down Expand Up @@ -105,35 +94,47 @@ def __init__(

self._close_reasons = default_close_reasons
self._state = _WebSocketState.HANDSHAKE
self._close_code = None # type: Optional[int]
self._close_code = None

@property
def unaccepted(self) -> bool:
"""``True`` if the WebSocket connection has not yet been accepted,
``False`` otherwise.
""" # noqa: D205
return self._state == _WebSocketState.HANDSHAKE

@property
def closed(self) -> bool:
"""``True`` if the WebSocket connection has been closed by the server or the
client has disconnected.
""" # noqa: D205
return (
self._state == _WebSocketState.CLOSED
or self._buffered_receiver.client_disconnected
)

@property
def ready(self) -> bool:
"""``True`` if the WebSocket connection has been accepted and the client is
still connected, ``False`` otherwise.
""" # noqa: D205
return (
self._state == _WebSocketState.ACCEPTED
and not self._buffered_receiver.client_disconnected
)

@property
def supports_accept_headers(self) -> bool:
"""``True`` if the ASGI server hosting the app supports sending headers when
accepting the WebSocket connection, ``False`` otherwise.
""" # noqa: D205
return self._supports_accept_headers

async def accept(
self,
subprotocol: Optional[str] = None,
headers: Optional[Union[Iterable[Iterable[str]], Mapping[str, str]]] = None,
):
headers: Optional[HeaderList] = None,
) -> None:
"""Accept the incoming WebSocket connection.
If, after examining the connection's attributes (headers, advertised
Expand All @@ -154,7 +155,7 @@ async def accept(
client may choose to abandon the connection in this case,
if it does not receive an explicit protocol selection.
headers (Iterable[[str, str]]): An iterable of ``[name: str, value: str]``
headers (HeaderList): An iterable of ``(name: str, value: str)``
two-item iterables, representing a collection of HTTP headers to
include in the handshake response. Both *name* and *value* must
be of type ``str`` and contain only US-ASCII characters.
Expand Down Expand Up @@ -199,13 +200,14 @@ async def accept(
)

header_items = getattr(headers, 'items', None)

if callable(header_items):
headers = header_items()
headers_iterable: Iterable[tuple[str, str]] = header_items()
else:
headers_iterable = headers # type: ignore[assignment]

event['headers'] = parsed_headers = [
(name.lower().encode('ascii'), value.encode('ascii'))
for name, value in headers # type: ignore
for name, value in headers_iterable
]

for name, __ in parsed_headers:
Expand Down Expand Up @@ -348,7 +350,6 @@ async def send_text(self, payload: str) -> None:
"""

self._require_accepted()

# NOTE(kgriffs): We have to check ourselves because some ASGI
# servers are not very strict which can lead to hard-to-debug
# errors.
Expand All @@ -369,14 +370,13 @@ async def send_data(self, payload: Union[bytes, bytearray, memoryview]) -> None:
payload (Union[bytes, bytearray, memoryview]): The binary data to send.
"""

self._require_accepted()
# NOTE(kgriffs): We have to check ourselves because some ASGI
# servers are not very strict which can lead to hard-to-debug
# errors.
if not isinstance(payload, (bytes, bytearray, memoryview)):
raise TypeError('payload must be a byte string')

self._require_accepted()

await self._send(
{
'type': EventType.WS_SEND,
Expand Down Expand Up @@ -464,7 +464,7 @@ async def receive_media(self) -> object:

return self._mh_bin_deserialize(data)

async def _send(self, msg: dict):
async def _send(self, msg: AsgiSendMsg) -> None:
if self._buffered_receiver.client_disconnected:
self._state = _WebSocketState.CLOSED
self._close_code = self._buffered_receiver.client_disconnected_code
Expand All @@ -489,7 +489,7 @@ async def _send(self, msg: dict):
# obscure the traceback.
raise

async def _receive(self) -> dict:
async def _receive(self) -> AsgiEvent:
event = await self._asgi_receive()

event_type = event['type']
Expand All @@ -506,15 +506,15 @@ async def _receive(self) -> dict:

return event

def _require_accepted(self):
def _require_accepted(self) -> None:
if self._state == _WebSocketState.HANDSHAKE:
raise errors.OperationNotAllowed(
'WebSocket connection has not yet been accepted'
)
elif self._state == _WebSocketState.CLOSED:
raise errors.WebSocketDisconnected(self._close_code)

def _translate_webserver_error(self, ex):
def _translate_webserver_error(self, ex: Exception) -> Optional[Exception]:
s = str(ex)

# NOTE(kgriffs): uvicorn or any other server using the "websockets"
Expand Down Expand Up @@ -656,13 +656,20 @@ class _BufferedReceiver:
'client_disconnected_code',
]

def __init__(self, asgi_receive: AsgiReceive, max_queue: int):
_pop_message_waiter: Optional[asyncio.Future[None]]
_put_message_waiter: Optional[asyncio.Future[None]]
_pump_task: Optional[asyncio.Task[None]]
_messages: Deque[AsgiEvent]
client_disconnected: bool
client_disconnected_code: Optional[int]

def __init__(self, asgi_receive: AsgiReceive, max_queue: int) -> None:
self._asgi_receive = asgi_receive
self._max_queue = max_queue

self._loop = asyncio.get_running_loop()

self._messages: Deque[AsgiEvent] = collections.deque()
self._messages = collections.deque()
self._pop_message_waiter = None
self._put_message_waiter = None

Expand All @@ -671,12 +678,12 @@ def __init__(self, asgi_receive: AsgiReceive, max_queue: int):
self.client_disconnected = False
self.client_disconnected_code = None

def start(self):
if not self._pump_task:
def start(self) -> None:
if self._pump_task is None:
self._pump_task = asyncio.create_task(self._pump())

async def stop(self):
if not self._pump_task:
async def stop(self) -> None:
if self._pump_task is None:
return

self._pump_task.cancel()
Expand All @@ -687,13 +694,14 @@ async def stop(self):

self._pump_task = None

async def receive(self):
async def receive(self) -> AsgiEvent:
# NOTE(kgriffs): Since this class is only used internally, we
# use an assertion to mitigate against framework bugs.
#
# receive() may not be called again while another coroutine
# is already waiting for the next message.
assert not self._pop_message_waiter
assert self._pop_message_waiter is None
assert self._pump_task is not None

# NOTE(kgriffs): Wait for a message if none are available. This pattern
# was borrowed from the websockets.protocol module.
Expand Down Expand Up @@ -737,7 +745,7 @@ async def receive(self):

return message

async def _pump(self):
async def _pump(self) -> None:
while not self.client_disconnected:
received_event = await self._asgi_receive()
if received_event['type'] == EventType.WS_DISCONNECT:
Expand Down
8 changes: 6 additions & 2 deletions falcon/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import auto
from enum import Enum
import os
import sys
Expand Down Expand Up @@ -187,5 +188,8 @@
_UNSET = object() # TODO: remove once replaced with missing


WebSocketPayloadType = Enum('WebSocketPayloadType', 'TEXT BINARY')
"""Enum representing the two possible WebSocket payload types."""
class WebSocketPayloadType(Enum):
"""Enum representing the two possible WebSocket payload types."""

TEXT = auto()
BINARY = auto()
8 changes: 7 additions & 1 deletion falcon/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections import defaultdict
from collections import deque
import contextlib
from enum import auto
from enum import Enum
import io
import itertools
Expand Down Expand Up @@ -365,7 +366,12 @@ async def collect(self, event: Dict[str, Any]):
__call__ = collect


_WebSocketState = Enum('_WebSocketState', 'CONNECT HANDSHAKE ACCEPTED DENIED CLOSED')
class _WebSocketState(Enum):
CONNECT = auto()
HANDSHAKE = auto()
ACCEPTED = auto()
DENIED = auto()
CLOSED = auto()


class ASGIWebSocketSimulator:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"falcon.asgi.reader",
"falcon.asgi.response",
"falcon.asgi.stream",
"falcon.asgi.ws",
"falcon.media.json",
"falcon.media.msgpack",
"falcon.media.multipart",
Expand Down

0 comments on commit f36a23e

Please sign in to comment.