diff --git a/multipart/multipart.py b/multipart/multipart.py index ea8ccca..bd836bd 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -18,19 +18,20 @@ if TYPE_CHECKING: # pragma: no cover from typing import Callable, Protocol, TypedDict - class QuerystringCallbacks(TypedDict, total=False): + class BaseCallbacks(TypedDict, total=False): + on_end: Callable[[], None] + + class QuerystringCallbacks(BaseCallbacks, total=False): on_field_start: Callable[[], None] on_field_name: Callable[[bytes, int, int], None] on_field_data: Callable[[bytes, int, int], None] on_field_end: Callable[[], None] - on_end: Callable[[], None] - class OctetStreamCallbacks(TypedDict, total=False): + class OctetStreamCallbacks(BaseCallbacks, total=False): on_start: Callable[[], None] on_data: Callable[[bytes, int, int], None] - on_end: Callable[[], None] - class MultipartCallbacks(TypedDict, total=False): + class MultipartCallbacks(BaseCallbacks, total=False): on_part_begin: Callable[[], None] on_part_data: Callable[[bytes, int, int], None] on_part_end: Callable[[], None] @@ -39,7 +40,6 @@ class MultipartCallbacks(TypedDict, total=False): on_header_value: Callable[[bytes, int, int], None] on_header_end: Callable[[], None] on_headers_finished: Callable[[], None] - on_end: Callable[[], None] class FormParserConfig(TypedDict): UPLOAD_DIR: str | None @@ -608,8 +608,9 @@ class BaseParser: performance. """ - def __init__(self) -> None: + def __init__(self, callbacks: BaseCallbacks) -> None: self.logger = logging.getLogger(__name__) + self.callbacks = callbacks def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None): """This function calls a provided callback with some data. If the @@ -696,8 +697,7 @@ class OctetStreamParser(BaseParser): """ def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")): - super().__init__() - self.callbacks = callbacks + super().__init__(callbacks) self._started = False if not isinstance(max_size, Number) or max_size < 1: @@ -795,12 +795,10 @@ class QuerystringParser(BaseParser): def __init__( self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf") ) -> None: - super().__init__() + super().__init__(callbacks) self.state = QuerystringState.BEFORE_FIELD self._found_sep = False - self.callbacks = callbacks - # Max-size stuff if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) @@ -1055,12 +1053,10 @@ def __init__( self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf") ) -> None: # Initialize parser state. - super().__init__() + super().__init__(callbacks) self.state = MultipartState.START self.index = self.flags = 0 - self.callbacks = callbacks - if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 93fd38d..086d32e 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -290,8 +290,8 @@ def test_handles_rfc_2231(self): class TestBaseParser(unittest.TestCase): def setUp(self): - self.b = BaseParser() - self.b.callbacks = {} + callbacks = {} + self.b = BaseParser(callbacks) def test_callbacks(self): called = 0