Skip to content

Commit

Permalink
Pass single argument string to ShellTasks (#72)
Browse files Browse the repository at this point in the history
* Try to make current example YAML files run through.

* Update class names after rebase/merge

* Add `no_icon` config file

* First working version apart from flags.

* Try fixing argument format.

* Current state before branch-off.

* Pass one multi-line string as `cli_argument`

* Cleanup.

* Pass arguments as list.

* Remove `workgraph-dev.py` dev file.

* Fix issues from `hatch fmt`
  • Loading branch information
GeigerJ2 authored Dec 21, 2024
1 parent 2daa2ad commit 0cd2ae5
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:

@property
def path(self) -> Path:
# TODO yaml level?
# TODO: yaml level?
return Path(expandvars(self.src))


Expand Down
50 changes: 7 additions & 43 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import time
from dataclasses import dataclass
from datetime import datetime
from os.path import expandvars
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal
from os.path import expandvars

from isoduration import parse_duration
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
Expand Down Expand Up @@ -83,43 +83,6 @@ def convert_datetime(cls, value) -> datetime:
return datetime.fromisoformat(value)


class _CliArgsBaseModel(BaseModel):
"""Base class for cli_arguments specifications"""

# TODO: Even allow for `str`, or always require list?
positional: str | list[str] | None = None
# Field needed for child class doing pydantic parsing
keyword: dict[str, str] | None = Field(default_factory=dict)
flags: str | list[str] | None = None
source_file: str | list[str] | None = None

# TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically?
# TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit.
@field_validator("keyword", mode="before")
@classmethod
def validate_keyword_args(cls, value):
"""Ensure keyword arguments start with '-' or '--'."""
if value is not None:
invalid_keys = [key for key in value if not key.startswith(("-", "--"))]
if invalid_keys:
invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}"
raise ValueError(invalid_kwarg_exc)
return value

@field_validator("flags", mode="before")
@classmethod
def validate_flag_args(cls, value):
"""Ensure positional arguments start with '-' or '--'."""
if value is not None:
if isinstance(value, str):
value = [value]
invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))]
if invalid_flags:
invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}"
raise ValueError(invalid_flags_exc)
return value


class TargetNodesBaseModel(_NamedBaseModel):
"""class for targeting other task or data nodes in the graph
Expand Down Expand Up @@ -311,20 +274,20 @@ class ConfigRootTask(ConfigBaseTask):
class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
command: str = ""
cli_arguments: _CliArgsBaseModel | None = None
cli_argument: str = ""
env_source_files: str | list[str] = None
src: str | None = None


class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
# PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests
command: str = ""
command_option: str = ""

# PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests
@field_validator("command", "src")
@classmethod
def expand_var(cls, value: str) -> str:
"""Expand environemnt variables"""
# TODO this might be not intended if we want to use environment variables on remote HPC
# TODO: this might be not intended if we want to use environment variables on remote HPC
return expandvars(value)


Expand All @@ -333,6 +296,7 @@ class ConfigIconTaskSpecs:
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: dict[str, str] | None = None


class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
pass

Expand Down Expand Up @@ -367,7 +331,7 @@ def is_file_or_dir(cls, value: str) -> str:
@classmethod
def expand_var(cls, value: str | None) -> str | None:
"""Expand environemnt variables"""
# TODO this might be not intended if we want to use environment variables on remote HPC
# TODO: this might be not intended if we want to use environment variables on remote HPC
return None if value is None else expandvars(value)


Expand Down
85 changes: 57 additions & 28 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import aiida_workgraph.engine.utils # type: ignore[import-untyped]
from aiida_workgraph import WorkGraph # type: ignore[import-untyped]

from sirocco.core._tasks.icon_task import IconTask
from sirocco.core._tasks.shell_task import ShellTask

if TYPE_CHECKING:
from aiida_workgraph.socket import TaskSocket # type: ignore[import-untyped]
from wcflow import core

from sirocco import core
from sirocco.core import graph_items


# This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require
Expand Down Expand Up @@ -65,7 +70,7 @@ def __init__(self, core_workflow: core.Workflow):

self._validate_workflow()

self._workgraph = WorkGraph() # core_workflow.name TODO use filename
self._workgraph = WorkGraph() # core_workflow.name TODO: use filename

# stores the input data available on initialization
self._aiida_data_nodes: dict[str, aiida_workgraph.orm.Data] = {}
Expand Down Expand Up @@ -123,24 +128,24 @@ def parse_to_aiida_label(label: str) -> str:
return label

@staticmethod
def get_aiida_label_from_unrolled_data(obj: core.BaseNode) -> str:
def get_aiida_label_from_unrolled_data(obj: graph_items.GraphItem) -> str:
""" """
return AiidaWorkGraph.parse_to_aiida_label(
f"{obj.name}" + "__".join(f"_{key}_{value}" for key, value in obj.coordinates.items())
)

@staticmethod
def get_aiida_label_from_unrolled_task(obj: core.BaseNode) -> str:
def get_aiida_label_from_unrolled_task(obj: graph_items.GraphItem) -> str:
""" """
# TODO task is not anymore using cycle name because information is not there
# TODO: task is not anymore using cycle name because information is not there
# so do we check somewhere that a task is not used in multiple cycles?
# Otherwise the label is not unique
# --> task name + date + parameters
return AiidaWorkGraph.parse_to_aiida_label(
"__".join([f"{obj.name}"] + [f"_{key}_{value}" for key, value in obj.coordinates.items()])
)

def _add_aiida_input_data_node(self, input_: core.UnrolledData):
def _add_aiida_input_data_node(self, input_: graph_items.Data):
"""
Create an :class:`aiida.orm.Data` instance from this wc data instance.
Expand All @@ -160,27 +165,49 @@ def _add_aiida_task_nodes(self):
for task in cycle.tasks:
self._add_aiida_task_node(task)
# after creation we can link the wait_on tasks
# TODO check where this is now
#for cycle in self._core_workflow.cycles:
# TODO: check where this is now
# for cycle in self._core_workflow.cycles:
# for task in cycle.tasks:
# self._link_wait_on_to_task(task)

def _add_aiida_task_node(self, task: core.UnrolledTask):
def _add_aiida_task_node(self, task: graph_items.Task):
label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task)
if task.command is None:
msg = f"The command is None of task {task}."
raise ValueError(msg)
workgraph_task = self._workgraph.tasks.new(
"ShellJob",
name=label,
command=task.command,
)
workgraph_task.set({"arguments": []})
workgraph_task.set({"nodes": {}})
self._aiida_task_nodes[label] = workgraph_task
if isinstance(task, ShellTask):
if task.command is None:
msg = f"The command is None of task {task}."
raise ValueError(msg)

# TODO: Add proper resolving of code/command to installed/portable code, etc., to think about
command = task.command

# ? Source file
env_source_files = task.env_source_files
env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files
prepend_text = "\n".join([f"source {env_source_file}" for env_source_file in env_source_files])

# Note: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when
# we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to
# tasks
argument_list = task.cli_argument.split()
workgraph_task = self._workgraph.tasks.new(
"ShellJob",
name=label,
command=command,
arguments=argument_list,
metadata={"options": {"prepend_text": prepend_text}},
)

self._aiida_task_nodes[label] = workgraph_task

elif isinstance(task, IconTask):
exc = "IconTask not implemented yet."
raise NotImplementedError(exc)
else:
exc = f"Task: {task.name} not implemented yet."
raise NotImplementedError(exc)

def _link_wait_on_to_task(self, task: core.UnrolledTask):
# TODO
def _link_wait_on_to_task(self, task: graph_items.Task):
# TODO: to be done
msg = ""
raise NotImplementedError(msg)
label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task)
Expand All @@ -202,7 +229,7 @@ def _add_aiida_links_from_cycle(self, cycle: core.UnrolledCycle):
for output in task.outputs:
self._link_output_to_task(task, output)

def _link_input_to_task(self, task: core.Task, input_: core.UnrolledData):
def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data):
"""
task: the task corresponding to the input
input: ...
Expand All @@ -229,12 +256,14 @@ def _link_input_to_task(self, task: core.Task, input_: core.UnrolledData):
if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None:
msg = f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph before linking. This is a bug in the code, please contact devevlopers."
raise ValueError(msg)
# TODO think about that the yaml file should have aiida valid labels
if (arg_option := task.input_arg_options.get(input_.name, None)) is not None:
workgraph_task_arguments.value.append(f"{arg_option}")
workgraph_task_arguments.value.append(f"{{{input_label}}}")
# TODO: think about that the yaml file should have aiida valid labels

# Avoid appending the same argument twice
argument_placeholder = f"{{{input_label}}}"
if argument_placeholder not in workgraph_task_arguments.value:
workgraph_task_arguments.value.append()

def _link_output_to_task(self, task: core.Task, output: core.UnrolledData):
def _link_output_to_task(self, task: graph_items.Task, output: graph_items.Data):
"""
task: the task corresponding to the output
output: ...
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pytest_plugins = ['aiida.tools.pytest_fixtures']
pytest_plugins = ["aiida.tools.pytest_fixtures"]
52 changes: 52 additions & 0 deletions tests/files/configs/test_config_small_no_icon.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
start_date: &root_start_date '2026-01-01T00:00'
end_date: &root_end_date '2026-02-01T00:00'
cycles:
- bimonthly_tasks:
start_date: *root_start_date
end_date: *root_end_date
period: P2M
tasks:
- shell_task:
inputs:
- initial_conditions:
when:
at: *root_start_date
- data1:
when:
at: *root_start_date
outputs: [restart]
tasks:
- shell_task:
plugin: shell
# Currently full task, so running the script actually works
command: /home/geiger_j/aiida_projects/swiss-twins/git-repos/Sirocco/tests/files/scripts/shell_task.sh
cli_argument: >
--restart_kwarg restart_value
--verbosity 2
--init {initial_conditions}
--test-flag
{data1}
data2
# Unfortunately not possible to add comments to multiline string, which was my original idea to make it more
# readable: https://stackoverflow.com/questions/20890445/yaml-comments-in-multi-line-strings
# --restart_kwarg restart_value # ? Keyword with normal str/int/float value
# --verbosity 2 # ? Keyword with normal int value
# --init {initial_conditions} # ? Keyword with reference to AiiDA node (available data) or AiiDA-WG socket
# --test-flag # ? Normal flag
# {data1} # ? Positional argument with reference
# data2 # ? Positional argument without AiiDA entity reference (even required)?
env_source_files:
- tests/files/data/dummy_source_file.sh
data:
available:
- initial_conditions:
type: file
src: tests/files/data/initial_conditions2
- data1:
type: file
src: tests/files/data/data-xyz
generated:
- restart:
type: file
src: restart
Empty file added tests/files/data/data-xyz
Empty file.
Empty file.
Empty file.
3 changes: 2 additions & 1 deletion tests/files/scripts/cleanup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python


def main():
# Main script execution continues here
print("Cleaning")


if __name__ == '__main__':
if __name__ == "__main__":
main()
31 changes: 15 additions & 16 deletions tests/files/scripts/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@

LOG_FILE = Path("icon.log")


def log(text: str):
print(text)
with LOG_FILE.open("a") as f:
f.write(text)

def main():
parser = argparse.ArgumentParser(description='A script mocking parts of icon in a form of a shell script.')
parser.add_argument('--init', nargs='?', type=str, help='The icon init file.')
parser.add_argument('namelist', nargs='?', default=None)
parser.add_argument('--restart', nargs='?', type=str, help='The icon restart file.')
parser.add_argument('--forcing', nargs='?', type=str, help='The icon forcing file.')

def main():
parser = argparse.ArgumentParser(description="A script mocking parts of icon in a form of a shell script.")
parser.add_argument("--init", nargs="?", type=str, help="The icon init file.")
parser.add_argument("namelist", nargs="?", default=None)
parser.add_argument("--restart", nargs="?", type=str, help="The icon restart file.")
parser.add_argument("--forcing", nargs="?", type=str, help="The icon forcing file.")

args = parser.parse_args()


output = Path('icon_output')
output = Path("icon_output")
output.write_text("")

if args.restart and args.init:
msg = "Cannot use '--init' and '--restart' option at the same time."
raise ValueError(msg)
elif args.restart:
if args.restart:
if args.init:
msg = "Cannot use '--init' and '--restart' option at the same time."
raise ValueError(msg)
if not Path(args.restart).exists():
msg = f"The icon restart file {args.restart!r} was not found."
raise FileNotFoundError(msg)
Expand All @@ -62,10 +62,9 @@ def main():
# Main script execution continues here
log("Script finished running calculations")

restart = Path('restart')
restart = Path("restart")
restart.write_text("")

if __name__ == '__main__':
main()


if __name__ == "__main__":
main()
Loading

0 comments on commit 0cd2ae5

Please sign in to comment.