From 4d4e50161dae9100d2f8197796f9e529bb4ad0fa Mon Sep 17 00:00:00 2001 From: Matthieu Leclair Date: Mon, 2 Dec 2024 13:11:16 +0100 Subject: [PATCH] Fuse classes for wait_on tasks and input data (#56) `ConfigCycleTaskInput` and `ConfigCycleTaskWaitOn` classes were gathering the same required information for targeting other nodes. No need to have 2. --- src/sirocco/core.py | 12 +- src/sirocco/parsing/_yaml_data_models.py | 191 +++++++++++------------ 2 files changed, 92 insertions(+), 111 deletions(-) diff --git a/src/sirocco/core.py b/src/sirocco/core.py index 9a8d343..18483e1 100644 --- a/src/sirocco/core.py +++ b/src/sirocco/core.py @@ -7,8 +7,6 @@ from sirocco.parsing._yaml_data_models import ( ConfigCycleTask, - ConfigCycleTaskInput, - ConfigCycleTaskWaitOn, ConfigTask, ConfigWorkflow, load_workflow_config, @@ -18,9 +16,7 @@ from collections.abc import Iterator from datetime import datetime - from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel - - type ConfigCycleSpec = ConfigCycleTaskWaitOn | ConfigCycleTaskInput + from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel, TargetNodesBaseModel logging.basicConfig() logger = logging.getLogger(__name__) @@ -185,7 +181,7 @@ def __getitem__(self, coordinates: dict) -> GraphItem: key = tuple(coordinates[dim] for dim in self._dims) return self._dict[key] - def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterator[GraphItem]: + def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]: # Check date references if "date" not in self._dims and (spec.lag or spec.date): msg = f"Array {self._name} has no date dimension, cannot be referenced by dates" @@ -197,7 +193,7 @@ def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterat for key in product(*(self._resolve_target_dim(spec, dim, reference) for dim in self._dims)): yield self._dict[key] - def _resolve_target_dim(self, spec: ConfigCycleSpec, dim: str, reference: Any) -> Iterator[Any]: + def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, reference: Any) -> Iterator[Any]: if dim == "date": if not spec.lag and not spec.date: yield reference["date"] @@ -239,7 +235,7 @@ def __getitem__(self, key: tuple[str, dict]) -> GraphItem: raise KeyError(msg) return self._dict[name][coordinates] - def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterator[GraphItem]: + def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]: # Check if target items should be querried at all if (when := spec.when) is not None: if (ref_date := reference.get("date")) is None: diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 226d2dd..67c00ce 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -70,10 +70,13 @@ def convert_datetime(cls, value) -> datetime: return datetime.fromisoformat(value) -# TODO: Change class name, does not fit anymore wit hthe addition of `when` and `parameters` -# find something more related to graph specification in general like _GraphTargetBaseModel -class _LagDateBaseModel(BaseModel): - """Base class for all classes containg a list of dates or time lags.""" +class TargetNodesBaseModel(_NamedBaseModel): + """class for targeting other task or data nodes in the graph + + When specifying cycle tasks, this class gathers the required information for + targeting other nodes, either input data or wait on tasks. + + """ model_config = ConfigDict(arbitrary_types_allowed=True) date: list[datetime] = [] # this is safe in pydantic @@ -117,111 +120,14 @@ def check_dict_single_item(cls, params: dict) -> dict: return params -class ConfigTask(_NamedBaseModel): - """ - To create an instance of a task defined in a workflow file - """ - - # TODO: This list is too large. We should start with the set of supported - # keywords and extend it as we support more - command: str - command_option: str | None = None - input_arg_options: dict[str, str] | None = None - parameters: list[str] = [] - host: str | None = None - account: str | None = None - plugin: str | None = None - config: str | None = None - uenv: dict | None = None - nodes: int | None = None - walltime: str | None = None - src: str | None = None - conda_env: str | None = None - - def __init__(self, /, **data): - # We have to treat root special as it does not typically define a command - if "ROOT" in data and "command" not in data["ROOT"]: - data["ROOT"]["command"] = "ROOT_PLACEHOLDER" - super().__init__(**data) - - @field_validator("command") - @classmethod - def expand_env_vars(cls, value: str) -> str: - """Expands any environment variables in the value""" - return expandvars(value) - - @field_validator("walltime") - @classmethod - def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None: - """Converts a string of form "%H:%M:%S" to a time.time_struct""" - return None if value is None else time.strptime(value, "%H:%M:%S") - - -class DataBaseModel(_NamedBaseModel): - """ - To create an instance of a data defined in a workflow file. - """ - - type: str - src: str - format: str | None = None - parameters: list[str] = [] - - @field_validator("type") - @classmethod - def is_file_or_dir(cls, value: str) -> str: - """.""" - if value not in ["file", "dir"]: - msg = "Must be one of 'file' or 'dir'." - raise ValueError(msg) - return value - - @property - def available(self) -> bool: - return isinstance(self, ConfigAvailableData) - - -class ConfigAvailableData(DataBaseModel): +class ConfigCycleTaskInput(TargetNodesBaseModel): pass -class ConfigGeneratedData(DataBaseModel): +class ConfigCycleTaskWaitOn(TargetNodesBaseModel): pass -class ConfigData(BaseModel): - """To create the container of available and generated data""" - - available: list[ConfigAvailableData] = [] - generated: list[ConfigGeneratedData] = [] - - -class ConfigCycleTaskWaitOn(_NamedBaseModel, _LagDateBaseModel): - """ - To create an instance of a input or output in a task in a cycle defined in a workflow file. - """ - - # TODO: Move to "wait_on" keyword in yaml instead of "depend" - name: str # name of the task it waits on - cycle_name: str | None = None - - -class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel): - """ - To create an instance of an input in a task in a cycle defined in a workflow file. - - For example: - - .. yaml - - - my_input: - date: ... - lag: ... - """ - - arg_option: str | None = None - - class ConfigCycleTaskOutput(_NamedBaseModel): """ To create an instance of an output in a task in a cycle defined in a workflow file. @@ -324,6 +230,85 @@ def check_period_is_not_negative_or_zero(self) -> ConfigCycle: return self +class ConfigTask(_NamedBaseModel): + """ + To create an instance of a task defined in a workflow file + """ + + # TODO: This list is too large. We should start with the set of supported + # keywords and extend it as we support more + command: str + command_option: str | None = None + input_arg_options: dict[str, str] | None = None + parameters: list[str] = [] + host: str | None = None + account: str | None = None + plugin: str | None = None + config: str | None = None + uenv: dict | None = None + nodes: int | None = None + walltime: str | None = None + src: str | None = None + conda_env: str | None = None + + def __init__(self, /, **data): + # We have to treat root special as it does not typically define a command + if "ROOT" in data and "command" not in data["ROOT"]: + data["ROOT"]["command"] = "ROOT_PLACEHOLDER" + super().__init__(**data) + + @field_validator("command") + @classmethod + def expand_env_vars(cls, value: str) -> str: + """Expands any environment variables in the value""" + return expandvars(value) + + @field_validator("walltime") + @classmethod + def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None: + """Converts a string of form "%H:%M:%S" to a time.time_struct""" + return None if value is None else time.strptime(value, "%H:%M:%S") + + +class DataBaseModel(_NamedBaseModel): + """ + To create an instance of a data defined in a workflow file. + """ + + type: str + src: str + format: str | None = None + parameters: list[str] = [] + + @field_validator("type") + @classmethod + def is_file_or_dir(cls, value: str) -> str: + """.""" + if value not in ["file", "dir"]: + msg = "Must be one of 'file' or 'dir'." + raise ValueError(msg) + return value + + @property + def available(self) -> bool: + return isinstance(self, ConfigAvailableData) + + +class ConfigAvailableData(DataBaseModel): + pass + + +class ConfigGeneratedData(DataBaseModel): + pass + + +class ConfigData(BaseModel): + """To create the container of available and generated data""" + + available: list[ConfigAvailableData] = [] + generated: list[ConfigGeneratedData] = [] + + class ConfigWorkflow(BaseModel): name: str | None = None cycles: list[ConfigCycle]