Skip to content

Commit

Permalink
Fix UPath.rename type signature (#258)
Browse files Browse the repository at this point in the history
* typesafety: add checks for rename signature

* upath: fix type signature of UPath.rename

* tests: fix issue with ssl detection on test setup
  • Loading branch information
ap-- authored Aug 27, 2024
1 parent e53d8a4 commit 320425f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
10 changes: 10 additions & 0 deletions typesafety/test_upath_interface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,13 @@
from upath import UPath
reveal_type(UPath("abc").walk()) # N: Revealed type is "typing.Iterator[tuple[upath.core.UPath, builtins.list[builtins.str], builtins.list[builtins.str]]]"
- case: upath_rename_extra_kwargs
disable_cache: false
main: |
from upath import UPath
UPath("abc").rename("efg")
UPath("recursive bool").rename("efg", recursive=True)
UPath("maxdepth int").rename("efg", maxdepth=1)
UPath("untyped extras").rename("efg", overwrite=True, something="else")
16 changes: 9 additions & 7 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
from typing import Mapping
from typing import Sequence
from typing import TextIO
from typing import TypedDict
from typing import TypeVar
from typing import overload
from urllib.parse import urlsplit

if sys.version_info >= (3, 11):
from typing import Self
from typing import Unpack
else:
from typing_extensions import Self
from typing_extensions import Unpack

from fsspec.registry import get_filesystem_class
from fsspec.spec import AbstractFileSystem
Expand Down Expand Up @@ -94,9 +91,7 @@ def _make_instance(cls, args, kwargs):
return cls(*args, **kwargs)


class _UPathRenameParams(TypedDict, total=False):
recursive: bool
maxdepth: int | None
_unset: Any = object()


# accessors are deprecated
Expand Down Expand Up @@ -1016,7 +1011,10 @@ def rmdir(self, recursive: bool = True) -> None: # fixme: non-standard
def rename(
self,
target: str | os.PathLike[str] | UPath,
**kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib
*, # note: non-standard compared to pathlib
recursive: bool = _unset,
maxdepth: int | None = _unset,
**kwargs: Any,
) -> Self:
if isinstance(target, str) and self.storage_options:
target = UPath(target, **self.storage_options)
Expand All @@ -1040,6 +1038,10 @@ def rename(
parent = parent.resolve()
target_ = parent.joinpath(os.path.normpath(target))
assert isinstance(target_, type(self)), "identical protocols enforced above"
if recursive is not _unset:
kwargs["recursive"] = recursive
if maxdepth is not _unset:
kwargs["maxdepth"] = maxdepth
self.fs.mv(
self.path,
target_.path,
Expand Down
11 changes: 7 additions & 4 deletions upath/implementations/smb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import os
import sys
import warnings
from typing import Any

if sys.version_info >= (3, 11):
from typing import Self
from typing import Unpack
else:
from typing_extensions import Self
from typing_extensions import Unpack

import smbprotocol.exceptions

from upath import UPath
from upath.core import _UPathRenameParams

_unset: Any = object()


class SMBPath(UPath):
Expand Down Expand Up @@ -44,7 +44,10 @@ def iterdir(self):
def rename(
self,
target: str | os.PathLike[str] | UPath,
**kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib
*,
recursive: bool = _unset,
maxdepth: int | None = _unset,
**kwargs: Any,
) -> Self:
if kwargs.pop("recursive", None) is not None:
warnings.warn(
Expand Down
6 changes: 4 additions & 2 deletions upath/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def xfail_if_version(module, *, reason, **conditions):
def xfail_if_no_ssl_connection(func):
try:
import requests

except ImportError:
return pytest.mark.skip(reason="requests not installed")(func)
try:
requests.get("https://example.com")
except (ImportError, requests.exceptions.SSLError):
except (requests.exceptions.ConnectionError, requests.exceptions.SSLError):
return pytest.mark.xfail(reason="No SSL connection")(func)
else:
return func
Expand Down

0 comments on commit 320425f

Please sign in to comment.