From 0cd2ae5c94a049c8070bfdada0d1695e3f36e751 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Sat, 21 Dec 2024 06:29:33 +0100 Subject: [PATCH] Pass single argument string to `ShellTasks` (#72) * Try to make current example YAML files run through. * Update class names after rebase/merge * Add `no_icon` config file * First working version apart from flags. * Try fixing argument format. * Current state before branch-off. * Pass one multi-line string as `cli_argument` * Cleanup. * Pass arguments as list. * Remove `workgraph-dev.py` dev file. * Fix issues from `hatch fmt` --- src/sirocco/core/graph_items.py | 2 +- src/sirocco/parsing/_yaml_data_models.py | 50 ++--------- src/sirocco/workgraph.py | 85 +++++++++++++------ tests/conftest.py | 2 +- .../configs/test_config_small_no_icon.yml | 52 ++++++++++++ tests/files/data/data-xyz | 0 tests/files/data/dummy_source_file.sh | 0 tests/files/data/initial_conditions2 | 0 tests/files/scripts/cleanup.py | 3 +- tests/files/scripts/icon.py | 31 ++++--- tests/files/scripts/shell_task.sh | 55 ++++++++++++ tests/test_wc_workflow.py | 21 +++-- 12 files changed, 205 insertions(+), 96 deletions(-) create mode 100644 tests/files/configs/test_config_small_no_icon.yml create mode 100644 tests/files/data/data-xyz create mode 100644 tests/files/data/dummy_source_file.sh create mode 100644 tests/files/data/initial_conditions2 create mode 100755 tests/files/scripts/shell_task.sh diff --git a/src/sirocco/core/graph_items.py b/src/sirocco/core/graph_items.py index ecffa73..ae57842 100644 --- a/src/sirocco/core/graph_items.py +++ b/src/sirocco/core/graph_items.py @@ -110,7 +110,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self: @property def path(self) -> Path: - # TODO yaml level? + # TODO: yaml level? return Path(expandvars(self.src)) diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index e6e16f6..139d199 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -3,9 +3,9 @@ import time from dataclasses import dataclass from datetime import datetime +from os.path import expandvars from pathlib import Path from typing import Annotated, Any, ClassVar, Literal -from os.path import expandvars from isoduration import parse_duration from isoduration.types import Duration # pydantic needs type # noqa: TCH002 @@ -83,43 +83,6 @@ def convert_datetime(cls, value) -> datetime: return datetime.fromisoformat(value) -class _CliArgsBaseModel(BaseModel): - """Base class for cli_arguments specifications""" - - # TODO: Even allow for `str`, or always require list? - positional: str | list[str] | None = None - # Field needed for child class doing pydantic parsing - keyword: dict[str, str] | None = Field(default_factory=dict) - flags: str | list[str] | None = None - source_file: str | list[str] | None = None - - # TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically? - # TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit. - @field_validator("keyword", mode="before") - @classmethod - def validate_keyword_args(cls, value): - """Ensure keyword arguments start with '-' or '--'.""" - if value is not None: - invalid_keys = [key for key in value if not key.startswith(("-", "--"))] - if invalid_keys: - invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}" - raise ValueError(invalid_kwarg_exc) - return value - - @field_validator("flags", mode="before") - @classmethod - def validate_flag_args(cls, value): - """Ensure positional arguments start with '-' or '--'.""" - if value is not None: - if isinstance(value, str): - value = [value] - invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))] - if invalid_flags: - invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}" - raise ValueError(invalid_flags_exc) - return value - - class TargetNodesBaseModel(_NamedBaseModel): """class for targeting other task or data nodes in the graph @@ -311,20 +274,20 @@ class ConfigRootTask(ConfigBaseTask): class ConfigShellTaskSpecs: plugin: ClassVar[Literal["shell"]] = "shell" command: str = "" - cli_arguments: _CliArgsBaseModel | None = None + cli_argument: str = "" + env_source_files: str | list[str] = None src: str | None = None class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs): - # PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests command: str = "" - command_option: str = "" + # PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests @field_validator("command", "src") @classmethod def expand_var(cls, value: str) -> str: """Expand environemnt variables""" - # TODO this might be not intended if we want to use environment variables on remote HPC + # TODO: this might be not intended if we want to use environment variables on remote HPC return expandvars(value) @@ -333,6 +296,7 @@ class ConfigIconTaskSpecs: plugin: ClassVar[Literal["icon"]] = "icon" namelists: dict[str, str] | None = None + class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs): pass @@ -367,7 +331,7 @@ def is_file_or_dir(cls, value: str) -> str: @classmethod def expand_var(cls, value: str | None) -> str | None: """Expand environemnt variables""" - # TODO this might be not intended if we want to use environment variables on remote HPC + # TODO: this might be not intended if we want to use environment variables on remote HPC return None if value is None else expandvars(value) diff --git a/src/sirocco/workgraph.py b/src/sirocco/workgraph.py index 2636352..806c740 100644 --- a/src/sirocco/workgraph.py +++ b/src/sirocco/workgraph.py @@ -7,9 +7,14 @@ import aiida_workgraph.engine.utils # type: ignore[import-untyped] from aiida_workgraph import WorkGraph # type: ignore[import-untyped] +from sirocco.core._tasks.icon_task import IconTask +from sirocco.core._tasks.shell_task import ShellTask + if TYPE_CHECKING: from aiida_workgraph.socket import TaskSocket # type: ignore[import-untyped] - from wcflow import core + + from sirocco import core + from sirocco.core import graph_items # This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require @@ -65,7 +70,7 @@ def __init__(self, core_workflow: core.Workflow): self._validate_workflow() - self._workgraph = WorkGraph() # core_workflow.name TODO use filename + self._workgraph = WorkGraph() # core_workflow.name TODO: use filename # stores the input data available on initialization self._aiida_data_nodes: dict[str, aiida_workgraph.orm.Data] = {} @@ -123,16 +128,16 @@ def parse_to_aiida_label(label: str) -> str: return label @staticmethod - def get_aiida_label_from_unrolled_data(obj: core.BaseNode) -> str: + def get_aiida_label_from_unrolled_data(obj: graph_items.GraphItem) -> str: """ """ return AiidaWorkGraph.parse_to_aiida_label( f"{obj.name}" + "__".join(f"_{key}_{value}" for key, value in obj.coordinates.items()) ) @staticmethod - def get_aiida_label_from_unrolled_task(obj: core.BaseNode) -> str: + def get_aiida_label_from_unrolled_task(obj: graph_items.GraphItem) -> str: """ """ - # TODO task is not anymore using cycle name because information is not there + # TODO: task is not anymore using cycle name because information is not there # so do we check somewhere that a task is not used in multiple cycles? # Otherwise the label is not unique # --> task name + date + parameters @@ -140,7 +145,7 @@ def get_aiida_label_from_unrolled_task(obj: core.BaseNode) -> str: "__".join([f"{obj.name}"] + [f"_{key}_{value}" for key, value in obj.coordinates.items()]) ) - def _add_aiida_input_data_node(self, input_: core.UnrolledData): + def _add_aiida_input_data_node(self, input_: graph_items.Data): """ Create an :class:`aiida.orm.Data` instance from this wc data instance. @@ -160,27 +165,49 @@ def _add_aiida_task_nodes(self): for task in cycle.tasks: self._add_aiida_task_node(task) # after creation we can link the wait_on tasks - # TODO check where this is now - #for cycle in self._core_workflow.cycles: + # TODO: check where this is now + # for cycle in self._core_workflow.cycles: # for task in cycle.tasks: # self._link_wait_on_to_task(task) - def _add_aiida_task_node(self, task: core.UnrolledTask): + def _add_aiida_task_node(self, task: graph_items.Task): label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) - if task.command is None: - msg = f"The command is None of task {task}." - raise ValueError(msg) - workgraph_task = self._workgraph.tasks.new( - "ShellJob", - name=label, - command=task.command, - ) - workgraph_task.set({"arguments": []}) - workgraph_task.set({"nodes": {}}) - self._aiida_task_nodes[label] = workgraph_task + if isinstance(task, ShellTask): + if task.command is None: + msg = f"The command is None of task {task}." + raise ValueError(msg) + + # TODO: Add proper resolving of code/command to installed/portable code, etc., to think about + command = task.command + + # ? Source file + env_source_files = task.env_source_files + env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files + prepend_text = "\n".join([f"source {env_source_file}" for env_source_file in env_source_files]) + + # Note: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when + # we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to + # tasks + argument_list = task.cli_argument.split() + workgraph_task = self._workgraph.tasks.new( + "ShellJob", + name=label, + command=command, + arguments=argument_list, + metadata={"options": {"prepend_text": prepend_text}}, + ) + + self._aiida_task_nodes[label] = workgraph_task + + elif isinstance(task, IconTask): + exc = "IconTask not implemented yet." + raise NotImplementedError(exc) + else: + exc = f"Task: {task.name} not implemented yet." + raise NotImplementedError(exc) - def _link_wait_on_to_task(self, task: core.UnrolledTask): - # TODO + def _link_wait_on_to_task(self, task: graph_items.Task): + # TODO: to be done msg = "" raise NotImplementedError(msg) label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) @@ -202,7 +229,7 @@ def _add_aiida_links_from_cycle(self, cycle: core.UnrolledCycle): for output in task.outputs: self._link_output_to_task(task, output) - def _link_input_to_task(self, task: core.Task, input_: core.UnrolledData): + def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data): """ task: the task corresponding to the input input: ... @@ -229,12 +256,14 @@ def _link_input_to_task(self, task: core.Task, input_: core.UnrolledData): if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None: msg = f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph before linking. This is a bug in the code, please contact devevlopers." raise ValueError(msg) - # TODO think about that the yaml file should have aiida valid labels - if (arg_option := task.input_arg_options.get(input_.name, None)) is not None: - workgraph_task_arguments.value.append(f"{arg_option}") - workgraph_task_arguments.value.append(f"{{{input_label}}}") + # TODO: think about that the yaml file should have aiida valid labels + + # Avoid appending the same argument twice + argument_placeholder = f"{{{input_label}}}" + if argument_placeholder not in workgraph_task_arguments.value: + workgraph_task_arguments.value.append() - def _link_output_to_task(self, task: core.Task, output: core.UnrolledData): + def _link_output_to_task(self, task: graph_items.Task, output: graph_items.Data): """ task: the task corresponding to the output output: ... diff --git a/tests/conftest.py b/tests/conftest.py index 66b9686..4a579bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1 @@ -pytest_plugins = ['aiida.tools.pytest_fixtures'] +pytest_plugins = ["aiida.tools.pytest_fixtures"] diff --git a/tests/files/configs/test_config_small_no_icon.yml b/tests/files/configs/test_config_small_no_icon.yml new file mode 100644 index 0000000..7432c8b --- /dev/null +++ b/tests/files/configs/test_config_small_no_icon.yml @@ -0,0 +1,52 @@ +--- +start_date: &root_start_date '2026-01-01T00:00' +end_date: &root_end_date '2026-02-01T00:00' +cycles: + - bimonthly_tasks: + start_date: *root_start_date + end_date: *root_end_date + period: P2M + tasks: + - shell_task: + inputs: + - initial_conditions: + when: + at: *root_start_date + - data1: + when: + at: *root_start_date + outputs: [restart] +tasks: + - shell_task: + plugin: shell + # Currently full task, so running the script actually works + command: /home/geiger_j/aiida_projects/swiss-twins/git-repos/Sirocco/tests/files/scripts/shell_task.sh + cli_argument: > + --restart_kwarg restart_value + --verbosity 2 + --init {initial_conditions} + --test-flag + {data1} + data2 + # Unfortunately not possible to add comments to multiline string, which was my original idea to make it more + # readable: https://stackoverflow.com/questions/20890445/yaml-comments-in-multi-line-strings + # --restart_kwarg restart_value # ? Keyword with normal str/int/float value + # --verbosity 2 # ? Keyword with normal int value + # --init {initial_conditions} # ? Keyword with reference to AiiDA node (available data) or AiiDA-WG socket + # --test-flag # ? Normal flag + # {data1} # ? Positional argument with reference + # data2 # ? Positional argument without AiiDA entity reference (even required)? + env_source_files: + - tests/files/data/dummy_source_file.sh +data: + available: + - initial_conditions: + type: file + src: tests/files/data/initial_conditions2 + - data1: + type: file + src: tests/files/data/data-xyz + generated: + - restart: + type: file + src: restart diff --git a/tests/files/data/data-xyz b/tests/files/data/data-xyz new file mode 100644 index 0000000..e69de29 diff --git a/tests/files/data/dummy_source_file.sh b/tests/files/data/dummy_source_file.sh new file mode 100644 index 0000000..e69de29 diff --git a/tests/files/data/initial_conditions2 b/tests/files/data/initial_conditions2 new file mode 100644 index 0000000..e69de29 diff --git a/tests/files/scripts/cleanup.py b/tests/files/scripts/cleanup.py index ff9c30e..de7aeba 100755 --- a/tests/files/scripts/cleanup.py +++ b/tests/files/scripts/cleanup.py @@ -1,9 +1,10 @@ #!/usr/bin/env python + def main(): # Main script execution continues here print("Cleaning") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/files/scripts/icon.py b/tests/files/scripts/icon.py index ac9e6c5..32f71ed 100755 --- a/tests/files/scripts/icon.py +++ b/tests/files/scripts/icon.py @@ -15,29 +15,29 @@ LOG_FILE = Path("icon.log") + def log(text: str): print(text) with LOG_FILE.open("a") as f: f.write(text) -def main(): - parser = argparse.ArgumentParser(description='A script mocking parts of icon in a form of a shell script.') - parser.add_argument('--init', nargs='?', type=str, help='The icon init file.') - parser.add_argument('namelist', nargs='?', default=None) - parser.add_argument('--restart', nargs='?', type=str, help='The icon restart file.') - parser.add_argument('--forcing', nargs='?', type=str, help='The icon forcing file.') +def main(): + parser = argparse.ArgumentParser(description="A script mocking parts of icon in a form of a shell script.") + parser.add_argument("--init", nargs="?", type=str, help="The icon init file.") + parser.add_argument("namelist", nargs="?", default=None) + parser.add_argument("--restart", nargs="?", type=str, help="The icon restart file.") + parser.add_argument("--forcing", nargs="?", type=str, help="The icon forcing file.") args = parser.parse_args() - - output = Path('icon_output') + output = Path("icon_output") output.write_text("") - if args.restart and args.init: - msg = "Cannot use '--init' and '--restart' option at the same time." - raise ValueError(msg) - elif args.restart: + if args.restart: + if args.init: + msg = "Cannot use '--init' and '--restart' option at the same time." + raise ValueError(msg) if not Path(args.restart).exists(): msg = f"The icon restart file {args.restart!r} was not found." raise FileNotFoundError(msg) @@ -62,10 +62,9 @@ def main(): # Main script execution continues here log("Script finished running calculations") - restart = Path('restart') + restart = Path("restart") restart.write_text("") -if __name__ == '__main__': - main() - +if __name__ == "__main__": + main() diff --git a/tests/files/scripts/shell_task.sh b/tests/files/scripts/shell_task.sh new file mode 100755 index 0000000..9740fd7 --- /dev/null +++ b/tests/files/scripts/shell_task.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Dummy script to validate inputs passed via CLI arguments + +# Function to extract and print positional arguments, keywords, and flags +process_args() { + echo "Processing inputs..." + + # Arrays to store different types of arguments + positional=() + keywords=() + flags=() + + while [[ $# -gt 0 ]]; do + case "$1" in + --*) # Keyword arguments or flags + if [[ "$2" && ! "$2" =~ ^-- ]]; then + keywords+=("$1=$2") + shift 2 + else + flags+=("$1") + shift + fi + ;; + *) # Positional arguments + positional+=("$1") + shift + ;; + esac + done + + # Print positional arguments + echo "Positional arguments:" + for arg in "${positional[@]}"; do + echo " $arg" + done + + # Print keyword arguments + echo "Keyword arguments:" + for keyword in "${keywords[@]}"; do + echo " $keyword" + done + + # Print flags + echo "Flags:" + for flag in "${flags[@]}"; do + echo " $flag" + done +} + +# Process all passed arguments +process_args "$@" + +# Test complete +echo "Test complete. All inputs received and categorized." | tee restart diff --git a/tests/test_wc_workflow.py b/tests/test_wc_workflow.py index 1448490..3b66394 100644 --- a/tests/test_wc_workflow.py +++ b/tests/test_wc_workflow.py @@ -8,15 +8,19 @@ from sirocco.workgraph import AiidaWorkGraph -@pytest.mark.parametrize("config_path", [ - "tests/files/configs/test_config_small.yml", - "tests/files/configs/test_config_parameters.yml", -]) +@pytest.mark.parametrize( + "config_path", + [ + "tests/files/configs/test_config_small.yml", + "tests/files/configs/test_config_parameters.yml", + ], +) def test_run_workgraph(config_path): core_workflow = Workflow.from_yaml(config_path) aiida_workflow = AiidaWorkGraph(core_workflow) out = aiida_workflow.run() - assert out.get('execution_count', None).value == 0 # TODO should be 1 but we need to update workgraph for this + assert out.get("execution_count", None).value == 0 # TODO: should be 1 but we need to update workgraph for this + # configs that are tested only tested parsing config_test_files = [ @@ -25,6 +29,7 @@ def test_run_workgraph(config_path): "tests/files/configs/test_config_parameters.yml", ] + @pytest.fixture(params=config_test_files) def config_paths(request): config_path = Path(request.param) @@ -34,17 +39,21 @@ def config_paths(request): "svg": (config_path.parent.parent / "svgs" / config_path.name).with_suffix(".svg"), } + @pytest.fixture def pprinter(): return PrettyPrinter() + def test_parse_config_file(config_paths, pprinter): reference_str = config_paths["txt"].read_text() test_str = pprinter.format(Workflow.from_yaml(config_paths["yml"])) if test_str != reference_str: new_path = Path(config_paths["txt"]).with_suffix(".new.txt") new_path.write_text(test_str) - assert reference_str == test_str, f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}." + assert ( + reference_str == test_str + ), f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}." @pytest.mark.skip(reason="don't run it each time, uncomment to regenerate serilaized data")