Skip to content

Commit

Permalink
Pass one multi-line string as cli_argument
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Dec 19, 2024
1 parent ef662d4 commit f269adc
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 154 deletions.
66 changes: 33 additions & 33 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,37 @@ def convert_datetime(cls, value) -> datetime:
class _CliArgsBaseModel(BaseModel):
"""Base class for cli_arguments specifications"""

# TODO: Even allow for `str`, or always require list?
positional: str | list[str] | None = None
# Field needed for child class doing pydantic parsing
keyword: dict[str, str | int] | None = Field(default_factory=dict)
flags: str | list[str] | None = None
source_file: str | list[str] | None = None

# TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically?
# TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit.
@field_validator("keyword", mode="before")
@classmethod
def validate_keyword_args(cls, value):
"""Ensure keyword arguments start with '-' or '--'."""
if value is not None:
invalid_keys = [key for key in value if not key.startswith(("-", "--"))]
if invalid_keys:
invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}"
raise ValueError(invalid_kwarg_exc)
return value

@field_validator("flags", mode="before")
@classmethod
def validate_flag_args(cls, value):
"""Ensure positional arguments start with '-' or '--'."""
if value is not None:
if isinstance(value, str):
value = [value]
invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))]
if invalid_flags:
invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}"
raise ValueError(invalid_flags_exc)
return value
string: str = None
# # Field needed for child class doing pydantic parsing
# keyword: dict[str, str | int] | None = Field(default_factory=dict)
# flags: str | list[str] | None = None
source_files: str | list[str] | None = None

# # TODO: Should we allow users to pass it without the hyphen(s), and prepend them automatically?
# # TODO: While convenient, it could be a bad idea, if users put in wrong things. Better to be explicit.
# @field_validator("keyword", mode="before")
# @classmethod
# def validate_keyword_args(cls, value):
# """Ensure keyword arguments start with '-' or '--'."""
# if value is not None:
# invalid_keys = [key for key in value if not key.startswith(("-", "--"))]
# if invalid_keys:
# invalid_kwarg_exc = f"Invalid keyword arguments: {', '.join(invalid_keys)}"
# raise ValueError(invalid_kwarg_exc)
# return value

# @field_validator("flags", mode="before")
# @classmethod
# def validate_flag_args(cls, value):
# """Ensure positional arguments start with '-' or '--'."""
# if value is not None:
# if isinstance(value, str):
# value = [value]
# invalid_flags = [arg for arg in value if not arg.startswith(("-", "--"))]
# if invalid_flags:
# invalid_flags_exc = f"Invalid positional arguments: {', '.join(invalid_flags)}"
# raise ValueError(invalid_flags_exc)
# return value


class TargetNodesBaseModel(_NamedBaseModel):
Expand Down Expand Up @@ -311,7 +310,8 @@ class ConfigRootTask(ConfigBaseTask):
class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
command: str = ""
cli_arguments: _CliArgsBaseModel | None = None
cli_argument: str = ""
env_source_files: str | list[str] = None
src: str | None = None


Expand Down
117 changes: 12 additions & 105 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def _prepare_for_shell_task(task: dict, kwargs: dict) -> dict:
task_outputs = {task["outputs"][i]["name"] for i in range(len(task["outputs"]))}
task_outputs = task_outputs.union(set(kwargs.pop("outputs", [])))
missing_outputs = task_outputs.difference(default_outputs)
# ? kwargs['arguments'] = ['{initial_conditions}', '{data1}']
breakpoint()
return {
"code": code,
"nodes": nodes,
Expand Down Expand Up @@ -179,125 +177,33 @@ def _add_aiida_task_node(self, task: graph_items.Task):
raise ValueError(msg)

command = task.command
argument_str: str = ''

cli_arguments = task.cli_arguments

# ? Positional
# ? -> Seem to already be resolved in the WG creation, see `_link_input_to_task`
positional_args = cli_arguments.positional
positional_args = [positional_args] if isinstance(positional_args, str) else positional_args

def resolve_positional_and_keyword(input_args): # -> dict(str, str|dict):

argument_str = ''
node_dict = {}

for input_arg in input_args:
print(f'INPUT_ARG: {input_arg}')
if isinstance(input_arg, dict):
node_identifier = list(input_arg.values())[0]
# append_str = f' {list(input_arg.keys())[0]} {{{node_identifier}}}'
# print(f"APPEND_STR: {append_str}")
elif isinstance(input_arg, str):
node_identifier = input_arg
# append_str = f' {{{node_identifier}}}'
else:
raise TypeError("Something went wrong.")

aiida_data_node = self._aiida_data_nodes.get(node_identifier, None)
aiida_data_socket = self._aiida_socket_nodes.get(node_identifier, None)
# THis will be either an existing aiida_data_node, or aiida_data_socket, otherwise None
# as aiida_data_socket will be assigned to the new variable, irrespective if None or not
aiida_entity = aiida_data_node or aiida_data_socket
if aiida_entity is None:
if isinstance(input_arg, dict):
append_str = f' {list(input_arg.keys())[0]} {node_identifier}'
elif isinstance(input_arg, str):
append_str = f' {node_identifier}'
argument_str += append_str
# ! Don't resolve the nodes here, as this is done somewhere else, and otherwise excepts
# argument_str += f' {{{positional_arg}}}'
# node_dict[positional_arg] = aiida_entity

# TODO: Final cleanup of non-existing ones
# TODO: Check for absolute path, and if absolute path provided, directly use as input argument, as we
# cannot create Singlefiledata from a file that doesn't live on the localhost -> Maybe create RemoteData instead
return {'arguments': argument_str, 'nodes': node_dict}

positional_resolved = resolve_positional_and_keyword(positional_args)
print(f"POSITIONAL_RESOLVED: {positional_resolved}")
argument_str = f"{argument_str} {positional_resolved['arguments']}"
# nodes = positional_resolved['nodes']

# ? Keywords
keyword_args = cli_arguments.keyword
print(f"KEYWORD_ARGS: {keyword_args}")
# ? Resolve to list of dictionaries with one key-value pair each (just for current implementation)
keyword_resolved = resolve_positional_and_keyword([{key: value} for key, value in keyword_args.items()])
print(f"KEYWORD_RESOLVED: {keyword_resolved}")
# argument_str = f"{argument_str} {keyword_resolved['arguments']}"
# nodes = positional_resolved['nodes']

# ? Flags
flags = cli_arguments.flags # append to string
flags = [flags] if isinstance(flags, str) else flags

argument_str = ' '.join([argument_str] + flags)
# command = ' '.join([command] + flags)

# ? Source file
source_files = cli_arguments.source_file # into prepend-text
source_files = [source_files] if isinstance(source_files, str) else source_files
prepend_text = '\n'.join([f"source {source_file}" for source_file in source_files])
env_source_files = task.env_source_files
env_source_files = [env_source_files] if isinstance(env_source_files, str) else env_source_files
prepend_text = '\n'.join([f"source {env_source_file}" for env_source_file in env_source_files])

# breakpoint()
# from aiida_shell.launch import prepare_computer, prepare_code
# breakpoint()
# TODO: Need access to the root task here, to get access to the host/computer
# TODO: Take care of proper code creation. Either a `PortableCode` or splitting the name, only using the
# TODO: actual script name, but resolving the path to the full path
# localhost = prepare_computer()
# command = task.command
# code = prepare_code(command=command, computer=localhost)

# TODO: Add the additional outputs here
breakpoint()
# ? argument_str = ' {data2} --test'
# ? `data1` and `initial_conditions` are added _somewhere_ else, either in `TaskCollection.new()`,
# ? `build_shelljob_task`, `build_task_from_AiiDA`, or `AiiDAWorkGraph._link_input_to_task()`, which is
# ? where the input/output nodes/sockets are being attached
workgraph_task = self._workgraph.tasks.new(
"ShellJob",
name=label,
# TODO: Currently use command to allow for flags. Ideally, should be also `Code` possible, and arguments
# passed through
command=command,
# TODO: Flags are not being passed to the submission script?
arguments=argument_str,
# arguments=argument_str.split(' ')[2:],
# nodes=nodes,
arguments=task.cli_argument,
# ! Do we still need to add nodes here, as in `aiida-shell`, or WG does that automatically from the
# argument if it finds them?
metadata={
'options': {
'prepend_text': prepend_text
}
}
)

# arguments='{file_a} {file_b}',
# nodes={
# 'file_a': SinglefileData.from_string('string a'),
# 'file_b': SinglefileData.from_string('string b'),
# }
# arguments_str = ''

# workgraph_task.set({"arguments": []})
workgraph_task.set({"nodes": {}})
# workgraph_task.set({"nodes": {}})
self._aiida_task_nodes[label] = workgraph_task

# elif isinstance(task, IconTask):
# exc = f"Task: {task.name} not implemented yet."
# raise NotImplementedError(exc)
elif isinstance(task, IconTask):
exc = "IconTask not implemented yet."
raise NotImplementedError(exc)
else:
exc = f"Task: {task.name} not implemented yet."
raise NotImplementedError(exc)
Expand Down Expand Up @@ -358,7 +264,8 @@ def _link_input_to_task(self, task: graph_items.Task, input_: graph_items.Data):
# workgraph_task_arguments.value.append(f"{arg_option}")
workgraph_task_arguments.value.append(f"{{{input_label}}}")
except Exception:
breakpoint()
pass
# breakpoint()

def _link_output_to_task(self, task: graph_items.Task, output: graph_items.Data):
"""
Expand Down
33 changes: 17 additions & 16 deletions tests/files/configs/test_config_small_no_icon.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,23 @@ tasks:
plugin: shell
# Currently full task, so running the script actually works
command: /home/geiger_j/aiida_projects/swiss-twins/git-repos/Sirocco/tests/files/scripts/shell_task.sh
cli_arguments:
# f"--restart_kwarg restart_value --init {initial_conditions} -a data1 -p 1"
# TODO: How to differentiate between string and data node/socket
keyword: # To be thought about
--restart_kwarg: restart_value # This doesn't map to a file, so is just an option
# --restart_kwarg: {restart_value} # This doesn't map to a file, so is just an option
# --verbosity: 2
--init: {initial_conditions} # This maps to a file, and should therefore be resolved to a SinglefileData
positional:
# - data1 # This maps to an actual existing file, specified in data section
- data2 # This is just an arbitrary positional argument
flags:
- --test-flag
# - --verbosity=2 # Didn't actually consider this
source_file:
- tests/files/data/dummy_source_file.sh
cli_argument: >
--restart_kwarg restart_value
--verbosity 2
--init {initial_conditions}
--test-flag
{data1}
data2
# Unfortunately not possible to add comments to multiline string, which was my original idea to make it more
# readable: https://stackoverflow.com/questions/20890445/yaml-comments-in-multi-line-strings
# --restart_kwarg restart_value # ? Keyword with normal str/int/float value
# --verbosity 2 # ? Keyword with normal int value
# --init {initial_conditions} # ? Keyword with reference to AiiDA node (available data) or AiiDA-WG socket
# --test-flag # ? Normal flag
# {data1} # ? Positional argument with reference
# data2 # ? Positional argument without AiiDA entity reference (even required)?
env_source_files:
- tests/files/data/dummy_source_file.sh
data:
available:
- initial_conditions:
Expand Down

0 comments on commit f269adc

Please sign in to comment.