Skip to content

Commit

Permalink
Preserve _keep_empty in copying and encoding.
Browse files Browse the repository at this point in the history
  • Loading branch information
serhiy-storchaka committed Dec 5, 2024
1 parent eaa9ce6 commit e5c31dd
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 49 deletions.
123 changes: 79 additions & 44 deletions Lib/test/test_urlparse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import functools
import sys
import unicodedata
import unittest
import urllib.parse
from urllib.parse import urlparse, urlsplit, urlunparse, urlunsplit
from urllib.parse import urldefrag, urlparse, urlsplit, urlunparse, urlunsplit

RFC1808_BASE = "http://a/b/c/d;p?q#f"
RFC2396_BASE = "http://a/b/c/d;p?q"
Expand Down Expand Up @@ -391,14 +392,14 @@ def checkJoin(self, base, relurl, expected, *, relroundtrip=True):
self.assertEqual(urllib.parse.urljoin(baseb, relurlb), expectedb)

if relroundtrip:
relurl2 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurl))
relurl2 = urlunsplit(urlsplit(relurl))
self.assertEqual(urllib.parse.urljoin(base, relurl2), expected)
relurlb2 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb))
relurlb2 = urlunsplit(urlsplit(relurlb))
self.assertEqual(urllib.parse.urljoin(baseb, relurlb2), expectedb)

relurl3 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurl, allow_none=True))
relurl3 = urlunsplit(urlsplit(relurl, allow_none=True))
self.assertEqual(urllib.parse.urljoin(base, relurl3), expected)
relurlb3 = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb, allow_none=True))
relurlb3 = urlunsplit(urlsplit(relurlb, allow_none=True))
self.assertEqual(urllib.parse.urljoin(baseb, relurlb3), expectedb)

def test_unparse_parse(self):
Expand Down Expand Up @@ -458,9 +459,9 @@ def test_RFC1808(self):

def test_RFC2368(self):
# Issue 11467: path that starts with a number is not parsed correctly
self.assertEqual(urllib.parse.urlparse('mailto:[email protected]'),
self.assertEqual(urlparse('mailto:[email protected]'),
('mailto', '', '[email protected]', '', '', ''))
self.assertEqual(urllib.parse.urlparse('mailto:[email protected]', allow_none=True),
self.assertEqual(urlparse('mailto:[email protected]', allow_none=True),
('mailto', None, '[email protected]', None, None, None))

def test_RFC2396(self):
Expand Down Expand Up @@ -1119,50 +1120,50 @@ def test_withoutscheme(self, allow_none):
# RFC 1808 specifies that netloc should start with //, urlparse expects
# the same, otherwise it classifies the portion of url as path.
none = None if allow_none else ''
self.assertEqual(urllib.parse.urlparse("path", allow_none=allow_none),
self.assertEqual(urlparse("path", allow_none=allow_none),
(none, none, 'path', none, none, none))
self.assertEqual(urllib.parse.urlparse("//www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse("//www.python.org:80", allow_none=allow_none),
(none, 'www.python.org:80', '', none, none, none))
self.assertEqual(urllib.parse.urlparse("http://www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse("http://www.python.org:80", allow_none=allow_none),
('http', 'www.python.org:80', '', none, none, none))
# Repeat for bytes input
none = None if allow_none else b''
self.assertEqual(urllib.parse.urlparse(b"path", allow_none=allow_none),
self.assertEqual(urlparse(b"path", allow_none=allow_none),
(none, none, b'path', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"//www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse(b"//www.python.org:80", allow_none=allow_none),
(none, b'www.python.org:80', b'', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"http://www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse(b"http://www.python.org:80", allow_none=allow_none),
(b'http', b'www.python.org:80', b'', none, none, none))

@parametrise_allow_none
def test_portseparator(self, allow_none):
# Issue 754016 makes changes for port separator ':' from scheme separator
none = None if allow_none else ''
self.assertEqual(urllib.parse.urlparse("http:80", allow_none=allow_none),
self.assertEqual(urlparse("http:80", allow_none=allow_none),
('http', none, '80', none, none, none))
self.assertEqual(urllib.parse.urlparse("https:80", allow_none=allow_none),
self.assertEqual(urlparse("https:80", allow_none=allow_none),
('https', none, '80', none, none, none))
self.assertEqual(urllib.parse.urlparse("path:80", allow_none=allow_none),
self.assertEqual(urlparse("path:80", allow_none=allow_none),
('path', none, '80', none, none, none))
self.assertEqual(urllib.parse.urlparse("http:", allow_none=allow_none),
self.assertEqual(urlparse("http:", allow_none=allow_none),
('http', none, '', none, none, none))
self.assertEqual(urllib.parse.urlparse("https:", allow_none=allow_none),
self.assertEqual(urlparse("https:", allow_none=allow_none),
('https', none, '', none, none, none))
self.assertEqual(urllib.parse.urlparse("http://www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse("http://www.python.org:80", allow_none=allow_none),
('http', 'www.python.org:80', '', none, none, none))
# As usual, need to check bytes input as well
none = None if allow_none else b''
self.assertEqual(urllib.parse.urlparse(b"http:80", allow_none=allow_none),
self.assertEqual(urlparse(b"http:80", allow_none=allow_none),
(b'http', none, b'80', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"https:80", allow_none=allow_none),
self.assertEqual(urlparse(b"https:80", allow_none=allow_none),
(b'https', none, b'80', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"path:80", allow_none=allow_none),
self.assertEqual(urlparse(b"path:80", allow_none=allow_none),
(b'path', none, b'80', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"http:", allow_none=allow_none),
self.assertEqual(urlparse(b"http:", allow_none=allow_none),
(b'http', none, b'', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"https:", allow_none=allow_none),
self.assertEqual(urlparse(b"https:", allow_none=allow_none),
(b'https', none, b'', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"http://www.python.org:80", allow_none=allow_none),
self.assertEqual(urlparse(b"http://www.python.org:80", allow_none=allow_none),
(b'http', b'www.python.org:80', b'', none, none, none))

def test_usingsys(self):
Expand All @@ -1173,24 +1174,24 @@ def test_usingsys(self):
def test_anyscheme(self, allow_none):
# Issue 7904: s3://foo.com/stuff has netloc "foo.com".
none = None if allow_none else ''
self.assertEqual(urllib.parse.urlparse("s3://foo.com/stuff", allow_none=allow_none),
self.assertEqual(urlparse("s3://foo.com/stuff", allow_none=allow_none),
('s3', 'foo.com', '/stuff', none, none, none))
self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff", allow_none=allow_none),
self.assertEqual(urlparse("x-newscheme://foo.com/stuff", allow_none=allow_none),
('x-newscheme', 'foo.com', '/stuff', none, none, none))
self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none),
self.assertEqual(urlparse("x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none),
('x-newscheme', 'foo.com', '/stuff', none, 'query', 'fragment'))
self.assertEqual(urllib.parse.urlparse("x-newscheme://foo.com/stuff?query", allow_none=allow_none),
self.assertEqual(urlparse("x-newscheme://foo.com/stuff?query", allow_none=allow_none),
('x-newscheme', 'foo.com', '/stuff', none, 'query', none))

# And for bytes...
none = None if allow_none else b''
self.assertEqual(urllib.parse.urlparse(b"s3://foo.com/stuff", allow_none=allow_none),
self.assertEqual(urlparse(b"s3://foo.com/stuff", allow_none=allow_none),
(b's3', b'foo.com', b'/stuff', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff", allow_none=allow_none),
self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff", allow_none=allow_none),
(b'x-newscheme', b'foo.com', b'/stuff', none, none, none))
self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none),
self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff?query#fragment", allow_none=allow_none),
(b'x-newscheme', b'foo.com', b'/stuff', none, b'query', b'fragment'))
self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query", allow_none=allow_none),
self.assertEqual(urlparse(b"x-newscheme://foo.com/stuff?query", allow_none=allow_none),
(b'x-newscheme', b'foo.com', b'/stuff', none, b'query', none))

def test_default_scheme(self):
Expand Down Expand Up @@ -1274,12 +1275,10 @@ def test_mixed_types_rejected(self):
with self.assertRaisesRegex(TypeError, "Cannot mix str"):
urllib.parse.urljoin(b"http://python.org", "http://python.org")

def _check_result_type(self, str_type):
num_args = len(str_type._fields)
def _check_result_type(self, str_type, str_args):
bytes_type = str_type._encoded_counterpart
self.assertIs(bytes_type._decoded_counterpart, str_type)
str_args = ('',) * num_args
bytes_args = (b'',) * num_args
bytes_args = tuple(self._encode(s) for s in str_args)
str_result = str_type(*str_args)
bytes_result = bytes_type(*bytes_args)
encoding = 'ascii'
Expand All @@ -1298,16 +1297,52 @@ def _check_result_type(self, str_type):
self.assertEqual(str_result.encode(encoding), bytes_result)
self.assertEqual(str_result.encode(encoding, errors), bytes_args)
self.assertEqual(str_result.encode(encoding, errors), bytes_result)
for result in str_result, bytes_result:
self.assertEqual(copy.copy(result), result)
self.assertEqual(copy.deepcopy(result), result)
self.assertEqual(copy.replace(result), result)
self.assertEqual(result._replace(), result)

def test_result_pairs(self):
# Check encoding and decoding between result pairs
result_types = [
urllib.parse.DefragResult,
urllib.parse.SplitResult,
urllib.parse.ParseResult,
]
for result_type in result_types:
self._check_result_type(result_type)
self._check_result_type(urllib.parse.DefragResult, ('', ''))
self._check_result_type(urllib.parse.DefragResult, ('', None))
self._check_result_type(urllib.parse.SplitResult, ('', '', '', '', ''))
self._check_result_type(urllib.parse.SplitResult, (None, None, '', None, None))
self._check_result_type(urllib.parse.ParseResult, ('', '', '', '', '', ''))
self._check_result_type(urllib.parse.ParseResult, (None, None, '', None, None, None))

def test_result_encoding_decoding(self):
def check(str_result, bytes_result):
self.assertEqual(str_result.encode(), bytes_result)
self.assertEqual(str_result.encode().geturl(), bytes_result.geturl())
self.assertEqual(bytes_result.decode(), str_result)
self.assertEqual(bytes_result.decode().geturl(), str_result.geturl())

url = 'http://example.com/?#'
burl = url.encode()
for func in urldefrag, urlsplit, urlparse:
check(func(url, allow_none=True), func(burl, allow_none=True))
check(func(url), func(burl))

def test_result_copying(self):
def check(result):
self.assertEqual(copy.copy(result), result)
self.assertEqual(copy.copy(result).geturl(), result.geturl())
self.assertEqual(copy.deepcopy(result), result)
self.assertEqual(copy.deepcopy(result).geturl(), result.geturl())
self.assertEqual(copy.replace(result), result)
self.assertEqual(copy.replace(result).geturl(), result.geturl())
self.assertEqual(result._replace(), result)
self.assertEqual(result._replace().geturl(), result.geturl())

url = 'http://example.com/?#'
burl = url.encode()
for func in urldefrag, urlsplit, urlparse:
check(func(url))
check(func(url, allow_none=True))
check(func(burl))
check(func(burl, allow_none=True))

def test_parse_qs_encoding(self):
result = urllib.parse.parse_qs("key=\u0141%E9", encoding="latin-1")
Expand Down
44 changes: 39 additions & 5 deletions Lib/urllib/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,29 @@ class _ResultMixinStr(object):
__slots__ = ()

def encode(self, encoding='ascii', errors='strict'):
return self._encoded_counterpart(*(x.encode(encoding, errors)
result = self._encoded_counterpart(*(x.encode(encoding, errors)
if x is not None else None
for x in self))
try:
result._keep_empty = self._keep_empty
except AttributeError:
pass
return result


class _ResultMixinBytes(object):
"""Standard approach to decoding parsed results from bytes to str"""
__slots__ = ()

def decode(self, encoding='ascii', errors='strict'):
return self._decoded_counterpart(*(x.decode(encoding, errors)
result = self._decoded_counterpart(*(x.decode(encoding, errors)
if x is not None else None
for x in self))
try:
result._keep_empty = self._keep_empty
except AttributeError:
pass
return result


class _NetlocResultMixinBase(object):
Expand Down Expand Up @@ -270,20 +280,44 @@ def _hostinfo(self):
_UNSPECIFIED = ['not specified']
_ALLOW_NONE_DEFAULT = False

class _DefragResultBase(namedtuple('_DefragResultBase', 'url fragment')):
class _ResultBase:
def __replace__(self, /, **kwargs):
result = super().__replace__(**kwargs)
try:
result._keep_empty = self._keep_empty
except AttributeError:
pass
return result

def _replace(self, /, **kwargs):
result = super()._replace(**kwargs)
try:
result._keep_empty = self._keep_empty
except AttributeError:
pass
return result

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self


class _DefragResultBase(_ResultBase, namedtuple('_DefragResultBase', 'url fragment')):
def geturl(self):
if self.fragment or (self.fragment is not None and
getattr(self, '_keep_empty', _ALLOW_NONE_DEFAULT)):
return self.url + self._HASH + self.fragment
else:
return self.url

class _SplitResultBase(namedtuple(
class _SplitResultBase(_ResultBase, namedtuple(
'_SplitResultBase', 'scheme netloc path query fragment')):
def geturl(self):
return urlunsplit(self)

class _ParseResultBase(namedtuple(
class _ParseResultBase(_ResultBase, namedtuple(
'_ParseResultBase', 'scheme netloc path params query fragment')):
def geturl(self):
return urlunparse(self)
Expand Down

0 comments on commit e5c31dd

Please sign in to comment.