Skip to content

Commit

Permalink
Fix data pattern with list value
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg committed Apr 2, 2024
1 parent 366dd0b commit 9118255
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
16 changes: 12 additions & 4 deletions respx/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,17 @@ class Data(MultiItemsMixin, Pattern):
key = "data"
value: MultiItems

def clean(self, value: Dict) -> MultiItems:
def _normalize_value(self, value: Any) -> Union[str, List[str]]:
if value is None:
return ""
elif isinstance(value, (tuple, list)):
return [str(v) for v in value]
else:
return str(value)

def clean(self, value: Dict[str, Any]) -> MultiItems:
return MultiItems(
(key, "" if value is None else str(value)) for key, value in value.items()
(key, self._normalize_value(value)) for key, value in value.items()
)

def parse(self, request: httpx.Request) -> Any:
Expand All @@ -563,7 +571,7 @@ class Files(MultiItemsMixin, Pattern):
key = "files"
value: MultiItems

def _normalize_file_value(self, value: FileTypes) -> Tuple[Any, Any]:
def _normalize_file_value(self, value: FileTypes) -> Tuple[Tuple[Any, Any]]:
# Mimic httpx `FileField` to normalize `files` kwarg to shortest tuple style
if isinstance(value, tuple):
filename, fileobj = value[:2]
Expand All @@ -580,7 +588,7 @@ def _normalize_file_value(self, value: FileTypes) -> Tuple[Any, Any]:
elif isinstance(fileobj, str):
fileobj = fileobj.encode()

return filename, fileobj
return ((filename, fileobj),)

def clean(self, value: RequestFiles) -> MultiItems:
if isinstance(value, Mapping):
Expand Down
29 changes: 20 additions & 9 deletions respx/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import email
from collections import defaultdict
from datetime import datetime
from email.message import Message
from typing import (
Any,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Expand All @@ -23,15 +25,24 @@
import httpx


class MultiItems(dict):
class MultiItems(defaultdict):
def __init__(self, values: Optional[Iterable[Tuple[str, Any]]] = None) -> None:
super().__init__(tuple)
if values is not None:
for key, value in values:
if isinstance(value, (tuple, list)):
self[key] += tuple(value) # Convert list to tuple and extend
else:
self[key] += (value,) # Extend with value

def get_list(self, key: str) -> List[Any]:
try:
return [self[key]]
except KeyError: # pragma: no cover
return []
return list(self[key])

def multi_items(self) -> List[Tuple[str, str]]:
return [(key, value) for key, values in self.items() for value in values]

def multi_items(self) -> List[Tuple[str, Any]]:
return list(self.items())
def append(self, key: str, value: Any) -> None:
self[key] += (value,)


def _parse_multipart_form_data(
Expand All @@ -55,10 +66,10 @@ def _parse_multipart_form_data(
assert isinstance(value, bytes)
if content_type.startswith("text/") and filename is None:
# Text field
data[name] = value.decode(payload.get_content_charset() or "utf-8")
data.append(name, value.decode(payload.get_content_charset() or "utf-8"))
else:
# File field
files[name] = filename, value
files.append(name, (filename, value))

return data, files

Expand Down
6 changes: 6 additions & 0 deletions tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ def test_content_pattern(lookup, content, expected):
None,
True,
),
(
Lookup.EQUAL,
{"foo": "bar", "ham": ["spam", "egg"]},
None,
True,
),
(
Lookup.EQUAL,
{"foo": "bar", "ham": "spam"},
Expand Down

0 comments on commit 9118255

Please sign in to comment.