Skip to content

Commit

Permalink
Fix issues from hatch fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Dec 19, 2024
1 parent 4e99d2e commit c56c2ed
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:

@property
def path(self) -> Path:
# TODO yaml level?
# TODO: yaml level?
return Path(expandvars(self.src))


Expand Down
8 changes: 4 additions & 4 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import time
from dataclasses import dataclass
from datetime import datetime
from os.path import expandvars
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal
from os.path import expandvars

from isoduration import parse_duration
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
Expand Down Expand Up @@ -280,15 +280,14 @@ class ConfigShellTaskSpecs:


class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):

command: str = ""

# PR(COMMENT) tmp hack to make script work, need to find better solution than PWD for tests
@field_validator("command", "src")
@classmethod
def expand_var(cls, value: str) -> str:
"""Expand environemnt variables"""
# TODO this might be not intended if we want to use environment variables on remote HPC
# TODO: this might be not intended if we want to use environment variables on remote HPC
return expandvars(value)


Expand All @@ -297,6 +296,7 @@ class ConfigIconTaskSpecs:
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: dict[str, str] | None = None


class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
pass

Expand Down Expand Up @@ -331,7 +331,7 @@ def is_file_or_dir(cls, value: str) -> str:
@classmethod
def expand_var(cls, value: str | None) -> str | None:
"""Expand environemnt variables"""
# TODO this might be not intended if we want to use environment variables on remote HPC
# TODO: this might be not intended if we want to use environment variables on remote HPC
return None if value is None else expandvars(value)


Expand Down
23 changes: 10 additions & 13 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import aiida_workgraph.engine.utils # type: ignore[import-untyped]
from aiida_workgraph import WorkGraph # type: ignore[import-untyped]

from sirocco import core
from sirocco.core import graph_items
from sirocco.core._tasks.icon_task import IconTask
from sirocco.core._tasks.shell_task import ShellTask

if TYPE_CHECKING:
from aiida_workgraph.socket import TaskSocket # type: ignore[import-untyped]

from sirocco import core
from sirocco.core import graph_items


# This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require
# some major refactor see issue https://github.com/aiidateam/aiida-workgraph/issues/168
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self, core_workflow: core.Workflow):

self._validate_workflow()

self._workgraph = WorkGraph() # core_workflow.name TODO use filename
self._workgraph = WorkGraph() # core_workflow.name TODO: use filename

# stores the input data available on initialization
self._aiida_data_nodes: dict[str, aiida_workgraph.orm.Data] = {}
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_aiida_label_from_unrolled_data(obj: graph_items.GraphItem) -> str:
@staticmethod
def get_aiida_label_from_unrolled_task(obj: graph_items.GraphItem) -> str:
""" """
# TODO task is not anymore using cycle name because information is not there
# TODO: task is not anymore using cycle name because information is not there
# so do we check somewhere that a task is not used in multiple cycles?
# Otherwise the label is not unique
# --> task name + date + parameters
Expand Down Expand Up @@ -164,7 +165,7 @@ def _add_aiida_task_nodes(self):
for task in cycle.tasks:
self._add_aiida_task_node(task)
# after creation we can link the wait_on tasks
# TODO check where this is now
# TODO: check where this is now
# for cycle in self._core_workflow.cycles:
# for task in cycle.tasks:
# self._link_wait_on_to_task(task)
Expand All @@ -182,7 +183,7 @@ def _add_aiida_task_node(self, task: graph_items.Task):
# ? Source file
env_source_files = task.env_source_files
env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files
prepend_text = '\n'.join([f"source {env_source_file}" for env_source_file in env_source_files])
prepend_text = "\n".join([f"source {env_source_file}" for env_source_file in env_source_files])

# Note: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when
# we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to
Expand All @@ -193,11 +194,7 @@ def _add_aiida_task_node(self, task: graph_items.Task):
name=label,
command=command,
arguments=argument_list,
metadata={
'options': {
'prepend_text': prepend_text
}
}
metadata={"options": {"prepend_text": prepend_text}},
)

self._aiida_task_nodes[label] = workgraph_task
Expand All @@ -210,7 +207,7 @@ def _add_aiida_task_node(self, task: graph_items.Task):
raise NotImplementedError(exc)

def _link_wait_on_to_task(self, task: graph_items.Task):
# TODO
# TODO: to be done
msg = ""
raise NotImplementedError(msg)
label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task)
Expand Down Expand Up @@ -259,7 +256,7 @@ def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data):
if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None:
msg = f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph before linking. This is a bug in the code, please contact devevlopers."
raise ValueError(msg)
# TODO think about that the yaml file should have aiida valid labels
# TODO: think about that the yaml file should have aiida valid labels

# Avoid appending the same argument twice
argument_placeholder = f"{{{input_label}}}"
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pytest_plugins = ['aiida.tools.pytest_fixtures']
pytest_plugins = ["aiida.tools.pytest_fixtures"]
3 changes: 2 additions & 1 deletion tests/files/scripts/cleanup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python


def main():
# Main script execution continues here
print("Cleaning")


if __name__ == '__main__':
if __name__ == "__main__":
main()
31 changes: 15 additions & 16 deletions tests/files/scripts/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@

LOG_FILE = Path("icon.log")


def log(text: str):
print(text)
with LOG_FILE.open("a") as f:
f.write(text)

def main():
parser = argparse.ArgumentParser(description='A script mocking parts of icon in a form of a shell script.')
parser.add_argument('--init', nargs='?', type=str, help='The icon init file.')
parser.add_argument('namelist', nargs='?', default=None)
parser.add_argument('--restart', nargs='?', type=str, help='The icon restart file.')
parser.add_argument('--forcing', nargs='?', type=str, help='The icon forcing file.')

def main():
parser = argparse.ArgumentParser(description="A script mocking parts of icon in a form of a shell script.")
parser.add_argument("--init", nargs="?", type=str, help="The icon init file.")
parser.add_argument("namelist", nargs="?", default=None)
parser.add_argument("--restart", nargs="?", type=str, help="The icon restart file.")
parser.add_argument("--forcing", nargs="?", type=str, help="The icon forcing file.")

args = parser.parse_args()


output = Path('icon_output')
output = Path("icon_output")
output.write_text("")

if args.restart and args.init:
msg = "Cannot use '--init' and '--restart' option at the same time."
raise ValueError(msg)
elif args.restart:
if args.restart:
if args.init:
msg = "Cannot use '--init' and '--restart' option at the same time."
raise ValueError(msg)
if not Path(args.restart).exists():
msg = f"The icon restart file {args.restart!r} was not found."
raise FileNotFoundError(msg)
Expand All @@ -62,10 +62,9 @@ def main():
# Main script execution continues here
log("Script finished running calculations")

restart = Path('restart')
restart = Path("restart")
restart.write_text("")

if __name__ == '__main__':
main()


if __name__ == "__main__":
main()
21 changes: 15 additions & 6 deletions tests/test_wc_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from sirocco.workgraph import AiidaWorkGraph


@pytest.mark.parametrize("config_path", [
"tests/files/configs/test_config_small.yml",
"tests/files/configs/test_config_parameters.yml",
])
@pytest.mark.parametrize(
"config_path",
[
"tests/files/configs/test_config_small.yml",
"tests/files/configs/test_config_parameters.yml",
],
)
def test_run_workgraph(config_path):
core_workflow = Workflow.from_yaml(config_path)
aiida_workflow = AiidaWorkGraph(core_workflow)
out = aiida_workflow.run()
assert out.get('execution_count', None).value == 0 # TODO should be 1 but we need to update workgraph for this
assert out.get("execution_count", None).value == 0 # TODO: should be 1 but we need to update workgraph for this


# configs that are tested only tested parsing
config_test_files = [
Expand All @@ -25,6 +29,7 @@ def test_run_workgraph(config_path):
"tests/files/configs/test_config_parameters.yml",
]


@pytest.fixture(params=config_test_files)
def config_paths(request):
config_path = Path(request.param)
Expand All @@ -34,17 +39,21 @@ def config_paths(request):
"svg": (config_path.parent.parent / "svgs" / config_path.name).with_suffix(".svg"),
}


@pytest.fixture
def pprinter():
return PrettyPrinter()


def test_parse_config_file(config_paths, pprinter):
reference_str = config_paths["txt"].read_text()
test_str = pprinter.format(Workflow.from_yaml(config_paths["yml"]))
if test_str != reference_str:
new_path = Path(config_paths["txt"]).with_suffix(".new.txt")
new_path.write_text(test_str)
assert reference_str == test_str, f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}."
assert (
reference_str == test_str
), f"Workflow graph doesn't match serialized data. New graph string dumped to {new_path}."


@pytest.mark.skip(reason="don't run it each time, uncomment to regenerate serilaized data")
Expand Down

0 comments on commit c56c2ed

Please sign in to comment.