Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket sansio implementataion #2060

Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ca826db
Implement websockets_sans_impl.py
gourav-kandoria Jul 26, 2023
559617b
add surrogate errors in decode
gourav-kandoria Jul 27, 2023
56d2152
fix lint issues
gourav-kandoria Jul 27, 2023
d12e72a
fix mypy failing issues
gourav-kandoria Jul 27, 2023
f24527b
fix lint issues
gourav-kandoria Jul 27, 2023
ba972e0
fix typing issues
gourav-kandoria Jul 27, 2023
b81f762
Fix extension tests failing
gourav-kandoria Jul 27, 2023
0f59f77
Fix extension tests failing
gourav-kandoria Jul 27, 2023
38e1629
correct types import
gourav-kandoria Jul 29, 2023
29d2d09
correct types import and mypy issues
gourav-kandoria Jul 29, 2023
28f2714
fix typo
gourav-kandoria Jul 29, 2023
10933b6
Merge remote-tracking branch 'upstream/master' into websocket-sansio-…
gourav-kandoria Aug 1, 2023
3a72504
Replace ServerConnection with ServerProtocol due to upgradation of we…
gourav-kandoria Aug 1, 2023
1239297
Merge remote-tracking branch 'upstream/master' into websocket-sansio-…
gourav-kandoria Aug 8, 2023
39e3c33
Remove conditional on imports
Kludex Aug 27, 2023
931e78e
Fix typos, and small details
Kludex Aug 27, 2023
3d57661
Refactor small things
Kludex Aug 28, 2023
d76cdc6
Fix linter
Kludex Aug 28, 2023
ca6f63b
Merge branch 'websocket-sansio-implementataion' of https://github.com…
gourav-kandoria Aug 29, 2023
d246cc4
Merge branch 'master' of https://github.com/encode/uvicorn into webso…
gourav-kandoria Aug 29, 2023
aed00c8
Add tests for websocket server for receiving multiple frames
gourav-kandoria Aug 29, 2023
808f951
Remove checking of PONG event after receiving data
gourav-kandoria Aug 29, 2023
803100c
Revert "Remove checking of PONG event after receiving data"
gourav-kandoria Aug 29, 2023
7519e6b
"Remove checking of PONG event after receiving data"
gourav-kandoria Aug 29, 2023
87ad36a
Create WSType on the test suite
Kludex Aug 30, 2023
1048c18
Add WebSocketsSansIOProtocol to the CLI
Kludex Aug 30, 2023
37a686f
Make changes for testing payload max_size limit
gourav-kandoria Sep 1, 2023
4f76f62
Make changes for testing payload max_size limit
gourav-kandoria Sep 4, 2023
cbe36ba
fix lint issue
gourav-kandoria Sep 4, 2023
348b6ac
increase msg size from 11 to 32
gourav-kandoria Sep 4, 2023
48b1d5f
increase client max_limit
gourav-kandoria Sep 5, 2023
12bb2d2
Empty-Commit-to-trigger-pipeline
gourav-kandoria Sep 5, 2023
498eaf5
Implement websockets_sans_impl.py
gourav-kandoria Jul 26, 2023
e378770
add surrogate errors in decode
gourav-kandoria Jul 27, 2023
3ce1611
fix lint issues
gourav-kandoria Jul 27, 2023
4f601dc
fix mypy failing issues
gourav-kandoria Jul 27, 2023
d48f8c1
fix lint issues
gourav-kandoria Jul 27, 2023
ab12969
fix typing issues
gourav-kandoria Jul 27, 2023
65057b8
Fix extension tests failing
gourav-kandoria Jul 27, 2023
b81bd5a
Fix extension tests failing
gourav-kandoria Jul 27, 2023
63e6f68
correct types import
gourav-kandoria Jul 29, 2023
b89d732
correct types import and mypy issues
gourav-kandoria Jul 29, 2023
526fe56
fix typo
gourav-kandoria Jul 29, 2023
0699a7e
Replace ServerConnection with ServerProtocol due to upgradation of we…
gourav-kandoria Aug 1, 2023
82f3f6e
Remove conditional on imports
Kludex Aug 27, 2023
7a90b8b
Fix typos, and small details
Kludex Aug 27, 2023
e72cd54
Refactor small things
Kludex Aug 28, 2023
d9a4ea0
Fix linter
Kludex Aug 28, 2023
252bdc1
Add tests for websocket server for receiving multiple frames
gourav-kandoria Aug 29, 2023
9ff1a2e
Remove checking of PONG event after receiving data
gourav-kandoria Aug 29, 2023
bc35b4f
Revert "Remove checking of PONG event after receiving data"
gourav-kandoria Aug 29, 2023
b82a8ce
"Remove checking of PONG event after receiving data"
gourav-kandoria Aug 29, 2023
09d3072
Create WSType on the test suite
Kludex Aug 30, 2023
5a22d00
Add WebSocketsSansIOProtocol to the CLI
Kludex Aug 30, 2023
5f5f5f4
Make changes for testing payload max_size limit
gourav-kandoria Sep 1, 2023
8dcd505
Make changes for testing payload max_size limit
gourav-kandoria Sep 4, 2023
3f4eecb
fix lint issue
gourav-kandoria Sep 4, 2023
60460f9
increase msg size from 11 to 32
gourav-kandoria Sep 4, 2023
78c6941
increase client max_limit
gourav-kandoria Sep 5, 2023
0ade3d4
Empty-Commit-to-trigger-pipeline
gourav-kandoria Sep 5, 2023
db31c56
Use WSProtocolType
Kludex Dec 26, 2023
bf00ada
Use future annotations on websocket sansio implementation
Kludex Dec 26, 2023
64d6eb7
WIP websockets denial response extension
Kludex Dec 26, 2023
064a7fc
Merge branch 'websocket-sansio-implementataion' of https://github.com…
gourav-kandoria Dec 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def unused_tcp_port() -> int:
),
),
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
]
)
def ws_protocol_cls(request: pytest.FixtureRequest):
Expand Down
389 changes: 389 additions & 0 deletions uvicorn/protocols/websockets/websockets_sansio_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,389 @@
import asyncio
import logging
import sys
import typing
from asyncio.transports import BaseTransport, Transport
from http import HTTPStatus
from urllib.parse import unquote

import websockets
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.frames import Frame
from websockets.http11 import Request, Response
from websockets.server import ServerProtocol

from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8):
from typing_extensions import Literal
Kludex marked this conversation as resolved.
Show resolved Hide resolved
else:
from typing import Literal

if typing.TYPE_CHECKING:
from uvicorn._types import (
ASGIReceiveEvent,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketScope,
WebSocketSendEvent,
)


class WebSocketsSansIOProtocol(asyncio.Protocol):
def __init__(
self,
config: Config,
server_state: ServerState,
app_state: typing.Dict[str, typing.Any],
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()
Kludex marked this conversation as resolved.
Show resolved Hide resolved

self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
self.app_state = app_state

# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers

# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: typing.Optional[typing.Tuple[str, int]] = None
self.client: typing.Optional[typing.Tuple[str, int]] = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]

# WebSocket state
self.queue: asyncio.Queue["ASGIReceiveEvent"] = asyncio.Queue()
self.handshake_initiated = False
self.handshake_complete = False
self.close_sent = False

extensions = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())
self.conn = ServerProtocol(extensions=extensions)
self.request: Request
self.response: Response
self.curr_msg_data_type: str

self.read_paused = False
self.writable = asyncio.Event()
self.writable.set()

# Buffers
self.bytes: "bytes" = b""

def connection_made(self, transport: BaseTransport) -> None:
"""Called when a connection is made."""
transport = typing.cast(Transport, transport)
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc: typing.Optional[Exception]) -> None:
self.connections.remove(self)
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
if self.handshake_initiated and not self.close_sent:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})

def data_received(self, data: bytes) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://websockets.readthedocs.io/en/stable/howto/sansio.html
if I'm following this right after receiving data you should call receive_eof()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@valentin994 @Kludex we don't have to call receive_eof() after receiving data. Because, protocol ensures that (https://docs.python.org/3/library/asyncio-protocol.html#asyncio.Protocol.data_received) data_received will only be called only when the data is there. In case when eof is received from client side, the transport will close and connection_lost will be called.

try:
self.conn.receive_data(data)
except Exception:
self.logger.exception("Exception in ASGI server")
self.transport.close()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lines(117-119) are not being covered. Because, I guess in the test suite there not any test which after connection establishment sends data which causes receive_data to throw exception such as data which don't follow websocket spec.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to test it. It's unlikely that we want Exception there as well.

self.handle_events()

def shutdown(self) -> None:
if not self.transport.is_closing():
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.close_sent = True
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.writelines(output)
elif self.handshake_initiated:
self.send_500_response()
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.transport.close()

def handle_events(self) -> None:
for event in self.conn.events_received():
if isinstance(event, Request):
self.handle_connect(event)
if isinstance(event, Frame):
if event.opcode == websockets.frames.Opcode.CONT:
self.handle_cont(event)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
elif event.opcode == websockets.frames.Opcode.TEXT:
self.handle_text(event)
elif event.opcode == websockets.frames.Opcode.BINARY:
self.handle_bytes(event)
elif event.opcode == websockets.frames.Opcode.PING:
self.handle_ping(event)
elif event.opcode == websockets.frames.Opcode.PONG:
self.handle_pong(event)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
elif event.opcode == websockets.frames.Opcode.CLOSE:
self.handle_close(event)
Copy link
Author

@gourav-kandoria gourav-kandoria Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kludex @aaugustin Could you please help with this?
one scenario, where we can expect close frame is when the client sends it.
What other scenarios are there where we can expect it?
Can we expect it in scenarios where the we received unexpected data
and how can we have access to the close code, so that same code can be passed in disconnect event to asgi app


# Event handlers

def handle_connect(self, event: Request) -> None:
self.request = event
self.response = self.conn.accept(event)
self.handshake_initiated = True
# if status_code is not 101 return response
if self.response.status_code != 101:
self.handshake_complete = True
self.close_sent = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.close()
return

headers = [
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for key, value in event.headers.raw_items()
]
raw_path, _, query_string = event.path.partition("?")
self.scope: "WebSocketScope" = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
"scheme": self.scheme,
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(raw_path),
"raw_path": raw_path.encode("ascii"),
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
"state": self.app_state.copy(),
}
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)

def handle_cont(self, event: Frame) -> None:
self.bytes += event.data
if event.fin:
self.send_receive_event_to_app()
Kludex marked this conversation as resolved.
Show resolved Hide resolved

def handle_text(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type = "text"
if event.fin:
self.send_receive_event_to_app()

def handle_bytes(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type = "bytes"
if event.fin:
self.send_receive_event_to_app()

def send_receive_event_to_app(self) -> None:
data: typing.Union[str, bytes]
if self.curr_msg_data_type == "text":
data = self.bytes.decode()
else:
data = self.bytes

msg: "WebSocketReceiveEvent" = {
"type": "websocket.receive",
self.curr_msg_data_type: data, # type: ignore[misc]
}
self.queue.put_nowait(msg)
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()

def handle_ping(self, event: Frame) -> None:
output = self.conn.data_to_send()
self.transport.writelines(output)

def handle_pong(self, event: Frame) -> None:
pass
Kludex marked this conversation as resolved.
Show resolved Hide resolved

def handle_close(self, event: Frame) -> None:
if not self.close_sent and not self.transport.is_closing():
disconnect_event: "WebSocketDisconnectEvent" = {
"type": "websocket.disconnect",
"code": self.conn.close_rcvd.code, # type: ignore[union-attr]
}
self.queue.put_nowait(disconnect_event)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.close_sent = True
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
self.tasks.discard(task)

async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_complete:
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.transport.close()

def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg,
]
self.transport.write(b"".join(content))

async def send(self, message: "ASGISendEvent") -> None:
await self.writable.wait()

message_type = message["type"]

if not self.handshake_complete:
if message_type == "websocket.accept" and not self.transport.is_closing():
message = typing.cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
headers = [
(
key.decode("ascii"),
value.decode("ascii", errors="surrogateescape"),
)
for key, value in self.default_headers
+ list(message.get("headers", []))
]

accepted_subprotocol = message.get("subprotocol")
if accepted_subprotocol:
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))

self.handshake_complete = True
self.response.headers.update(headers)
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)

elif message_type == "websocket.close" and not self.transport.is_closing():
message = typing.cast("WebSocketCloseEvent", message)
self.queue.put_nowait(
{
"type": "websocket.disconnect",
"code": message.get("code", 1000) or 1000,
}
)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
extra_headers = [
(
key.decode("ascii"),
value.decode("ascii", errors="surrogateescape"),
)
for key, value in self.default_headers
]

response = self.conn.reject(
HTTPStatus.FORBIDDEN, message.get("reason", "") or ""
)
response.headers.update(extra_headers)
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.hankshake_complete = True
self.transport.writelines(output)
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
"but got '%s'."
)
raise RuntimeError(msg % message_type)

elif not self.close_sent:
if message_type == "websocket.send" and not self.transport.is_closing():
message = typing.cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
if text_data:
self.conn.send_text(text_data.encode())
elif bytes_data:
self.conn.send_binary(bytes_data)
output = self.conn.data_to_send()
self.transport.writelines(output)

elif message_type == "websocket.close" and not self.transport.is_closing():
message = typing.cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.conn.send_close(code, reason)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.close_sent = True
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)

else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)

async def receive(self) -> "ASGIReceiveEvent":
message = await self.queue.get()
if self.read_paused and self.queue.empty():
self.read_paused = False
self.transport.resume_reading()
return message
Loading