From c56c2ed13f343487581c3511b9483a497e2cba11 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 19 Dec 2024 17:02:09 +0100 Subject: [PATCH] Fix issues from `hatch fmt` --- src/sirocco/core/graph_items.py | 2 +- src/sirocco/parsing/_yaml_data_models.py | 8 +++--- src/sirocco/workgraph.py | 23 ++++++++---------- tests/conftest.py | 2 +- tests/files/scripts/cleanup.py | 3 ++- tests/files/scripts/icon.py | 31 ++++++++++++------------ tests/test_wc_workflow.py | 21 +++++++++++----- 7 files changed, 48 insertions(+), 42 deletions(-) diff --git a/src/sirocco/core/graph_items.py b/src/sirocco/core/graph_items.py index ecffa73..ae57842 100644 --- a/src/sirocco/core/graph_items.py +++ b/src/sirocco/core/graph_items.py @@ -110,7 +110,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self: @property def path(self) -> Path: - # TODO yaml level? + # TODO: yaml level? return Path(expandvars(self.src)) diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 4ca0c05..139d199 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -3,9 +3,9 @@ import time from dataclasses import dataclass from datetime import datetime +from os.path import expandvars from pathlib import Path from typing import Annotated, Any, ClassVar, Literal -from os.path import expandvars from isoduration import parse_duration from isoduration.types import Duration # pydantic needs type # noqa: TCH002 @@ -280,7 +280,6 @@ class ConfigShellTaskSpecs: class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs): - command: str = "" # PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests @@ -288,7 +287,7 @@ class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs): @classmethod def expand_var(cls, value: str) -> str: """Expand environemnt variables""" - # TODO this might be not intended if we want to use environment variables on remote HPC + # TODO: this might be not intended if we want to use environment variables on remote HPC return expandvars(value) @@ -297,6 +296,7 @@ class ConfigIconTaskSpecs: plugin: ClassVar[Literal["icon"]] = "icon" namelists: dict[str, str] | None = None + class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs): pass @@ -331,7 +331,7 @@ def is_file_or_dir(cls, value: str) -> str: @classmethod def expand_var(cls, value: str | None) -> str | None: """Expand environemnt variables""" - # TODO this might be not intended if we want to use environment variables on remote HPC + # TODO: this might be not intended if we want to use environment variables on remote HPC return None if value is None else expandvars(value) diff --git a/src/sirocco/workgraph.py b/src/sirocco/workgraph.py index fb76faa..806c740 100644 --- a/src/sirocco/workgraph.py +++ b/src/sirocco/workgraph.py @@ -7,14 +7,15 @@ import aiida_workgraph.engine.utils # type: ignore[import-untyped] from aiida_workgraph import WorkGraph # type: ignore[import-untyped] -from sirocco import core -from sirocco.core import graph_items from sirocco.core._tasks.icon_task import IconTask from sirocco.core._tasks.shell_task import ShellTask if TYPE_CHECKING: from aiida_workgraph.socket import TaskSocket # type: ignore[import-untyped] + from sirocco import core + from sirocco.core import graph_items + # This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require # some major refactor see issue https://github.com/aiidateam/aiida-workgraph/issues/168 @@ -69,7 +70,7 @@ def __init__(self, core_workflow: core.Workflow): self._validate_workflow() - self._workgraph = WorkGraph() # core_workflow.name TODO use filename + self._workgraph = WorkGraph() # core_workflow.name TODO: use filename # stores the input data available on initialization self._aiida_data_nodes: dict[str, aiida_workgraph.orm.Data] = {} @@ -136,7 +137,7 @@ def get_aiida_label_from_unrolled_data(obj: graph_items.GraphItem) -> str: @staticmethod def get_aiida_label_from_unrolled_task(obj: graph_items.GraphItem) -> str: """ """ - # TODO task is not anymore using cycle name because information is not there + # TODO: task is not anymore using cycle name because information is not there # so do we check somewhere that a task is not used in multiple cycles? # Otherwise the label is not unique # --> task name + date + parameters @@ -164,7 +165,7 @@ def _add_aiida_task_nodes(self): for task in cycle.tasks: self._add_aiida_task_node(task) # after creation we can link the wait_on tasks - # TODO check where this is now + # TODO: check where this is now # for cycle in self._core_workflow.cycles: # for task in cycle.tasks: # self._link_wait_on_to_task(task) @@ -182,7 +183,7 @@ def _add_aiida_task_node(self, task: graph_items.Task): # ? Source file env_source_files = task.env_source_files env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files - prepend_text = '\n'.join([f"source {env_source_file}" for env_source_file in env_source_files]) + prepend_text = "\n".join([f"source {env_source_file}" for env_source_file in env_source_files]) # Note: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when # we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to @@ -193,11 +194,7 @@ def _add_aiida_task_node(self, task: graph_items.Task): name=label, command=command, arguments=argument_list, - metadata={ - 'options': { - 'prepend_text': prepend_text - } - } + metadata={"options": {"prepend_text": prepend_text}}, ) self._aiida_task_nodes[label] = workgraph_task @@ -210,7 +207,7 @@ def _add_aiida_task_node(self, task: graph_items.Task): raise NotImplementedError(exc) def _link_wait_on_to_task(self, task: graph_items.Task): - # TODO + # TODO: to be done msg = "" raise NotImplementedError(msg) label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) @@ -259,7 +256,7 @@ def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data): if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None: msg = f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph before linking. This is a bug in the code, please contact devevlopers." raise ValueError(msg) - # TODO think about that the yaml file should have aiida valid labels + # TODO: think about that the yaml file should have aiida valid labels # Avoid appending the same argument twice argument_placeholder = f"{{{input_label}}}" diff --git a/tests/conftest.py b/tests/conftest.py index 66b9686..4a579bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1 @@ -pytest_plugins = ['aiida.tools.pytest_fixtures'] +pytest_plugins = ["aiida.tools.pytest_fixtures"] diff --git a/tests/files/scripts/cleanup.py b/tests/files/scripts/cleanup.py index ff9c30e..de7aeba 100755 --- a/tests/files/scripts/cleanup.py +++ b/tests/files/scripts/cleanup.py @@ -1,9 +1,10 @@ #!/usr/bin/env python + def main(): # Main script execution continues here print("Cleaning") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/files/scripts/icon.py b/tests/files/scripts/icon.py index ac9e6c5..32f71ed 100755 --- a/tests/files/scripts/icon.py +++ b/tests/files/scripts/icon.py @@ -15,29 +15,29 @@ LOG_FILE = Path("icon.log") + def log(text: str): print(text) with LOG_FILE.open("a") as f: f.write(text) -def main(): - parser = argparse.ArgumentParser(description='A script mocking parts of icon in a form of a shell script.') - parser.add_argument('--init', nargs='?', type=str, help='The icon init file.') - parser.add_argument('namelist', nargs='?', default=None) - parser.add_argument('--restart', nargs='?', type=str, help='The icon restart file.') - parser.add_argument('--forcing', nargs='?', type=str, help='The icon forcing file.') +def main(): + parser = argparse.ArgumentParser(description="A script mocking parts of icon in a form of a shell script.") + parser.add_argument("--init", nargs="?", type=str, help="The icon init file.") + parser.add_argument("namelist", nargs="?", default=None) + parser.add_argument("--restart", nargs="?", type=str, help="The icon restart file.") + parser.add_argument("--forcing", nargs="?", type=str, help="The icon forcing file.") args = parser.parse_args() - - output = Path('icon_output') + output = Path("icon_output") output.write_text("") - if args.restart and args.init: - msg = "Cannot use '--init' and '--restart' option at the same time." - raise ValueError(msg) - elif args.restart: + if args.restart: + if args.init: + msg = "Cannot use '--init' and '--restart' option at the same time." + raise ValueError(msg) if not Path(args.restart).exists(): msg = f"The icon restart file {args.restart!r} was not found." raise FileNotFoundError(msg) @@ -62,10 +62,9 @@ def main(): # Main script execution continues here log("Script finished running calculations") - restart = Path('restart') + restart = Path("restart") restart.write_text("") -if __name__ == '__main__': - main() - +if __name__ == "__main__": + main() diff --git a/tests/test_wc_workflow.py b/tests/test_wc_workflow.py index 1448490..3b66394 100644 --- a/tests/test_wc_workflow.py +++ b/tests/test_wc_workflow.py @@ -8,15 +8,19 @@ from sirocco.workgraph import AiidaWorkGraph -@pytest.mark.parametrize("config_path", [ - "tests/files/configs/test_config_small.yml", - "tests/files/configs/test_config_parameters.yml", -]) +@pytest.mark.parametrize( + "config_path", + [ + "tests/files/configs/test_config_small.yml", + "tests/files/configs/test_config_parameters.yml", + ], +) def test_run_workgraph(config_path): core_workflow = Workflow.from_yaml(config_path) aiida_workflow = AiidaWorkGraph(core_workflow) out = aiida_workflow.run() - assert out.get('execution_count', None).value == 0 # TODO should be 1 but we need to update workgraph for this + assert out.get("execution_count", None).value == 0 # TODO: should be 1 but we need to update workgraph for this + # configs that are tested only tested parsing config_test_files = [ @@ -25,6 +29,7 @@ def test_run_workgraph(config_path): "tests/files/configs/test_config_parameters.yml", ] + @pytest.fixture(params=config_test_files) def config_paths(request): config_path = Path(request.param) @@ -34,17 +39,21 @@ def config_paths(request): "svg": (config_path.parent.parent / "svgs" / config_path.name).with_suffix(".svg"), } + @pytest.fixture def pprinter(): return PrettyPrinter() + def test_parse_config_file(config_paths, pprinter): reference_str = config_paths["txt"].read_text() test_str = pprinter.format(Workflow.from_yaml(config_paths["yml"])) if test_str != reference_str: new_path = Path(config_paths["txt"]).with_suffix(".new.txt") new_path.write_text(test_str) - assert reference_str == test_str, f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}." + assert ( + reference_str == test_str + ), f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}." @pytest.mark.skip(reason="don't run it each time, uncomment to regenerate serilaized data")