diff --git a/fsspec/implementations/sftp.py b/fsspec/implementations/sftp.py index 77f7b370c..c8625072a 100644 --- a/fsspec/implementations/sftp.py +++ b/fsspec/implementations/sftp.py @@ -13,6 +13,35 @@ logger = logging.getLogger("fsspec.sftp") +def _patch_SFTPFile(file): + """ + This patcher tries to rectify the file API of paramiko.sftp_file.SFTPFile + to be more consistent with io.IOBase. + https://github.com/paramiko/paramiko/issues/2452 + """ + + if getattr(file, "_patched_by_fsspec", False): + return file + file._patched_by_fsspec = True + + self = file + + real_seek = self.seek + def seek(offset: int, whence: int = 0) -> int: + result = real_seek(offset, whence) + return self.tell() if result is None else result + self.seek = seek + + real_write = self.write + def write(data) -> int: + old_offset = self.tell() + result = real_write(data) + return self.tell() - old_offset if result is None else result + self.write = write + + return self + + class SFTPFileSystem(AbstractFileSystem): """Files over SFTP/SSH @@ -141,6 +170,9 @@ def get_file(self, rpath, lpath, **kwargs): else: self.ftp.get(self._strip_protocol(rpath), lpath) + def _open_patched(self, path, mode="rb", **kwargs): + return _patch_SFTPFile(self.ftp.open(path, mode, **kwargs)) + def _open(self, path, mode="rb", block_size=None, **kwargs): """ block_size: int or None @@ -151,14 +183,14 @@ def _open(self, path, mode="rb", block_size=None, **kwargs): if kwargs.get("autocommit", True) is False: # writes to temporary file, move on commit path2 = "/".join([self.temppath, str(uuid.uuid4())]) - f = self.ftp.open(path2, mode, bufsize=block_size if block_size else -1) + f = self._open_patched(path2, mode, bufsize=block_size if block_size else -1) f.temppath = path2 f.targetpath = path f.fs = self f.commit = types.MethodType(commit_a_file, f) f.discard = types.MethodType(discard_a_file, f) else: - f = self.ftp.open(path, mode, bufsize=block_size if block_size else -1) + f = self._open_patched(path, mode, bufsize=block_size if block_size else -1) return f def _rm(self, path):