From 320425fae9a71f7baee8d300cf75cbc533815b29 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Wed, 28 Aug 2024 01:44:54 +0200 Subject: [PATCH] Fix UPath.rename type signature (#258) * typesafety: add checks for rename signature * upath: fix type signature of UPath.rename * tests: fix issue with ssl detection on test setup --- typesafety/test_upath_interface.yml | 10 ++++++++++ upath/core.py | 16 +++++++++------- upath/implementations/smb.py | 11 +++++++---- upath/tests/utils.py | 6 ++++-- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/typesafety/test_upath_interface.yml b/typesafety/test_upath_interface.yml index 219b3a49..116411fb 100644 --- a/typesafety/test_upath_interface.yml +++ b/typesafety/test_upath_interface.yml @@ -565,3 +565,13 @@ from upath import UPath reveal_type(UPath("abc").walk()) # N: Revealed type is "typing.Iterator[tuple[upath.core.UPath, builtins.list[builtins.str], builtins.list[builtins.str]]]" + +- case: upath_rename_extra_kwargs + disable_cache: false + main: | + from upath import UPath + + UPath("abc").rename("efg") + UPath("recursive bool").rename("efg", recursive=True) + UPath("maxdepth int").rename("efg", maxdepth=1) + UPath("untyped extras").rename("efg", overwrite=True, something="else") diff --git a/upath/core.py b/upath/core.py index 9a84f21c..3e11dfa2 100644 --- a/upath/core.py +++ b/upath/core.py @@ -15,17 +15,14 @@ from typing import Mapping from typing import Sequence from typing import TextIO -from typing import TypedDict from typing import TypeVar from typing import overload from urllib.parse import urlsplit if sys.version_info >= (3, 11): from typing import Self - from typing import Unpack else: from typing_extensions import Self - from typing_extensions import Unpack from fsspec.registry import get_filesystem_class from fsspec.spec import AbstractFileSystem @@ -94,9 +91,7 @@ def _make_instance(cls, args, kwargs): return cls(*args, **kwargs) -class _UPathRenameParams(TypedDict, total=False): - recursive: bool - maxdepth: int | None +_unset: Any = object() # accessors are deprecated @@ -1016,7 +1011,10 @@ def rmdir(self, recursive: bool = True) -> None: # fixme: non-standard def rename( self, target: str | os.PathLike[str] | UPath, - **kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib + *, # note: non-standard compared to pathlib + recursive: bool = _unset, + maxdepth: int | None = _unset, + **kwargs: Any, ) -> Self: if isinstance(target, str) and self.storage_options: target = UPath(target, **self.storage_options) @@ -1040,6 +1038,10 @@ def rename( parent = parent.resolve() target_ = parent.joinpath(os.path.normpath(target)) assert isinstance(target_, type(self)), "identical protocols enforced above" + if recursive is not _unset: + kwargs["recursive"] = recursive + if maxdepth is not _unset: + kwargs["maxdepth"] = maxdepth self.fs.mv( self.path, target_.path, diff --git a/upath/implementations/smb.py b/upath/implementations/smb.py index ef43de05..8d2642e6 100644 --- a/upath/implementations/smb.py +++ b/upath/implementations/smb.py @@ -3,18 +3,18 @@ import os import sys import warnings +from typing import Any if sys.version_info >= (3, 11): from typing import Self - from typing import Unpack else: from typing_extensions import Self - from typing_extensions import Unpack import smbprotocol.exceptions from upath import UPath -from upath.core import _UPathRenameParams + +_unset: Any = object() class SMBPath(UPath): @@ -44,7 +44,10 @@ def iterdir(self): def rename( self, target: str | os.PathLike[str] | UPath, - **kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib + *, + recursive: bool = _unset, + maxdepth: int | None = _unset, + **kwargs: Any, ) -> Self: if kwargs.pop("recursive", None) is not None: warnings.warn( diff --git a/upath/tests/utils.py b/upath/tests/utils.py index 463ed0a8..4158738b 100644 --- a/upath/tests/utils.py +++ b/upath/tests/utils.py @@ -39,9 +39,11 @@ def xfail_if_version(module, *, reason, **conditions): def xfail_if_no_ssl_connection(func): try: import requests - + except ImportError: + return pytest.mark.skip(reason="requests not installed")(func) + try: requests.get("https://example.com") - except (ImportError, requests.exceptions.SSLError): + except (requests.exceptions.ConnectionError, requests.exceptions.SSLError): return pytest.mark.xfail(reason="No SSL connection")(func) else: return func