diff --git a/environment.yml b/environment.yml index fbce5ffe8c..3cbf1b4307 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 +- asyncssh~=2.19.0 - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 diff --git a/pyproject.toml b/pyproject.toml index 4ee43b8c01..37d9413508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', + "asyncssh~=2.19.0", 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', @@ -175,6 +176,7 @@ requires-python = '>=3.9' [project.entry-points.'aiida.transports'] 'core.local' = 'aiida.transports.plugins.local:LocalTransport' 'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport' +'core.ssh_async' = 'aiida.transports.plugins.ssh_async:AsyncSshTransport' 'core.ssh_auto' = 'aiida.transports.plugins.ssh_auto:SshAutoTransport' [project.entry-points.'aiida.workflows'] @@ -308,6 +310,7 @@ module = 'tests.*' ignore_missing_imports = true module = [ 'ase.*', + 'asyncssh.*', 'bpython.*', 'bs4.*', 'CifFile.*', @@ -388,6 +391,7 @@ testpaths = [ 'tests' ] timeout = 240 +timeout_method = "thread" xfail_strict = true [tool.ruff] diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index e481dcdadb..5afb594fb5 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -105,7 +105,7 @@ async def upload_calculation( if dry_run: workdir = Path(folder.abspath) else: - remote_user = transport.whoami() + remote_user = await transport.whoami_async() remote_working_directory = computer.get_workdir().format(username=remote_user) if not remote_working_directory.strip(): raise exceptions.ConfigurationError( @@ -114,13 +114,13 @@ async def upload_calculation( ) # If it already exists, no exception is raised - if not transport.path_exists(remote_working_directory): + if not await transport.path_exists_async(remote_working_directory): logger.debug( f'[submission of calculation {node.pk}] Path ' f'{remote_working_directory} does not exist, trying to create it' ) try: - transport.makedirs(remote_working_directory) + await transport.makedirs_async(remote_working_directory) except EnvironmentError as exc: raise exceptions.ConfigurationError( f'[submission of calculation {node.pk}] ' @@ -133,14 +133,14 @@ async def upload_calculation( # and I do not have to know the logic, but I just need to # read the absolute path from the calculation properties. workdir = Path(remote_working_directory).joinpath(calc_info.uuid[:2], calc_info.uuid[2:4]) - transport.makedirs(str(workdir), ignore_existing=True) + await transport.makedirs_async(workdir, ignore_existing=True) try: # The final directory may already exist, most likely because this function was already executed once, but # failed and as a result was rescheduled by the engine. In this case it would be fine to delete the folder # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) except OSError: # Move the existing directory to lost+found, log a warning and create a clean directory anyway path_existing = os.path.join(str(workdir), calc_info.uuid[4:]) @@ -151,12 +151,12 @@ async def upload_calculation( ) # Make sure the lost+found directory exists, then copy the existing folder there and delete the original - transport.mkdir(path_lost_found, ignore_existing=True) - transport.copytree(path_existing, path_target) - transport.rmtree(path_existing) + await transport.mkdir_async(path_lost_found, ignore_existing=True) + await transport.copytree_async(path_existing, path_target) + await transport.rmtree_async(path_existing) # Now we can create a clean folder for this calculation - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) finally: workdir = workdir.joinpath(calc_info.uuid[4:]) @@ -171,11 +171,11 @@ async def upload_calculation( # Note: this will possibly overwrite files for root, dirnames, filenames in code.base.repository.walk(): # mkdir of root - transport.makedirs(str(workdir.joinpath(root)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root), ignore_existing=True) # remotely mkdir first for dirname in dirnames: - transport.makedirs(str(workdir.joinpath(root, dirname)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root, dirname), ignore_existing=True) # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` @@ -185,11 +185,11 @@ async def upload_calculation( content = code.base.repository.get_object_content(Path(root) / filename, mode='rb') handle.write(content) handle.flush() - transport.put(handle.name, str(workdir.joinpath(root, filename))) + await transport.put_async(handle.name, workdir.joinpath(root, filename)) if code.filepath_executable.is_absolute(): - transport.chmod(str(code.filepath_executable), 0o755) # rwxr-xr-x + await transport.chmod_async(code.filepath_executable, 0o755) # rwxr-xr-x else: - transport.chmod(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x + await transport.chmod_async(workdir.joinpath(code.filepath_executable), 0o755) # rwxr-xr-x # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() @@ -288,7 +288,7 @@ async def _copy_remote_files(logger, node, computer, transport, remote_copy_list f'remotely, directly on the machine {computer.label}' ) try: - transport.copy(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.copy_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except FileNotFoundError: logger.warning( f'[submission of calculation {node.pk}] Unable to copy remote ' @@ -314,8 +314,8 @@ async def _copy_remote_files(logger, node, computer, transport, remote_copy_list ) remote_dirname = Path(dest_rel_path).parent try: - transport.makedirs(str(workdir.joinpath(remote_dirname)), ignore_existing=True) - transport.symlink(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.makedirs_async(workdir.joinpath(remote_dirname), ignore_existing=True) + await transport.symlink_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except OSError: logger.warning( f'[submission of calculation {node.pk}] Unable to create remote symlink ' @@ -356,14 +356,14 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo # The logic below takes care of an edge case where the source is a file but the target is a directory. In # this case, the v2.5.1 implementation would raise an `IsADirectoryError` exception, because it would try # to open the directory in the sandbox folder as a file when writing the contents. - if file_type_source == FileType.FILE and target and transport.isdir(str(workdir.joinpath(target))): + if file_type_source == FileType.FILE and target and await transport.isdir_async(workdir.joinpath(target)): raise IsADirectoryError # In case the source filename is specified and it is a directory that already exists in the remote, we # want to avoid nested directories in the target path to replicate the behavior of v2.5.1. This is done by # setting the target filename to '.', which means the contents of the node will be copied in the top level # of the temporary directory, whose contents are then copied into the target directory. - if filename and transport.isdir(str(workdir.joinpath(filename))): + if filename and await transport.isdir_async(workdir.joinpath(filename)): filename_target = '.' filepath_target = (dirpath / filename_target).resolve().absolute() @@ -372,9 +372,9 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo if file_type_source == FileType.DIRECTORY: # If the source object is a directory, we copy its entire contents data_node.base.repository.copy_tree(filepath_target, filename_source) - transport.put( + await transport.put_async( f'{dirpath}/*', - str(workdir.joinpath(target)) if target else str(workdir.joinpath('.')), + workdir.joinpath(target) if target else workdir.joinpath('.'), overwrite=True, ) else: @@ -382,15 +382,15 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo with filepath_target.open('wb') as handle: with data_node.base.repository.open(filename_source, 'rb') as source: shutil.copyfileobj(source, handle) - transport.makedirs(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) - transport.put(str(filepath_target), str(workdir.joinpath(target))) + await transport.makedirs_async(workdir.joinpath(Path(target).parent), ignore_existing=True) + await transport.put_async(filepath_target, workdir.joinpath(target)) async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): """Copy the contents of the sandbox folder to the working directory.""" for filename in folder.get_content_list(): logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') - transport.put(folder.get_abs_path(filename), str(workdir.joinpath(filename))) + await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename)) def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode: @@ -461,7 +461,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> N for source_filename in source_list: if transport.has_magic(source_filename): copy_instructions = [] - for globbed_filename in transport.glob(str(source_basepath / source_filename)): + for globbed_filename in await transport.glob_async(source_basepath / source_filename): target_filepath = target_basepath / Path(globbed_filename).relative_to(source_basepath) copy_instructions.append((globbed_filename, target_filepath)) else: @@ -470,10 +470,10 @@ async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> N for source_filepath, target_filepath in copy_instructions: # If the source file is in a (nested) directory, create those directories first in the target directory target_dirname = target_filepath.parent - transport.makedirs(str(target_dirname), ignore_existing=True) + await transport.makedirs_async(target_dirname, ignore_existing=True) try: - transport.copy(str(source_filepath), str(target_filepath)) + await transport.copy_async(source_filepath, target_filepath) except (OSError, ValueError) as exception: EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') else: @@ -612,7 +612,7 @@ async def retrieve_files_from_list( upto what level of the original remotepath nesting the files will be copied. :param transport: the Transport instance. - :param folder: an absolute path to a folder that contains the files to copy. + :param folder: an absolute path to a folder that contains the files to retrieve. :param retrieve_list: the list of files to retrieve. """ workdir = Path(calculation.get_remote_workdir()) @@ -621,7 +621,7 @@ async def retrieve_files_from_list( tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): - remote_names = transport.glob(str(workdir.joinpath(tmp_rname))) + remote_names = await transport.glob_async(workdir.joinpath(tmp_rname)) local_names = [] for rem in remote_names: # get the relative path so to make local_names relative @@ -644,7 +644,7 @@ async def retrieve_files_from_list( abs_item = item if item.startswith('/') else str(workdir.joinpath(item)) if transport.has_magic(abs_item): - remote_names = transport.glob(abs_item) + remote_names = await transport.glob_async(abs_item) local_names = [os.path.split(rem)[1] for rem in remote_names] else: remote_names = [abs_item] @@ -656,6 +656,6 @@ async def retrieve_files_from_list( if rem.startswith('/'): to_get = rem else: - to_get = str(workdir.joinpath(rem)) + to_get = workdir.joinpath(rem) - transport.get(to_get, os.path.join(folder, loc), ignore_nonexisting=True) + await transport.get_async(to_get, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index bae925b25c..1c695910af 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -626,12 +626,12 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport': """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have - to open a connection). To do this you can call ``transports.open()``, or simply + to open a connection). To do this you can call ``transport.open()``, or simply run within a ``with`` statement:: transport = Computer.get_transport() with transport: - print(transports.whoami()) + print(transport.whoami()) :param user: if None, try to obtain a transport for the default user. Otherwise, pass a valid User. diff --git a/src/aiida/orm/nodes/data/remote/base.py b/src/aiida/orm/nodes/data/remote/base.py index 97f5f067b8..dcf16e4b4a 100644 --- a/src/aiida/orm/nodes/data/remote/base.py +++ b/src/aiida/orm/nodes/data/remote/base.py @@ -125,7 +125,8 @@ def listdir_withattributes(self, path='.'): """Connects to the remote folder and lists the directory content. :param relpath: If 'relpath' is specified, lists the content of the given subfolder. - :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. + :return: a list of dictionaries, where the documentation + is in :py:class:Transport.listdir_withattributes. """ authinfo = self.get_authinfo() diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index a7cd20c88e..8ceeb40212 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -453,7 +453,8 @@ def get_authinfo(self) -> 'AuthInfo': def get_transport(self) -> 'Transport': """Return the transport for this calculation. - :return: `Transport` configured with the `AuthInfo` associated to the computer of this node + :return: Transport configured + with the `AuthInfo` associated to the computer of this node """ return self.get_authinfo().get_transport() diff --git a/src/aiida/plugins/factories.py b/src/aiida/plugins/factories.py index d007ef0dd3..925b400672 100644 --- a/src/aiida/plugins/factories.py +++ b/src/aiida/plugins/factories.py @@ -418,7 +418,7 @@ def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Transport']]: - """Return the `Transport` sub class registered under the given entry point. + """Return the Transport sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. @@ -435,7 +435,7 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoi if not load: return entry_point - if isclass(entry_point) and issubclass(entry_point, Transport): + if isclass(entry_point) and (issubclass(entry_point, Transport)): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/src/aiida/schedulers/plugins/direct.py b/src/aiida/schedulers/plugins/direct.py index 694ff93863..0bed55bda4 100644 --- a/src/aiida/schedulers/plugins/direct.py +++ b/src/aiida/schedulers/plugins/direct.py @@ -192,7 +192,7 @@ def _get_submit_command(self, submit_script): directory. IMPORTANT: submit_script should be already escaped. """ - submit_command = f'bash {submit_script} > /dev/null 2>&1 & echo $!' + submit_command = f'(bash {submit_script} > /dev/null 2>&1 & echo $!) &' self.logger.info(f'submitting with: {submit_command}') diff --git a/src/aiida/tools/pytest_fixtures/__init__.py b/src/aiida/tools/pytest_fixtures/__init__.py index c2729d16c5..e19d4c455e 100644 --- a/src/aiida/tools/pytest_fixtures/__init__.py +++ b/src/aiida/tools/pytest_fixtures/__init__.py @@ -22,6 +22,7 @@ aiida_computer, aiida_computer_local, aiida_computer_ssh, + aiida_computer_ssh_async, aiida_localhost, ssh_key, ) @@ -33,6 +34,7 @@ 'aiida_computer', 'aiida_computer_local', 'aiida_computer_ssh', + 'aiida_computer_ssh_async', 'aiida_config', 'aiida_config_factory', 'aiida_config_tmp', diff --git a/src/aiida/tools/pytest_fixtures/orm.py b/src/aiida/tools/pytest_fixtures/orm.py index 52fd0eb322..618125d203 100644 --- a/src/aiida/tools/pytest_fixtures/orm.py +++ b/src/aiida/tools/pytest_fixtures/orm.py @@ -191,6 +191,38 @@ def factory(label: str | None = None, configure: bool = True) -> 'Computer': return factory +@pytest.fixture +def aiida_computer_ssh_async(aiida_computer) -> t.Callable[[], 'Computer']: + """Factory to return a :class:`aiida.orm.computers.Computer` instance with ``core.ssh_async`` transport. + + Usage:: + + def test(aiida_computer_ssh): + computer = aiida_computer_ssh(label='some-label', configure=True) + assert computer.transport_type == 'core.ssh_async' + assert computer.is_configured + + The factory has the following signature: + + :param label: The computer label. If not specified, a random UUID4 is used. + :param configure: Boolean, if ``True``, ensures the computer is configured, otherwise the computer is returned + as is. Note that if a computer with the given label already exists and it was configured before, the + computer will not be "un-"configured. If an unconfigured computer is absolutely required, make sure to first + delete the existing computer or specify another label. + :return: A stored computer instance. + """ + + def factory(label: str | None = None, configure: bool = True) -> 'Computer': + computer = aiida_computer(label=label, hostname='localhost', transport_type='core.ssh_async') + + if configure: + computer.configure() + + return computer + + return factory + + @pytest.fixture def aiida_localhost(aiida_computer_local) -> 'Computer': """Return a :class:`aiida.orm.computers.Computer` instance representing localhost with ``core.local`` transport. diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index eecd07c04f..3358a673b5 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -16,8 +16,11 @@ from .transport import * __all__ = ( + 'AsyncTransport', + 'BlockingTransport', 'SshTransport', 'Transport', + 'TransportPath', 'convert_to_bool', 'parse_sshconfig', ) diff --git a/src/aiida/transports/cli.py b/src/aiida/transports/cli.py index 6088eb08f6..5faa2d6f80 100644 --- a/src/aiida/transports/cli.py +++ b/src/aiida/transports/cli.py @@ -140,7 +140,7 @@ def transport_options(transport_type): """Decorate a command with all options for a computer configure subcommand for transport_type.""" def apply_options(func): - """Decorate the command functionn with the appropriate options for the transport type.""" + """Decorate the command function with the appropriate options for the transport type.""" options_list = list_transport_options(transport_type) options_list.reverse() func = arguments.COMPUTER(callback=partial(match_comp_transport, transport_type=transport_type))(func) diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index 8de49838e3..c4e8ccadf1 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -15,13 +15,15 @@ import os import shutil import subprocess +from typing import Optional +from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli -from aiida.transports.transport import Transport, TransportInternalError +from aiida.transports.transport import BlockingTransport, TransportInternalError, TransportPath # refactor or raise the limit: issue #1784 -class LocalTransport(Transport): +class LocalTransport(BlockingTransport): """Support copy and command execution on the same host on which AiiDA is running via direct file copy and execution commands. @@ -92,7 +94,7 @@ def curdir(self): raise TransportInternalError('Error, local method called for LocalTransport without opening the channel first') - def chdir(self, path): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -101,6 +103,11 @@ def chdir(self, path): :param path: path to cd into :raise OSError: if the directory does not have read attributes. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) + path = str(path) new_path = os.path.join(self.curdir, path) if not os.path.isdir(new_path): raise OSError(f"'{new_path}' is not a valid directory") @@ -109,13 +116,15 @@ def chdir(self, path): self._internal_dir = os.path.normpath(new_path) - def chown(self, path, uid, gid): + def chown(self, path: TransportPath, uid, gid): + path = str(path) os.chown(path, uid, gid) - def normalize(self, path='.'): + def normalize(self, path: TransportPath = '.'): """Normalizes path, eliminating double slashes, etc.. :param path: path to normalize """ + path = str(path) return os.path.realpath(os.path.join(self.curdir, path)) def getcwd(self): @@ -127,8 +136,9 @@ def getcwd(self): return self.curdir @staticmethod - def _os_path_split_asunder(path): + def _os_path_split_asunder(path: TransportPath): """Used by makedirs, Takes path (a str) and returns a list deconcatenating the path.""" + path = str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -142,7 +152,7 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -153,6 +163,7 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists and is not ignore_existing """ + path = str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -168,7 +179,7 @@ def makedirs(self, path, ignore_existing=False): if not os.path.exists(this_dir): os.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -177,33 +188,37 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = str(path) if ignore_existing and self.isdir(path): return os.mkdir(os.path.join(self.curdir, path)) - def rmdir(self, path): + def rmdir(self, path: TransportPath): """Removes a folder at location path. :param path: path to remove """ + path = str(path) os.rmdir(os.path.join(self.curdir, path)) - def isdir(self, path): + def isdir(self, path: TransportPath): """Checks if 'path' is a directory. :return: a boolean """ + path = str(path) if not path: return False return os.path.isdir(os.path.join(self.curdir, path)) - def chmod(self, path, mode): + def chmod(self, path: TransportPath, mode): """Changes permission bits of object at path :param path: path to modify :param mode: permission bits :raise OSError: if path does not exist. """ + path = str(path) if not path: raise OSError('Directory not given in input') real_path = os.path.join(self.curdir, path) @@ -214,7 +229,7 @@ def chmod(self, path, mode): # please refactor: issue #1782 - def put(self, localpath, remotepath, *args, **kwargs): + def put(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file or a folder from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -228,6 +243,8 @@ def put(self, localpath, remotepath, *args, **kwargs): :raise OSError: if remotepath is not valid :raise ValueError: if localpath is not valid """ + localpath = str(localpath) + remotepath = str(remotepath) from aiida.common.warnings import warn_deprecation if 'ignore_noexisting' in kwargs: @@ -294,7 +311,7 @@ def put(self, localpath, remotepath, *args, **kwargs): else: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, *args, **kwargs): + def putfile(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -307,6 +324,9 @@ def putfile(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = str(localpath) + remotepath = str(remotepath) + overwrite = kwargs.get('overwrite', args[0] if args else True) if not remotepath: raise OSError('Input remotepath to putfile must be a non empty string') @@ -325,7 +345,7 @@ def putfile(self, localpath, remotepath, *args, **kwargs): shutil.copyfile(localpath, the_destination) - def puttree(self, localpath, remotepath, *args, **kwargs): + def puttree(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a folder recursively from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -340,6 +360,8 @@ def puttree(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = str(localpath) + remotepath = str(remotepath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -365,11 +387,12 @@ def puttree(self, localpath, remotepath, *args, **kwargs): shutil.copytree(localpath, the_destination, symlinks=not dereference, dirs_exist_ok=overwrite) - def rmtree(self, path): + def rmtree(self, path: TransportPath): """Remove tree as rm -r would do :param path: a string to path """ + path = str(path) the_path = os.path.join(self.curdir, path) try: shutil.rmtree(the_path) @@ -383,7 +406,7 @@ def rmtree(self, path): # please refactor: issue #1781 - def get(self, remotepath, localpath, *args, **kwargs): + def get(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder or a file recursively from 'remote' remotepath to 'local' localpath. Automatically redirects to getfile or gettree. @@ -398,6 +421,8 @@ def get(self, remotepath, localpath, *args, **kwargs): :raise OSError: if 'remote' remotepath is not valid :raise ValueError: if 'local' localpath is not valid """ + remotepath = str(remotepath) + localpath = str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) ignore_nonexisting = kwargs.get('ignore_nonexisting', args[2] if len(args) > 2 else False) @@ -449,7 +474,7 @@ def get(self, remotepath, localpath, *args, **kwargs): else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, *args, **kwargs): + def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a file recursively from 'remote' remotepath to 'local' localpath. @@ -462,6 +487,9 @@ def getfile(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = str(remotepath) + localpath = str(localpath) + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -480,7 +508,7 @@ def getfile(self, remotepath, localpath, *args, **kwargs): shutil.copyfile(the_source, localpath) - def gettree(self, remotepath, localpath, *args, **kwargs): + def gettree(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder recursively from 'remote' remotepath to 'local' localpath. @@ -493,6 +521,8 @@ def gettree(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = str(remotepath) + localpath = str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -519,7 +549,7 @@ def gettree(self, remotepath, localpath, *args, **kwargs): # please refactor: issue #1780 on github - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False, recursive=True): """Copies a file or a folder from 'remote' remotesource to 'remote' remotedestination. Automatically redirects to copyfile or copytree. @@ -532,6 +562,8 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru :raise ValueError: if 'remote' remotesource or remotedestinationis not valid :raise OSError: if remotesource does not exist """ + remotesource = str(remotesource) + remotedestination = str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copy must be a non empty object') if not remotedestination: @@ -579,7 +611,7 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru # With self.copytree, the (possible) relative path is OK self.copytree(remotesource, remotedestination, dereference) - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a file from 'remote' remotesource to 'remote' remotedestination. @@ -590,6 +622,8 @@ def copyfile(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = str(remotesource) + remotedestination = str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copyfile must be a non empty object') if not remotedestination: @@ -605,7 +639,7 @@ def copyfile(self, remotesource, remotedestination, dereference=False): else: shutil.copyfile(the_source, the_destination) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a folder from 'remote' remotesource to 'remote' remotedestination. @@ -616,6 +650,8 @@ def copytree(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = str(remotesource) + remotedestination = str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copytree must be a non empty object') if not remotedestination: @@ -631,11 +667,12 @@ def copytree(self, remotesource, remotedestination, dereference=False): shutil.copytree(the_source, the_destination, symlinks=not dereference) - def get_attribute(self, path): + def get_attribute(self, path: TransportPath): """Returns an object FileAttribute, as specified in aiida.transports. :param path: the path of the given file. """ + path = str(path) from aiida.transports.util import FileAttribute os_attr = os.lstat(os.path.join(self.curdir, path)) @@ -646,10 +683,12 @@ def get_attribute(self, path): aiida_attr[key] = getattr(os_attr, key) return aiida_attr - def _local_listdir(self, path, pattern=None): + def _local_listdir(self, path: TransportPath, pattern=None): """Act on the local folder, for the rest, same as listdir.""" import re + path = str(path) + if not pattern: return os.listdir(path) @@ -663,12 +702,13 @@ def _local_listdir(self, path, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """:return: a list containing the names of the entries in the directory. :param path: default ='.' :param pattern: if set, returns the list of files matching pattern. Unix only. (Use to emulate ls * for example) """ + path = str(path) the_path = os.path.join(self.curdir, path).strip() if not pattern: try: @@ -685,20 +725,22 @@ def listdir(self, path='.', pattern=None): the_path += '/' return [re.sub(the_path, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: TransportPath): """Removes a file at position path.""" + path = str(path) os.remove(os.path.join(self.curdir, path)) - def isfile(self, path): + def isfile(self, path: TransportPath): """Checks if object at path is a file. Returns a boolean. """ + path = str(path) if not path: return False return os.path.isfile(os.path.join(self.curdir, path)) @contextlib.contextmanager - def _exec_command_internal(self, command, workdir=None, **kwargs): + def _exec_command_internal(self, command, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command in bash login shell. @@ -723,12 +765,13 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): """ from aiida.common.escaping import escape_for_bash + if workdir: + workdir = str(workdir) # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. bash_commmand = f'{self._bash_command_str}-c ' command = bash_commmand + escape_for_bash(command) - if workdir: cwd = workdir else: @@ -745,7 +788,7 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): ) as process: yield process - def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): + def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command and waits for it to finish. :param command: the command to execute @@ -757,6 +800,8 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both bytes and the return_value is an int. """ + if workdir: + workdir = str(workdir) with self._exec_command_internal(command, workdir) as process: if stdin is not None: # Implicitly assume that the desired encoding is 'utf-8' if I receive a string. @@ -799,7 +844,7 @@ def line_encoder(iterator, encoding='utf-8'): return retval, output_text, stderr_text - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -810,11 +855,12 @@ def gotocomputer_command(self, remotedir): :param str remotedir: the full path of the remote directory """ + remotedir = str(remotedir) connect_string = self._gotocomputer_string(remotedir) cmd = f'bash -c {connect_string}' return cmd - def rename(self, oldpath, newpath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -823,6 +869,8 @@ def rename(self, oldpath, newpath): :raises OSError: if src/dst is not found :raises ValueError: if src/dst is not a valid string """ + oldpath = str(oldpath) + newpath = str(newpath) if not oldpath: raise ValueError(f'Source {oldpath} is not a valid string') if not newpath: @@ -834,15 +882,15 @@ def rename(self, oldpath, newpath): shutil.move(oldpath, newpath) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote remotedestination :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = os.path.normpath(remotesource) - remotedestination = os.path.normpath(remotedestination) + remotesource = os.path.normpath(str(remotesource)) + remotedestination = os.path.normpath(str(remotedestination)) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -861,8 +909,9 @@ def symlink(self, remotesource, remotedestination): except OSError: raise OSError(f'!!: {remotesource}, {self.curdir}, {remotedestination}') - def path_exists(self, path): + def path_exists(self, path: TransportPath): """Check if path exists""" + path = str(path) return os.path.exists(os.path.join(self.curdir, path)) diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index 6858da5d2a..240d1b8153 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -19,8 +19,9 @@ from aiida.cmdline.params import options from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType from aiida.common.escaping import escape_for_bash +from aiida.common.warnings import warn_deprecation -from ..transport import Transport, TransportInternalError +from ..transport import BlockingTransport, TransportInternalError, TransportPath __all__ = ('SshTransport', 'convert_to_bool', 'parse_sshconfig') @@ -61,7 +62,7 @@ def convert_to_bool(string): raise ValueError('Invalid boolean value provided') -class SshTransport(Transport): +class SshTransport(BlockingTransport): """Support connection, command execution and data transfer to remote computers via SSH+SFTP.""" # Valid keywords accepted by the connect method of paramiko.SSHClient @@ -230,6 +231,10 @@ class SshTransport(Transport): # if too large commands are sent, clogging the outputs or logs _MAX_EXEC_COMMAND_LOG_SIZE = None + # NOTE: all the methods that start with _get_ are class methods that + # return a suggestion for the specific field. They are being used in + # a function called transport_option_default in transports/cli.py, + # during an interactive `verdi computer configure` command. @classmethod def _get_username_suggestion_string(cls, computer): """Return a suggestion for the specific field.""" @@ -580,7 +585,7 @@ def __str__(self): return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]" - def chdir(self, path): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -590,8 +595,13 @@ def chdir(self, path): Differently from paramiko, if you pass None to chdir, nothing happens and the cwd is unchanged. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) from paramiko.sftp import SFTPError + path = str(path) old_path = self.sftp.getcwd() if path is not None: try: @@ -618,11 +628,13 @@ def chdir(self, path): self.chdir(old_path) raise OSError(str(exc)) - def normalize(self, path='.'): + def normalize(self, path: TransportPath = '.'): """Returns the normalized path (removing double slashes, etc...)""" + path = str(path) + return self.sftp.normalize(path) - def stat(self, path): + def stat(self, path: TransportPath): """Retrieve information about a file on the remote system. The return value is an object whose attributes correspond to the attributes of Python's ``stat`` structure as returned by ``os.stat``, except that it @@ -635,9 +647,11 @@ def stat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = str(path) + return self.sftp.stat(path) - def lstat(self, path): + def lstat(self, path: TransportPath): """Retrieve information about a file on the remote system, without following symbolic links (shortcuts). This otherwise behaves exactly the same as `stat`. @@ -647,6 +661,8 @@ def lstat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = str(path) + return self.sftp.lstat(path) def getcwd(self): @@ -659,9 +675,13 @@ def getcwd(self): this method will return None. But in __enter__ this is set explicitly, so this should never happen within this class. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) return self.sftp.getcwd() - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing: bool = False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -676,6 +696,8 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = str(path) + # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -697,7 +719,7 @@ def makedirs(self, path, ignore_existing=False): if not self.isdir(this_dir): self.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing: bool = False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -706,6 +728,8 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = str(path) + if ignore_existing and self.isdir(path): return @@ -725,7 +749,7 @@ def mkdir(self, path, ignore_existing=False): 'or the directory already exists? ({})'.format(path, self.getcwd(), exc) ) - def rmtree(self, path): + def rmtree(self, path: TransportPath): """Remove a file or a directory at path, recursively Flags used: -r: recursive copy; -f: force, makes the command non interactive; @@ -733,6 +757,7 @@ def rmtree(self, path): :raise OSError: if the rm execution failed. """ + path = str(path) # Assuming linux rm command! rm_exe = 'rm' @@ -752,25 +777,29 @@ def rmtree(self, path): self.logger.error(f"Problem executing rm. Exit code: {retval}, stdout: '{stdout}', stderr: '{stderr}'") raise OSError(f'Error while executing rm. Exit code: {retval}') - def rmdir(self, path): + def rmdir(self, path: TransportPath): """Remove the folder named 'path' if empty.""" + path = str(path) self.sftp.rmdir(path) - def chown(self, path, uid, gid): + def chown(self, path: TransportPath, uid, gid): """Change owner permissions of a file. For now, this is not implemented for the SSH transport. """ raise NotImplementedError - def isdir(self, path): + def isdir(self, path: TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. """ # Return False on empty string (paramiko would map this to the local # folder instead) + path = str(path) + if not path: return False + path = str(path) try: return S_ISDIR(self.stat(path).st_mode) except OSError as exc: @@ -779,21 +808,24 @@ def isdir(self, path): return False raise # Typically if I don't have permissions (errno=13) - def chmod(self, path, mode): + def chmod(self, path: TransportPath, mode): """Change permissions to path :param path: path to file :param mode: new permission bits (integer) """ + path = str(path) + if not path: raise OSError('Input path is an empty argument.') return self.sftp.chmod(path, mode) @staticmethod - def _os_path_split_asunder(path): - """Used by makedirs. Takes path (a str) + def _os_path_split_asunder(path: TransportPath): + """Used by makedirs. Takes path and returns a list deconcatenating the path """ + path = str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -807,7 +839,15 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def put(self, localpath, remotepath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def put( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, + ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. @@ -821,6 +861,9 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ + localpath = str(localpath) + remotepath = str(remotepath) + if not dereference: raise NotImplementedError @@ -871,7 +914,14 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= elif not ignore_nonexisting: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def putfile( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Put a file from local to remote. :param localpath: an (absolute) local path @@ -883,6 +933,9 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr :raise OSError: if the localpath does not exist, or unintentionally overwriting """ + localpath = str(localpath) + remotepath = str(remotepath) + if not dereference: raise NotImplementedError @@ -894,7 +947,14 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr return self.sftp.put(localpath, remotepath, callback=callback) - def puttree(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def puttree( + self, + localpath: TransportPath, + remotepath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Put a folder recursively from local to remote. By default, overwrite. @@ -913,6 +973,9 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr .. note:: setting dereference equal to True could cause infinite loops. see os.walk() documentation """ + localpath = str(localpath) + remotepath = str(remotepath) + if not dereference: raise NotImplementedError @@ -958,7 +1021,15 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr this_remote_file = os.path.join(remotepath, this_basename, this_file) self.putfile(this_local_file, this_remote_file) - def get(self, remotepath, localpath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def get( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, + ): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -973,6 +1044,9 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ + remotepath = str(remotepath) + localpath = str(localpath) + if not dereference: raise NotImplementedError @@ -1020,7 +1094,14 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def getfile( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Get a file from remote to local. :param remotepath: a remote path @@ -1031,6 +1112,9 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ + remotepath = str(remotepath) + localpath = str(localpath) + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -1050,7 +1134,14 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr pass raise - def gettree(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def gettree( + self, + remotepath: TransportPath, + localpath: TransportPath, + callback=None, + dereference: bool = True, + overwrite: bool = True, + ): """Get a folder recursively from remote to local. :param remotepath: a remote path @@ -1059,12 +1150,14 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr Default = True (default behaviour in paramiko). False is not implemented. :param overwrite: if True overwrites files and folders. - Default = False + Default = True :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ + remotepath = str(remotepath) + localpath = str(localpath) if not dereference: raise NotImplementedError @@ -1101,10 +1194,11 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr else: self.getfile(os.path.join(remotepath, item), os.path.join(dest, item)) - def get_attribute(self, path): + def get_attribute(self, path: TransportPath): """Returns the object Fileattribute, specified in aiida.transports Receives in input the path of a given file. """ + path = str(path) from aiida.transports.util import FileAttribute paramiko_attr = self.lstat(path) @@ -1115,13 +1209,25 @@ def get_attribute(self, path): aiida_attr[key] = getattr(paramiko_attr, key) return aiida_attr - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = str(remotesource) + remotedestination = str(remotedestination) + return self.copy(remotesource, remotedestination, dereference) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = str(remotesource) + remotedestination = str(remotedestination) + return self.copy(remotesource, remotedestination, dereference, recursive=True) - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy( + self, + remotesource: TransportPath, + remotedestination: TransportPath, + dereference: bool = False, + recursive: bool = True, + ): """Copy a file or a directory from remote source to remote destination. Flags used: ``-r``: recursive copy; ``-f``: force, makes the command non interactive; ``-L`` follows symbolic links @@ -1138,6 +1244,9 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru .. note:: setting dereference equal to True could cause infinite loops. """ + remotesource = str(remotesource) + remotedestination = str(remotedestination) + # In the majority of cases, we should deal with linux cp commands cp_flags = '-f' if recursive: @@ -1179,7 +1288,7 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru else: self._exec_cp(cp_exe, cp_flags, remotesource, remotedestination) - def _exec_cp(self, cp_exe, cp_flags, src, dst): + def _exec_cp(self, cp_exe: str, cp_flags: str, src: str, dst: str): """Execute the ``cp`` command on the remote machine.""" # to simplify writing the above copy function command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}' @@ -1205,7 +1314,7 @@ def _exec_cp(self, cp_exe, cp_flags, src, dst): ) @staticmethod - def _local_listdir(path, pattern=None): + def _local_listdir(path: str, pattern=None): """Acts on the local folder, for the rest, same as listdir""" if not pattern: return os.listdir(path) @@ -1219,13 +1328,15 @@ def _local_listdir(path, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """Get the list of files at path. :param path: default = '.' :param pattern: returns the list of files matching pattern. Unix only. (Use to emulate ``ls *`` for example) """ + path = str(path) + if path.startswith('/'): abs_dir = path else: @@ -1239,33 +1350,41 @@ def listdir(self, path='.', pattern=None): abs_dir += '/' return [re.sub(abs_dir, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: TransportPath): """Remove a single file at 'path'""" + path = str(path) return self.sftp.remove(path) - def rename(self, oldpath, newpath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder :param str newpath: new name for the file or folder :raises OSError: if oldpath/newpath is not found - :raises ValueError: if sroldpathc/newpath is not a valid string + :raises ValueError: if sroldpathc/newpath is not a valid path """ if not oldpath: - raise ValueError(f'Source {oldpath} is not a valid string') + raise ValueError(f'Source {oldpath} is not a valid path') if not newpath: - raise ValueError(f'Destination {newpath} is not a valid string') + raise ValueError(f'Destination {newpath} is not a valid path') + + oldpath = str(oldpath) + newpath = str(newpath) + if not self.isfile(oldpath): if not self.isdir(oldpath): raise OSError(f'Source {oldpath} does not exist') + # TODO: this seems to be a bug (?) + # why to raise an OSError if the newpath does not exist? + # ofcourse newpath shouldn't exist, that's why we are renaming it! if not self.isfile(newpath): if not self.isdir(newpath): raise OSError(f'Destination {newpath} does not exist') return self.sftp.rename(oldpath, newpath) - def isfile(self, path): + def isfile(self, path: TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. """ @@ -1274,6 +1393,8 @@ def isfile(self, path): # but this is just to be sure if not path: return False + + path = str(path) try: self.logger.debug( f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]" @@ -1334,7 +1455,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, work return stdin, stdout, stderr, channel def exec_command_wait_bytes( - self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir=None + self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir: TransportPath = None ): """Executes the specified command and waits for it to finish. @@ -1354,6 +1475,9 @@ def exec_command_wait_bytes( import socket import time + if workdir: + workdir = str(workdir) + ssh_stdin, stdout, stderr, channel = self._exec_command_internal( command, combine_stderr, bufsize=bufsize, workdir=workdir ) @@ -1447,10 +1571,12 @@ def exec_command_wait_bytes( return (retval, b''.join(stdout_bytes), b''.join(stderr_bytes)) - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: TransportPath): """Specific gotocomputer string to connect to a given remote computer via ssh and directly go to the calculation folder. """ + remotedir = str(remotedir) + further_params = [] if 'username' in self._connect_args: further_params.append(f"-l {escape_for_bash(self._connect_args['username'])}") @@ -1473,21 +1599,25 @@ def gotocomputer_command(self, remotedir): cmd = f'ssh -t {self._machine} {further_params_str} {connect_string}' return cmd - def _symlink(self, source, dest): + def _symlink(self, source: TransportPath, dest: TransportPath): """Wrap SFTP symlink call without breaking API :param source: source of link :param dest: link to create """ + source = str(source) + dest = str(dest) self.sftp.symlink(source, dest) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ + remotesource = str(remotesource) + remotedestination = str(remotedestination) # paramiko gives some errors if path is starting with '.' source = os.path.normpath(remotesource) dest = os.path.normpath(remotedestination) @@ -1495,7 +1625,7 @@ def symlink(self, remotesource, remotedestination): if self.has_magic(source): if self.has_magic(dest): # if there are patterns in dest, I don't know which name to assign - raise ValueError('Remotedestination cannot have patterns') + raise ValueError('`remotedestination` cannot have patterns') # find all files matching pattern for this_source in self.glob(source): @@ -1505,10 +1635,12 @@ def symlink(self, remotesource, remotedestination): else: self._symlink(source, dest) - def path_exists(self, path): + def path_exists(self, path: TransportPath): """Check if path exists""" import errno + path = str(path) + try: self.stat(path) except OSError as exc: diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py new file mode 100644 index 0000000000..da51017b5a --- /dev/null +++ b/src/aiida/transports/plugins/ssh_async.py @@ -0,0 +1,1272 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Plugin for transport over SSH asynchronously.""" + +## TODO: put & get methods could be simplified with the asyncssh.sftp.mget() & put() method or sftp.glob() +import asyncio +import glob +import os +from pathlib import Path, PurePath +from typing import Optional + +import asyncssh +import click +from asyncssh import SFTPFileAlreadyExists + +from aiida.common.escaping import escape_for_bash +from aiida.common.exceptions import InvalidOperation +from aiida.transports.transport import ( + AsyncTransport, + Transport, + TransportInternalError, + TransportPath, + validate_positive_number, +) + +__all__ = ('AsyncSshTransport',) + + +def validate_script(ctx, param, value: str): + if value == 'None': + return value + if not os.path.isabs(value): + raise click.BadParameter(f'{value} is not an absolute path') + if not os.path.isfile(value): + raise click.BadParameter(f'The script file: {value} does not exist') + if not os.access(value, os.X_OK): + raise click.BadParameter(f'The script {value} is not executable') + return value + + +def validate_machine(ctx, param, value: str): + async def attempt_connection(): + try: + await asyncssh.connect(value) + except Exception: + return False + return True + + if not asyncio.run(attempt_connection()): + raise click.BadParameter("Couldn't connect! " 'Please make sure `ssh {value}` would work without password') + else: + click.echo(f'`ssh {value}` successful!') + + return value + + +class AsyncSshTransport(AsyncTransport): + """Transport plugin via SSH, asynchronously.""" + + _DEFAULT_max_io_allowed = 8 + + # note, I intentionally wanted to keep connection parameters as simple as possible. + _valid_auth_options = [ + ( + # the underscore is added to avoid conflict with the machine property + # which is passed to __init__ as parameter `machine=computer.hostname` + 'machine_or_host', + { + 'type': str, + 'prompt': 'Machine(or host) name as in `ssh ` command.' + ' (It should be a password-less setup)', + 'help': 'Password-less host-setup to connect, as in command `ssh `. ' + "You'll need to have a `Host ` entry defined in your `~/.ssh/config` file.", + 'non_interactive_default': True, + 'callback': validate_machine, + }, + ), + ( + 'max_io_allowed', + { + 'type': int, + 'default': _DEFAULT_max_io_allowed, + 'prompt': 'Maximum number of concurrent I/O operations.', + 'help': 'Depends on various factors, such as your network bandwidth, the server load, etc.' + ' (An experimental number)', + 'non_interactive_default': True, + 'callback': validate_positive_number, + }, + ), + ( + 'script_before', + { + 'type': str, + 'default': 'None', + 'prompt': 'Local script to run *before* opening connection (path)', + 'help': ' (optional) Specify a script to run *before* opening SSH connection. ' + 'The script should be executable', + 'non_interactive_default': True, + 'callback': validate_script, + }, + ), + ] + + @classmethod + def _get_machine_suggestion_string(cls, computer): + """Return a suggestion for the parameter machine.""" + # Originally set as 'Hostname' during `verdi computer setup` + # and is passed as `machine=computer.hostname` in the codebase + # unfortunately, name of hostname and machine are used interchangeably in the aiida-core codebase + # TODO: open an issue to unify the naming + return computer.hostname + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # the machine is passed as `machine=computer.hostname` in the codebase + # 'machine' is immutable. + # 'machine_or_host' is mutable, so it can be changed via command: + # 'verdi computer configure core.ssh_async