Skip to content

Commit

Permalink
fix(2549): Ensure MultiDict and ImmutableMultiDict copy methods r…
Browse files Browse the repository at this point in the history
…eturn the instance's type (#3009)

* Fix MultiDict and ImmutableMultiDict copy methods

* Fix typing

* Fix typing
  • Loading branch information
provinzkraut authored Jan 22, 2024
1 parent 1cc4b5e commit 1475955
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
26 changes: 18 additions & 8 deletions litestar/datastructures/multi_dicts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from abc import ABC
from typing import Any, Generator, Generic, Iterable, Mapping, TypeVar
from typing import TYPE_CHECKING, Any, Generator, Generic, Iterable, Mapping, TypeVar

from multidict import MultiDict as BaseMultiDict
from multidict import MultiDictProxy, MultiMapping

from litestar.datastructures.upload_file import UploadFile

if TYPE_CHECKING:
from typing_extensions import Self


__all__ = ("FormMultiDict", "ImmutableMultiDict", "MultiDict", "MultiMixin")


Expand Down Expand Up @@ -40,7 +44,8 @@ class MultiDict(BaseMultiDict[T], MultiMixin[T], Generic[T]):
"""MultiDict, using :class:`MultiDict <multidict.MultiDictProxy>`."""

def __init__(self, args: MultiMapping | Mapping[str, T] | Iterable[tuple[str, T]] | None = None) -> None:
"""Initialize ``MultiDict`` from a`MultiMapping``, :class:`Mapping <typing.Mapping>` or an iterable of tuples.
"""Initialize ``MultiDict`` from a`MultiMapping``,
:class:`Mapping <typing.Mapping>` or an iterable of tuples.
Args:
args: Mapping-like structure to create the ``MultiDict`` from
Expand All @@ -57,30 +62,35 @@ def immutable(self) -> ImmutableMultiDict[T]:
"""
return ImmutableMultiDict[T](self) # pyright: ignore

def copy(self) -> Self:
"""Return a shallow copy"""
return type(self)(list(self.multi_items()))


class ImmutableMultiDict(MultiDictProxy[T], MultiMixin[T], Generic[T]):
"""Immutable MultiDict, using class:`MultiDictProxy <multidict.MultiDictProxy>`."""

def __init__(self, args: MultiMapping | Mapping[str, Any] | Iterable[tuple[str, Any]] | None = None) -> None:
"""Initialize ``ImmutableMultiDict`` from a.
``MultiMapping``, :class:`Mapping <typing.Mapping>` or an iterable of tuples.
"""Initialize ``ImmutableMultiDict`` from a `MultiMapping``,
:class:`Mapping <typing.Mapping>` or an iterable of tuples.
Args:
args: Mapping-like structure to create the ``ImmutableMultiDict`` from
"""
super().__init__(BaseMultiDict(args or {}))

def mutable_copy(self) -> MultiDict[T]:
"""Create a mutable copy as a.
:class:`MultiDict`
"""Create a mutable copy as a :class:`MultiDict`
Returns:
A mutable multi dict
"""
return MultiDict(list(self.multi_items()))

def copy(self) -> Self: # type: ignore[override]
"""Return a shallow copy"""
return type(self)(self.items())


class FormMultiDict(ImmutableMultiDict[Any]):
"""MultiDict for form data."""
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/test_datastructures/test_multi_dicts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Union
from __future__ import annotations

import pytest
from pytest_mock import MockerFixture
Expand All @@ -8,14 +8,14 @@


@pytest.mark.parametrize("multi_class", [MultiDict, ImmutableMultiDict])
def test_multi_to_dict(multi_class: Type[Union[MultiDict, ImmutableMultiDict]]) -> None:
def test_multi_to_dict(multi_class: type[MultiDict | ImmutableMultiDict]) -> None:
multi = multi_class([("key", "value"), ("key", "value2"), ("key2", "value3")])

assert multi.dict() == {"key": ["value", "value2"], "key2": ["value3"]}


@pytest.mark.parametrize("multi_class", [MultiDict, ImmutableMultiDict])
def test_multi_multi_items(multi_class: Type[Union[MultiDict, ImmutableMultiDict]]) -> None:
def test_multi_multi_items(multi_class: type[MultiDict | ImmutableMultiDict]) -> None:
data = [("key", "value"), ("key", "value2"), ("key2", "value3")]
multi = multi_class(data)

Expand Down Expand Up @@ -47,3 +47,11 @@ async def test_form_multi_dict_close(mocker: MockerFixture) -> None:
await multi.close()

assert close.call_count == 2


@pytest.mark.parametrize("type_", [MultiDict, ImmutableMultiDict])
def test_copy(type_: type[MultiDict | ImmutableMultiDict]) -> None:
d = type_([("foo", "bar"), ("foo", "baz")])
copy = d.copy()
assert set(d.multi_items()) == set(copy.multi_items())
assert isinstance(d, type_)

0 comments on commit 1475955

Please sign in to comment.