Skip to content

Commit

Permalink
feat(typing): type asgi.reader, asgi.structures, asgi.stream (#2297)
Browse files Browse the repository at this point in the history
* typing: type app

* typing: type websocket module

* typing: type asgi.reader, asgi.structures, asgi.stream

---------

Co-authored-by: Vytautas Liuolia <[email protected]>
  • Loading branch information
CaselIT and vytas7 authored Aug 30, 2024
1 parent f36a23e commit 3e32ff7
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 76 deletions.
92 changes: 60 additions & 32 deletions falcon/asgi/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""Buffered ASGI stream reader."""

from __future__ import annotations

import io
from typing import AsyncIterator, List, NoReturn, Optional, Protocol

from falcon.errors import DelimiterError
from falcon.errors import OperationNotAllowed
Expand Down Expand Up @@ -45,7 +48,17 @@ class BufferedReader:
'_source',
]

def __init__(self, source, chunk_size=None):
_buffer: bytes
_buffer_len: int
_buffer_pos: int
_chunk_size: int
_consumed: int
_exhausted: bool
_iteration_started: bool
_max_join_size: int
_source: AsyncIterator[bytes]

def __init__(self, source: AsyncIterator[bytes], chunk_size: Optional[int] = None):
self._source = self._iter_normalized(source)
self._chunk_size = chunk_size or DEFAULT_CHUNK_SIZE
self._max_join_size = self._chunk_size * _MAX_JOIN_CHUNKS
Expand All @@ -57,7 +70,9 @@ def __init__(self, source, chunk_size=None):
self._exhausted = False
self._iteration_started = False

async def _iter_normalized(self, source):
async def _iter_normalized(
self, source: AsyncIterator[bytes]
) -> AsyncIterator[bytes]:
chunk = b''
chunk_size = self._chunk_size

Expand All @@ -77,7 +92,7 @@ async def _iter_normalized(self, source):

self._exhausted = True

async def _iter_with_buffer(self, size_hint=0):
async def _iter_with_buffer(self, size_hint: int = 0) -> AsyncIterator[bytes]:
if self._buffer_len > self._buffer_pos:
if 0 < size_hint < self._buffer_len - self._buffer_pos:
buffer_pos = self._buffer_pos
Expand All @@ -91,7 +106,9 @@ async def _iter_with_buffer(self, size_hint=0):
async for chunk in self._source:
yield chunk

async def _iter_delimited(self, delimiter, size_hint=0):
async def _iter_delimited(
self, delimiter: bytes, size_hint: int = 0
) -> AsyncIterator[bytes]:
delimiter_len_1 = len(delimiter) - 1
if not 0 <= delimiter_len_1 < self._chunk_size:
raise ValueError('delimiter length must be within [1, chunk_size]')
Expand Down Expand Up @@ -152,13 +169,13 @@ async def _iter_delimited(self, delimiter, size_hint=0):

yield self._buffer

async def _consume_delimiter(self, delimiter):
async def _consume_delimiter(self, delimiter: bytes) -> None:
delimiter_len = len(delimiter)
if await self.peek(delimiter_len) != delimiter:
raise DelimiterError('expected delimiter missing')
self._buffer_pos += delimiter_len

def _prepend_buffer(self, chunk):
def _prepend_buffer(self, chunk: bytes) -> None:
if self._buffer_len > self._buffer_pos:
self._buffer = chunk + self._buffer[self._buffer_pos :]
self._buffer_len = len(self._buffer)
Expand All @@ -168,25 +185,25 @@ def _prepend_buffer(self, chunk):

self._buffer_pos = 0

def _trim_buffer(self):
def _trim_buffer(self) -> None:
self._buffer = self._buffer[self._buffer_pos :]
self._buffer_len -= self._buffer_pos
self._buffer_pos = 0

async def _read_from(self, source, size=-1):
async def _read_from(self, source: AsyncIterator[bytes], size: int = -1) -> bytes:
if size == -1 or size is None:
result = io.BytesIO()
result_bytes = io.BytesIO()
async for chunk in source:
result.write(chunk)
return result.getvalue()
result_bytes.write(chunk)
return result_bytes.getvalue()

if size <= 0:
return b''

remaining = size

if size <= self._max_join_size:
result = []
result: List[bytes] = []
async for chunk in source:
chunk_len = len(chunk)
if remaining < chunk_len:
Expand All @@ -203,29 +220,29 @@ async def _read_from(self, source, size=-1):
return result[0] if len(result) == 1 else b''.join(result)

# NOTE(vytas): size > self._max_join_size
result = io.BytesIO()
result_bytes = io.BytesIO()
async for chunk in source:
chunk_len = len(chunk)
if remaining < chunk_len:
result.write(chunk[:remaining])
result_bytes.write(chunk[:remaining])
self._prepend_buffer(chunk[remaining:])
break

result.write(chunk)
result_bytes.write(chunk)
remaining -= chunk_len
if remaining == 0: # pragma: no py39,py310 cover
break

return result.getvalue()
return result_bytes.getvalue()

def delimit(self, delimiter):
def delimit(self, delimiter: bytes) -> BufferedReader: # TODO: should se self
return type(self)(self._iter_delimited(delimiter), chunk_size=self._chunk_size)

# -------------------------------------------------------------------------
# Asynchronous IO interface.
# -------------------------------------------------------------------------

def __aiter__(self):
def __aiter__(self) -> AsyncIterator[bytes]:
if self._iteration_started:
raise OperationNotAllowed('This stream is already being iterated over.')

Expand All @@ -236,10 +253,10 @@ def __aiter__(self):
return self._iter_with_buffer()
return self._source

async def exhaust(self):
async def exhaust(self) -> None:
await self.pipe()

async def peek(self, size=-1):
async def peek(self, size: int = -1) -> bytes:
if size < 0 or size > self._chunk_size:
size = self._chunk_size

Expand All @@ -255,23 +272,28 @@ async def peek(self, size=-1):

return self._buffer[:size]

async def pipe(self, destination=None):
async def pipe(self, destination: Optional[AsyncWritableIO] = None) -> None:
async for chunk in self._iter_with_buffer():
if destination is not None:
await destination.write(chunk)

async def pipe_until(self, delimiter, destination=None, consume_delimiter=False):
async def pipe_until(
self,
delimiter: bytes,
destination: Optional[AsyncWritableIO] = None,
consume_delimiter: bool = False,
) -> None:
async for chunk in self._iter_delimited(delimiter):
if destination is not None:
await destination.write(chunk)

if consume_delimiter:
await self._consume_delimiter(delimiter)

async def read(self, size=-1):
async def read(self, size: int = -1) -> bytes:
return await self._read_from(self._iter_with_buffer(size_hint=size or 0), size)

async def readall(self):
async def readall(self) -> bytes:
"""Read and return all remaining data in the request body.
Warning:
Expand All @@ -286,7 +308,9 @@ async def readall(self):
"""
return await self._read_from(self._iter_with_buffer())

async def read_until(self, delimiter, size=-1, consume_delimiter=False):
async def read_until(
self, delimiter: bytes, size: int = -1, consume_delimiter: bool = False
) -> bytes:
result = await self._read_from(
self._iter_delimited(delimiter, size_hint=size or 0), size
)
Expand All @@ -306,30 +330,34 @@ async def read_until(self, delimiter, size=-1, consume_delimiter=False):
# pass

@property
def eof(self):
def eof(self) -> bool:
"""Whether the stream is at EOF."""
return self._exhausted and self._buffer_len == self._buffer_pos

def fileno(self):
def fileno(self) -> NoReturn:
"""Raise an instance of OSError since a file descriptor is not used."""
raise OSError('This IO object does not use a file descriptor')

def isatty(self):
def isatty(self) -> bool:
"""Return ``False`` always."""
return False

def readable(self):
def readable(self) -> bool:
"""Return ``True`` always."""
return True

def seekable(self):
def seekable(self) -> bool:
"""Return ``False`` always."""
return False

def writable(self):
def writable(self) -> bool:
"""Return ``False`` always."""
return False

def tell(self):
def tell(self) -> int:
"""Return the number of bytes read from the stream so far."""
return self._consumed - (self._buffer_len - self._buffer_pos)


class AsyncWritableIO(Protocol):
async def write(self, data: bytes, /) -> None: ...
Loading

0 comments on commit 3e32ff7

Please sign in to comment.