diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index ce5f32b..d866968 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -86,38 +86,37 @@ def convert_datetime(cls, value) -> datetime: class _CliArgsBaseModel(BaseModel): """Base class for cli_arguments specifications""" - # TODO: Even allow for `str`, or always require list? - positional: str | list[str] | None = None - # Field needed for child class doing pydantic parsing - keyword: dict[str, str | int] | None = Field(default_factory=dict) - flags: str | list[str] | None = None - source_file: str | list[str] | None = None - - # TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically? - # TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit. - @field_validator("keyword", mode="before") - @classmethod - def validate_keyword_args(cls, value): - """Ensure keyword arguments start with '-' or '--'.""" - if value is not None: - invalid_keys = [key for key in value if not key.startswith(("-", "--"))] - if invalid_keys: - invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}" - raise ValueError(invalid_kwarg_exc) - return value - - @field_validator("flags", mode="before") - @classmethod - def validate_flag_args(cls, value): - """Ensure positional arguments start with '-' or '--'.""" - if value is not None: - if isinstance(value, str): - value = [value] - invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))] - if invalid_flags: - invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}" - raise ValueError(invalid_flags_exc) - return value + string: str = None + # # Field needed for child class doing pydantic parsing + # keyword: dict[str, str | int] | None = Field(default_factory=dict) + # flags: str | list[str] | None = None + source_files: str | list[str] | None = None + + # # TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically? + # # TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit. + # @field_validator("keyword", mode="before") + # @classmethod + # def validate_keyword_args(cls, value): + # """Ensure keyword arguments start with '-' or '--'.""" + # if value is not None: + # invalid_keys = [key for key in value if not key.startswith(("-", "--"))] + # if invalid_keys: + # invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}" + # raise ValueError(invalid_kwarg_exc) + # return value + + # @field_validator("flags", mode="before") + # @classmethod + # def validate_flag_args(cls, value): + # """Ensure positional arguments start with '-' or '--'.""" + # if value is not None: + # if isinstance(value, str): + # value = [value] + # invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))] + # if invalid_flags: + # invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}" + # raise ValueError(invalid_flags_exc) + # return value class TargetNodesBaseModel(_NamedBaseModel): @@ -311,7 +310,8 @@ class ConfigRootTask(ConfigBaseTask): class ConfigShellTaskSpecs: plugin: ClassVar[Literal["shell"]] = "shell" command: str = "" - cli_arguments: _CliArgsBaseModel | None = None + cli_argument: str = "" + env_source_files: str | list[str] = None src: str | None = None diff --git a/src/sirocco/workgraph.py b/src/sirocco/workgraph.py index 171adca..53d6f99 100644 --- a/src/sirocco/workgraph.py +++ b/src/sirocco/workgraph.py @@ -48,8 +48,6 @@ def _prepare_for_shell_task(task: dict, kwargs: dict) -> dict: task_outputs = {task["outputs"][i]["name"] for i in range(len(task["outputs"]))} task_outputs = task_outputs.union(set(kwargs.pop("outputs", []))) missing_outputs = task_outputs.difference(default_outputs) - # ? kwargs['arguments'] = ['{initial_conditions}', '{data1}'] - breakpoint() return { "code": code, "nodes": nodes, @@ -179,104 +177,19 @@ def _add_aiida_task_node(self, task: graph_items.Task): raise ValueError(msg) command = task.command - argument_str: str = '' - - cli_arguments = task.cli_arguments - - # ? Positional - # ? -> Seem to already be resolved in the WG creation, see `_link_input_to_task` - positional_args = cli_arguments.positional - positional_args = [positional_args] if isinstance(positional_args, str) else positional_args - - def resolve_positional_and_keyword(input_args): # -> dict(str, str|dict): - - argument_str = '' - node_dict = {} - - for input_arg in input_args: - print(f'INPUT_ARG: {input_arg}') - if isinstance(input_arg, dict): - node_identifier = list(input_arg.values())[0] - # append_str = f' {list(input_arg.keys())[0]} {{{node_identifier}}}' - # print(f"APPEND_STR: {append_str}") - elif isinstance(input_arg, str): - node_identifier = input_arg - # append_str = f' {{{node_identifier}}}' - else: - raise TypeError("Something went wrong.") - - aiida_data_node = self._aiida_data_nodes.get(node_identifier, None) - aiida_data_socket = self._aiida_socket_nodes.get(node_identifier, None) - # THis will be either an existing aiida_data_node, or aiida_data_socket, otherwise None - # as aiida_data_socket will be assigned to the new variable, irrespective if None or not - aiida_entity = aiida_data_node or aiida_data_socket - if aiida_entity is None: - if isinstance(input_arg, dict): - append_str = f' {list(input_arg.keys())[0]} {node_identifier}' - elif isinstance(input_arg, str): - append_str = f' {node_identifier}' - argument_str += append_str - # ! Don't resolve the nodes here, as this is done somewhere else, and otherwise excepts - # argument_str += f' {{{positional_arg}}}' - # node_dict[positional_arg] = aiida_entity - - # TODO: Final cleanup of non-existing ones - # TODO: Check for absolute path, and if absolute path provided, directly use as input argument, as we - # cannot create Singlefiledata from a file that doesn't live on the localhost -> Maybe create RemoteData instead - return {'arguments': argument_str, 'nodes': node_dict} - - positional_resolved = resolve_positional_and_keyword(positional_args) - print(f"POSITIONAL_RESOLVED: {positional_resolved}") - argument_str = f"{argument_str} {positional_resolved['arguments']}" - # nodes = positional_resolved['nodes'] - - # ? Keywords - keyword_args = cli_arguments.keyword - print(f"KEYWORD_ARGS: {keyword_args}") - # ? Resolve to list of dictionaries with one key-value pair each (just for current implementation) - keyword_resolved = resolve_positional_and_keyword([{key: value} for key, value in keyword_args.items()]) - print(f"KEYWORD_RESOLVED: {keyword_resolved}") - # argument_str = f"{argument_str} {keyword_resolved['arguments']}" - # nodes = positional_resolved['nodes'] - - # ? Flags - flags = cli_arguments.flags # append to string - flags = [flags] if isinstance(flags, str) else flags - - argument_str = ' '.join([argument_str] + flags) - # command = ' '.join([command] + flags) # ? Source file - source_files = cli_arguments.source_file # into prepend-text - source_files = [source_files] if isinstance(source_files, str) else source_files - prepend_text = '\n'.join([f"source {source_file}" for source_file in source_files]) + env_source_files = task.env_source_files + env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files + prepend_text = '\n'.join([f"source {env_source_file}" for env_source_file in env_source_files]) - # breakpoint() - # from aiida_shell.launch import prepare_computer, prepare_code - # breakpoint() - # TODO: Need access to the root task here, to get access to the host/computer - # TODO: Take care of proper code creation. Either a `PortableCode` or splitting the name, only using the - # TODO: actual script name, but resolving the path to the full path - # localhost = prepare_computer() - # command = task.command - # code = prepare_code(command=command, computer=localhost) - - # TODO: Add the additional outputs here - breakpoint() - # ? argument_str = ' {data2} --test' - # ? `data1` and `initial_conditions` are added _somewhere_ else, either in `TaskCollection.new()`, - # ? `build_shelljob_task`, `build_task_from_AiiDA`, or `AiiDAWorkGraph._link_input_to_task()`, which is - # ? where the input/output nodes/sockets are being attached workgraph_task = self._workgraph.tasks.new( "ShellJob", name=label, - # TODO: Currently use command to allow for flags. Ideally, should be also `Code` possible, and arguments - # passed through command=command, - # TODO: Flags are not being passed to the submission script? - arguments=argument_str, - # arguments=argument_str.split(' ')[2:], - # nodes=nodes, + arguments=task.cli_argument, + # ! Do we still need to add nodes here, as in `aiida-shell`, or WG does that automatically from the + # argument if it finds them? metadata={ 'options': { 'prepend_text': prepend_text @@ -284,20 +197,13 @@ def resolve_positional_and_keyword(input_args): # -> dict(str, str|dict): } ) - # arguments='{file_a} {file_b}', - # nodes={ - # 'file_a': SinglefileData.from_string('string a'), - # 'file_b': SinglefileData.from_string('string b'), - # } - # arguments_str = '' - # workgraph_task.set({"arguments": []}) - workgraph_task.set({"nodes": {}}) + # workgraph_task.set({"nodes": {}}) self._aiida_task_nodes[label] = workgraph_task - # elif isinstance(task, IconTask): - # exc = f"Task: {task.name} not implemented yet." - # raise NotImplementedError(exc) + elif isinstance(task, IconTask): + exc = "IconTask not implemented yet." + raise NotImplementedError(exc) else: exc = f"Task: {task.name} not implemented yet." raise NotImplementedError(exc) @@ -358,7 +264,8 @@ def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data): # workgraph_task_arguments.value.append(f"{arg_option}") workgraph_task_arguments.value.append(f"{{{input_label}}}") except Exception: - breakpoint() + pass + # breakpoint() def _link_output_to_task(self, task: graph_items.Task, output: graph_items.Data): """ diff --git a/tests/files/configs/test_config_small_no_icon.yml b/tests/files/configs/test_config_small_no_icon.yml index f28e02a..6503a39 100644 --- a/tests/files/configs/test_config_small_no_icon.yml +++ b/tests/files/configs/test_config_small_no_icon.yml @@ -21,22 +21,23 @@ tasks: plugin: shell # Currently full task, so running the script actually works command: /home/geiger_j/aiida_projects/swiss-twins/git-repos/Sirocco/tests/files/scripts/shell_task.sh - cli_arguments: - # f"--restart_kwarg restart_value --init {initial_conditions} -a data1 -p 1" - # TODO: How to differentiate between string and data node/socket - keyword: # To be thought about - --restart_kwarg: restart_value # This doesn't map to a file, so is just an option - # --restart_kwarg: {restart_value} # This doesn't map to a file, so is just an option - # --verbosity: 2 - --init: {initial_conditions} # This maps to a file, and should therefore be resolved to a SinglefileData - positional: - # - data1 # This maps to an actual existing file, specified in data section - - data2 # This is just an arbitrary positional argument - flags: - - --test-flag - # - --verbosity=2 # Didn't actually consider this - source_file: - - tests/files/data/dummy_source_file.sh + cli_argument: > + --restart_kwarg restart_value + --verbosity 2 + --init {initial_conditions} + --test-flag + {data1} + data2 + # Unfortunately not possible to add comments to multiline string, which was my original idea to make it more + # readable: https://stackoverflow.com/questions/20890445/yaml-comments-in-multi-line-strings + # --restart_kwarg restart_value # ? Keyword with normal str/int/float value + # --verbosity 2 # ? Keyword with normal int value + # --init {initial_conditions} # ? Keyword with reference to AiiDA node (available data) or AiiDA-WG socket + # --test-flag # ? Normal flag + # {data1} # ? Positional argument with reference + # data2 # ? Positional argument without AiiDA entity reference (even required)? + env_source_files: + - tests/files/data/dummy_source_file.sh data: available: - initial_conditions: