Skip to content

Commit

Permalink
Split parsing classes between declaration and validation (#63)
Browse files Browse the repository at this point in the history
* split Task and Data classes in parsing in 2 classes:
   - one for declaration of fields that are needed in core classes
   - one for pydantic validation
  This way core classes can inherit from the declaration classes and we avoid duplication.

* remove metaclass for core.graph_items.Task
  this was overkill and has been replaced by an __init_subclass__ class method
----
Co-authored-by: Alexander Goscinski <[email protected]>
  • Loading branch information
leclairm authored Dec 13, 2024
1 parent 9d63123 commit 7ee6861
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 167 deletions.
11 changes: 4 additions & 7 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import ClassVar, Literal
from dataclasses import dataclass

from sirocco.core.graph_items import Task
from sirocco.parsing import ConfigIconTask
from sirocco.parsing._yaml_data_models import ConfigIconTaskSpecs


@dataclass
class IconTask(Task):
plugin: ClassVar[Literal[ConfigIconTask.plugin]] = ConfigIconTask.plugin

namelists: dict = field(default_factory=dict)
class IconTask(ConfigIconTaskSpecs, Task):
pass
13 changes: 3 additions & 10 deletions src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import ClassVar, Literal

from sirocco.core.graph_items import Task
from sirocco.parsing import ConfigShellTask
from sirocco.parsing._yaml_data_models import ConfigShellTaskSpecs


@dataclass
class ShellTask(Task):
# plugin: ClassVar[str] = "shell"
plugin: ClassVar[Literal[ConfigShellTask.plugin]] = ConfigShellTask.plugin

command: str | None = None
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
src: str | None = None
class ShellTask(ConfigShellTaskSpecs, Task):
pass
84 changes: 31 additions & 53 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from dataclasses import dataclass, field
from itertools import chain, product
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self
from typing import TYPE_CHECKING, Any, ClassVar, Self

from sirocco.parsing import ConfigBaseTask
from sirocco.parsing._yaml_data_models import (
ConfigAvailableData,
ConfigBaseDataSpecs,
ConfigBaseTaskSpecs,
)

if TYPE_CHECKING:
from collections.abc import Iterator

from sirocco.parsing._yaml_data_models import ConfigCycleTask, ConfigTask, DataBaseModel, TargetNodesBaseModel
from sirocco.parsing._yaml_data_models import ConfigBaseData, ConfigCycleTask, ConfigTask, TargetNodesBaseModel


@dataclass
Expand All @@ -19,47 +23,26 @@ class GraphItem:
color: ClassVar[str]

name: str
coordinates: dict = field(default_factory=dict)


class TaskPlugin(type):
"""Metaclass for plugin tasks inheriting from Task
Used to register all plugin task classes"""

classes: ClassVar[dict[str, type[Task]]] = {}

def __new__(cls, name: str, bases: tuple, attr: dict):
"""Invoked on class definition when used as metaclass.
name: The name of the class
bases: The base classes from the class
attr: The attributes of the class
"""
plugin = attr["plugin"]
if plugin in cls.classes:
msg = f"Task for plugin {plugin} already set"
raise ValueError(msg)
return_cls = super().__new__(cls, name, bases, attr)
cls.classes[plugin] = return_cls
return return_cls
coordinates: dict


@dataclass
class Task(GraphItem, metaclass=TaskPlugin):
class Task(ConfigBaseTaskSpecs, GraphItem):
"""Internal representation of a task node"""

plugin: ClassVar[Literal[ConfigBaseTask.plugin]] = ConfigBaseTask.plugin
plugin_classes: ClassVar[dict[str, type]] = field(default={}, repr=False)
color: ClassVar[str] = field(default="light_red", repr=False)

inputs: list[Data] = field(default_factory=list)
outputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
host: str | None = None
account: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.plugin in Task.plugin_classes:
msg = f"Task for plugin {cls.plugin} already set"
raise ValueError(msg)
Task.plugin_classes[cls.plugin] = cls

@classmethod
def from_config(
Expand All @@ -76,8 +59,8 @@ def from_config(
# use the fact that pydantic models can be turned into dicts easily
cls_config = dict(config)
del cls_config["parameters"]
if (plugin_cls := TaskPlugin.classes.get(type(config).plugin, None)) is None:
msg = f"Plugin {config.plugin!r} is not supported."
if (plugin_cls := Task.plugin_classes.get(type(config).plugin, None)) is None:
msg = f"Plugin {type(config).plugin!r} is not supported."
raise ValueError(msg)

new = plugin_cls(
Expand Down Expand Up @@ -105,32 +88,31 @@ def link_wait_on_tasks(self, taskstore: Store):
)


@dataclass(kw_only=True)
class Data(GraphItem):
@dataclass
class Data(ConfigBaseDataSpecs, GraphItem):
"""Internal representation of a data node"""

color: ClassVar[str] = "light_blue"
color: ClassVar[str] = field(default="light_blue", repr=False)

type: str
src: str
available: bool
available: bool | None = None # must get a default value because of dataclass inheritence

@classmethod
def from_config(cls, config: DataBaseModel, coordinates: dict) -> Self:
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
return cls(
name=config.name,
type=config.type,
src=config.src,
available=config.available,
available=isinstance(config, ConfigAvailableData),
coordinates=coordinates,
)


@dataclass(kw_only=True)
@dataclass
class Cycle(GraphItem):
"""Internal reprenstation of a cycle"""

color: str = "light_green"
color: ClassVar[str] = field(default="light_green", repr=False)

tasks: list[Task]


Expand Down Expand Up @@ -206,10 +188,10 @@ def __iter__(self) -> Iterator[GraphItem]:


class Store:
"""Container for Array or unique items"""
"""Container for GraphItem Arrays"""

def __init__(self):
self._dict: dict[str, Array | GraphItem] = {}
self._dict: dict[str, Array] = {}

def add(self, item) -> None:
if not isinstance(item, GraphItem):
Expand Down Expand Up @@ -245,8 +227,4 @@ def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> I
yield from self._dict[spec.name].iter_from_cycle_spec(spec, reference)

def __iter__(self) -> Iterator[GraphItem]:
for item in self._dict.values():
if isinstance(item, Array):
yield from item
else:
yield item
yield from chain(*(self._dict.values()))
6 changes: 0 additions & 6 deletions src/sirocco/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from ._yaml_data_models import (
ConfigBaseTask,
ConfigIconTask,
ConfigShellTask,
load_workflow_config,
)

__all__ = [
"load_workflow_config",
"ConfigBaseTask",
"ConfigShellTask",
"ConfigIconTask",
]
81 changes: 43 additions & 38 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal
Expand All @@ -9,9 +10,6 @@
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator, model_validator

# from sirocco.core._tasks.icon_task import IconTask
# from sirocco.core._tasks.shell_task import ShellTask
# from sirocco.core.graph_items import Task
from sirocco.parsing._utils import TimeUtils


Expand Down Expand Up @@ -244,27 +242,21 @@ def check_period_is_not_negative_or_zero(self) -> ConfigCycle:
return self


class ConfigBaseTask(_NamedBaseModel):
"""
config for genric task, no plugin specifics
"""

# this class could be used for constructing a root task we therefore need a
# default value for the plugin as it is not required
# plugin: Literal[Task.plugin] | None = None
plugin: ClassVar[Literal["_BASE_TASK_"]] = "_BASE_TASK_"
parameters: list[str] = Field(default_factory=list)
@dataclass
class ConfigBaseTaskSpecs:
host: str | None = None
account: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None

def __init__(self, /, **data):
# We have to treat root special as it does not typically define a command
if "ROOT" in data and "command" not in data["ROOT"]:
data["ROOT"]["command"] = "ROOT_PLACEHOLDER"
super().__init__(**data)

class ConfigBaseTask(_NamedBaseModel, ConfigBaseTaskSpecs):
"""
config for genric task, no plugin specifics
"""

parameters: list[str] = Field(default_factory=list)

@field_validator("walltime")
@classmethod
Expand All @@ -273,29 +265,46 @@ def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None:
return None if value is None else time.strptime(value, "%H:%M:%S")


class ConfigShellTask(ConfigBaseTask):
# plugin: Literal[ShellTask.plugin]
class ConfigRootTask(ConfigBaseTask):
plugin: ClassVar[Literal["_root"]] = "_root"


@dataclass
class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
command: str
command: str = ""
command_option: str = ""
input_arg_options: dict[str, str] = Field(default_factory=dict)
input_arg_options: dict[str, str] = Field(default_factory=dict) # noqa: RUF009 Field needed
# for child class doing pydantic parsing
src: str | None = None


class ConfigIconTask(ConfigBaseTask):
# plugin: Literal[IconTask.plugin]
class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
pass


@dataclass
class ConfigIconTaskSpecs:
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: dict[str, Any]
namelists: dict[str, str] | None = None


class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
pass


class DataBaseModel(_NamedBaseModel):
@dataclass
class ConfigBaseDataSpecs:
type: str | None = None
src: str | None = None
format: str | None = None


class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs):
"""
To create an instance of a data defined in a workflow file.
"""

type: str
src: str
format: str | None = None
parameters: list[str] = []

@field_validator("type")
Expand All @@ -307,16 +316,12 @@ def is_file_or_dir(cls, value: str) -> str:
raise ValueError(msg)
return value

@property
def available(self) -> bool:
return isinstance(self, ConfigAvailableData)


class ConfigAvailableData(DataBaseModel):
class ConfigAvailableData(ConfigBaseData):
pass


class ConfigGeneratedData(DataBaseModel):
class ConfigGeneratedData(ConfigBaseData):
pass


Expand All @@ -330,16 +335,16 @@ class ConfigData(BaseModel):
def get_plugin_from_named_base_model(data: dict) -> str:
name_and_specs = _NamedBaseModel.merge_name_and_specs(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigBaseTask.plugin
return ConfigRootTask.plugin
plugin = name_and_specs.get("plugin", None)
if plugin is None:
msg = "Could not find plugin name in {data}"
msg = f"Could not find plugin name in {data}"
raise ValueError(msg)
return plugin


ConfigTask = Annotated[
Annotated[ConfigBaseTask, Tag(ConfigBaseTask.plugin)]
Annotated[ConfigRootTask, Tag(ConfigRootTask.plugin)]
| Annotated[ConfigIconTask, Tag(ConfigIconTask.plugin)]
| Annotated[ConfigShellTask, Tag(ConfigShellTask.plugin)],
Discriminator(get_plugin_from_named_base_model),
Expand Down
Loading

0 comments on commit 7ee6861

Please sign in to comment.