Skip to content

Commit

Permalink
UPath.joinpath raise error on protocol mismatch (#264)
Browse files Browse the repository at this point in the history
* tests: add test defining protocol mismatch behavior

* upath: fix UPath raise ValueError on mismatch instead of TypeError

* upath.implementations.cloud: raise early if bucket/container missing

* upath: fix protocol matching on <=3.11
  • Loading branch information
ap-- authored Aug 31, 2024
1 parent e2451e9 commit 3d4ec00
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 17 deletions.
17 changes: 17 additions & 0 deletions upath/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import os
import re
from pathlib import PurePath
from typing import TYPE_CHECKING
from typing import Any

if TYPE_CHECKING:
from upath.core import UPath

__all__ = [
"get_upath_protocol",
"normalize_empty_netloc",
"compatible_protocol",
]

# Regular expression to match fsspec style protocols.
Expand Down Expand Up @@ -59,3 +64,15 @@ def normalize_empty_netloc(pth: str) -> str:
path = m.group("path")
pth = f"{protocol}:///{path}"
return pth


def compatible_protocol(protocol: str, *args: str | os.PathLike[str] | UPath) -> bool:
"""check if UPath protocols are compatible"""
for arg in args:
other_protocol = get_upath_protocol(arg)
# consider protocols equivalent if they match up to the first "+"
other_protocol = other_protocol.partition("+")[0]
# protocols: only identical (or empty "") protocols can combine
if other_protocol and other_protocol != protocol:
return False
return True
24 changes: 7 additions & 17 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from upath._flavour import LazyFlavourDescriptor
from upath._flavour import upath_get_kwargs_from_url
from upath._flavour import upath_urijoin
from upath._protocol import compatible_protocol
from upath._protocol import get_upath_protocol
from upath._stat import UPathStatResult
from upath.registry import get_upath_class
Expand Down Expand Up @@ -251,23 +252,12 @@ def __init__(
self._storage_options = storage_options.copy()

# check that UPath subclasses in args are compatible
# --> ensures items in _raw_paths are compatible
for arg in args:
if not isinstance(arg, UPath):
continue
# protocols: only identical (or empty "") protocols can combine
if arg.protocol and arg.protocol != self._protocol:
raise TypeError("can't combine different UPath protocols as parts")
# storage_options: args may not define other storage_options
if any(
self._storage_options.get(key) != value
for key, value in arg.storage_options.items()
):
# TODO:
# Future versions of UPath could verify that storage_options
# can be combined between UPath instances. Not sure if this
# is really necessary though. A warning might be enough...
pass
# TODO:
# Future versions of UPath could verify that storage_options
# can be combined between UPath instances. Not sure if this
# is really necessary though. A warning might be enough...
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")

# fill ._raw_paths
if hasattr(self, "_raw_paths"):
Expand Down
7 changes: 7 additions & 0 deletions upath/implementations/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
class CloudPath(UPath):
__slots__ = ()

def __init__(
self, *args, protocol: str | None = None, **storage_options: Any
) -> None:
super().__init__(*args, protocol=protocol, **storage_options)
if not self.drive and len(self.parts) > 1:
raise ValueError("non key-like path provided (bucket/container missing)")

@classmethod
def _transform_init_args(
cls,
Expand Down
15 changes: 15 additions & 0 deletions upath/implementations/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import MutableMapping
from urllib.parse import SplitResult

from upath._protocol import compatible_protocol
from upath.core import UPath

__all__ = [
Expand Down Expand Up @@ -141,6 +142,8 @@ def __new__(
raise NotImplementedError(
f"cannot instantiate {cls.__name__} on your system"
)
if not compatible_protocol("", *args):
raise ValueError("can't combine incompatible UPath protocols")
obj = super().__new__(cls, *args)
obj._protocol = ""
return obj # type: ignore[return-value]
Expand All @@ -152,6 +155,11 @@ def __init__(
self._drv, self._root, self._parts = type(self)._parse_args(args)
_upath_init(self)

def _make_child(self, args):
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")
return super()._make_child(args)

@classmethod
def _from_parts(cls, *args, **kwargs):
obj = super(Path, cls)._from_parts(*args, **kwargs)
Expand Down Expand Up @@ -205,6 +213,8 @@ def __new__(
raise NotImplementedError(
f"cannot instantiate {cls.__name__} on your system"
)
if not compatible_protocol("", *args):
raise ValueError("can't combine incompatible UPath protocols")
obj = super().__new__(cls, *args)
obj._protocol = ""
return obj # type: ignore[return-value]
Expand All @@ -216,6 +226,11 @@ def __init__(
self._drv, self._root, self._parts = self._parse_args(args)
_upath_init(self)

def _make_child(self, args):
if not compatible_protocol(self._protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")
return super()._make_child(args)

@classmethod
def _from_parts(cls, *args, **kwargs):
obj = super(Path, cls)._from_parts(*args, **kwargs)
Expand Down
29 changes: 29 additions & 0 deletions upath/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,32 @@ def test_query_string(uri, query_str):
p = UPath(uri)
assert str(p).endswith(query_str)
assert p.path.endswith(query_str)


@pytest.mark.parametrize(
"base,join",
[
("/a", "s3://bucket/b"),
("s3://bucket/a", "gs://b/c"),
("gs://bucket/a", "memory://b/c"),
("memory://bucket/a", "s3://b/c"),
],
)
def test_joinpath_on_protocol_mismatch(base, join):
with pytest.raises(ValueError):
UPath(base).joinpath(UPath(join))
with pytest.raises(ValueError):
UPath(base) / UPath(join)


@pytest.mark.parametrize(
"base,join",
[
("/a", "s3://bucket/b"),
("s3://bucket/a", "gs://b/c"),
("gs://bucket/a", "memory://b/c"),
("memory://bucket/a", "s3://b/c"),
],
)
def test_joinuri_on_protocol_mismatch(base, join):
assert UPath(base).joinuri(UPath(join)) == UPath(join)

0 comments on commit 3d4ec00

Please sign in to comment.