diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 68b6a6d626..f9b4577579 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -59,6 +59,7 @@ 'WorkChain', 'append_', 'assign_', + 'await_processes', 'calcfunction', 'construct_awaitable', 'get_daemon_client', diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index ef48a09e3b..b277609ae2 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -14,6 +14,7 @@ import typing as t from aiida.common import InvalidOperation +from aiida.common.lang import type_check from aiida.common.log import AIIDA_LOGGER from aiida.manage import manager from aiida.orm import ProcessNode @@ -24,7 +25,7 @@ from .runners import ResultAndPk from .utils import instantiate_process, is_process_scoped # pylint: disable=no-name-in-module -__all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') +__all__ = ('run', 'run_get_pk', 'run_get_node', 'submit', 'await_processes') TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] # pylint: disable=invalid-name # run can also be process function, but it is not clear what type this should be @@ -130,6 +131,28 @@ def submit(process: TYPE_SUBMIT_PROCESS, wait: bool = False, wait_interval: int return node +def await_processes(nodes: t.Sequence[ProcessNode], wait_interval: int = 1) -> None: + """Run a loop until all processes are terminated. + + :param nodes: Sequence of nodes that represent the processes to await. + :param wait_interval: The interval between each iteration of checking the status of all processes. + """ + type_check(nodes, (list, tuple)) + + if any(not isinstance(node, ProcessNode) for node in nodes): + raise TypeError(f'`nodes` should be a list of `ProcessNode`s but got: {nodes}') + + start_time = time.time() + terminated = False + + while not terminated: + running = [not node.is_terminated for node in nodes] + terminated = not any(running) + seconds_passed = time.time() - start_time + LOGGER.report(f'{running.count(False)} out of {len(nodes)} processes terminated. [{round(seconds_passed)} s]') + time.sleep(wait_interval) + + # Allow one to also use run.get_node and run.get_pk as a shortcut, without having to import the functions themselves run.get_node = run_get_node # type: ignore[attr-defined] run.get_pk = run_get_pk # type: ignore[attr-defined] diff --git a/docs/source/topics/processes/usage.rst b/docs/source/topics/processes/usage.rst index 479ac72ba7..b5cbb9a309 100644 --- a/docs/source/topics/processes/usage.rst +++ b/docs/source/topics/processes/usage.rst @@ -355,6 +355,23 @@ The function will submit the calculation to the daemon and immediately return co This can be useful for tutorials and demos in interactive notebooks where the user should not continue before the process is done. One could of course also use ``run`` (see below), but then the process would be lost if the interpreter gets accidentally shut down. By using ``submit``, the process is run by the daemon which takes care of saving checkpoints so it can always be restarted in case of problems. + If you need to launch multiple processes in parallel and want to wait for all of them to be finished, simply use ``submit`` with the default ``wait=False`` and collect the returned nodes in a list. + You can then pass them to :func:`aiida.engine.launch.await_processes` which will return once all processes have terminated: + + .. code:: python + + from aiida.engine import submit, await_processes + + nodes = [] + + for i in range(5): + node = submit(...) + nodes.append(node) + + await_processes(nodes, wait_interval=10) + + The ``await_processes`` function will loop every ``wait_interval`` seconds and check whether all processes (represented by the ``ProcessNode`` in the ``nodes`` list) have terminated. + The ``run`` function is called identically: diff --git a/tests/engine/test_launch.py b/tests/engine/test_launch.py index 3fc6ca8a4f..f593a486b4 100644 --- a/tests/engine/test_launch.py +++ b/tests/engine/test_launch.py @@ -82,6 +82,35 @@ def test_submit_wait(aiida_local_code_factory): assert node.is_finished_ok, node.exit_code +def test_await_processes_invalid(): + """Test :func:`aiida.engine.launch.await_processes` for invalid inputs.""" + with pytest.raises(TypeError): + launch.await_processes(None) + + with pytest.raises(TypeError): + launch.await_processes([orm.Data()]) + + with pytest.raises(TypeError): + launch.await_processes(orm.ProcessNode()) + + +@pytest.mark.usefixtures('started_daemon_client') +def test_await_processes(aiida_local_code_factory, caplog): + """Test :func:`aiida.engine.launch.await_processes`.""" + builder = ArithmeticAddCalculation.get_builder() + builder.code = aiida_local_code_factory('core.arithmetic.add', '/bin/bash') + builder.x = orm.Int(1) + builder.y = orm.Int(2) + builder.metadata = {'options': {'resources': {'num_machines': 1}}} + node = launch.submit(builder) + + assert not node.is_terminated + launch.await_processes([node]) + assert node.is_terminated + assert len(caplog.records) > 0 + assert 'out of 1 processes terminated.' in caplog.records[0].message + + @pytest.mark.requires_rmq class TestLaunchers: """Class to test process launchers."""