Skip to content

Commit

Permalink
Plugin task subclasses (#59)
Browse files Browse the repository at this point in the history
Refactor `core.py` in order to allow for `Task` subclassing.
- `core.py` becomes a module containing `graph_items.py`, `workflow.py` and `_tasks`
- `Task` has a metaclass called `Plugin` to register all its subclasses
- `Task` subclasses are stored in core._tasks
- `Task.from_config` generates the right instance of the `Task` subclass
- The logic for iterating over dates and parameters when constructing the IR has been moved to `workflow.py`
- Pydantic checking for `Task` subclasses is not yet implemented. It will either go to the parsing part, like for the `Task` baseclass or the class itself to avoid code duplication.
  • Loading branch information
leclairm authored Dec 5, 2024
1 parent 602c3fe commit 65ea029
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 169 deletions.
4 changes: 2 additions & 2 deletions src/sirocco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import parsing
from . import core, parsing

__all__ = ["parsing"]
__all__ = ["parsing", "core"]

__version__ = "0.0.0-dev0"
4 changes: 4 additions & 0 deletions src/sirocco/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .graph_items import Cycle, Data, GraphItem, Task
from .workflow import Workflow

__all__ = ["Workflow", "GraphItem", "Data", "Task", "Cycle"]
3 changes: 3 additions & 0 deletions src/sirocco/core/_tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import icon_task, shell_task

__all__ = ["icon_task", "shell_task"]
13 changes: 13 additions & 0 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import ClassVar

from sirocco.core.graph_items import Task


@dataclass
class IconTask(Task):
plugin: ClassVar[str] = "icon"

namelists: dict = field(default_factory=dict)
16 changes: 16 additions & 0 deletions src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import ClassVar

from sirocco.core.graph_items import Task


@dataclass
class ShellTask(Task):
plugin: ClassVar[str] = "shell"

command: str | None = None
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
src: str | None = None
211 changes: 70 additions & 141 deletions src/sirocco/core.py → src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,120 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from itertools import chain, product
from typing import TYPE_CHECKING, Any, Self

from sirocco.parsing._yaml_data_models import (
ConfigCycleTask,
ConfigTask,
ConfigWorkflow,
load_workflow_config,
)
from typing import TYPE_CHECKING, Any, ClassVar, Self

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime

from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel, TargetNodesBaseModel

logging.basicConfig()
logger = logging.getLogger(__name__)
from sirocco.parsing._yaml_data_models import ConfigCycleTask, ConfigTask, DataBaseModel, TargetNodesBaseModel


@dataclass
class GraphItem:
"""base class for Data Tasks and Cycles"""

color: ClassVar[str]

name: str
color: str
coordinates: dict = field(default_factory=dict)

@classmethod
def iter_coordinates(cls, param_refs: list, parameters: dict, date: datetime) -> Iterator[dict]:
space = ({} if date is None else {"date": [date]}) | {k: parameters[k] for k in param_refs}
yield from (dict(zip(space.keys(), x)) for x in product(*space.values()))

class Plugin(type):
"""Metaclass for plugin tasks inheriting from Task
Used to register all plugin task classes"""

classes: dict[str, type] | None = None

def __new__(cls, name, bases, dct):
if cls.classes is None:
cls.classes = {}
plugin = dct["plugin"]
if plugin in cls.classes:
msg = f"Task for plugin {plugin} already set"
raise ValueError(msg)
return_cls = super().__new__(cls, name, bases, dct)
cls.classes[plugin] = return_cls
return return_cls


@dataclass
class Task(GraphItem):
class Task(GraphItem, metaclass=Plugin):
"""Internal representation of a task node"""

color: str = "light_red"
workflow: Workflow | None = None
outputs: list[Data] = field(default_factory=list)
plugin: ClassVar[str] = "_BASE_TASK"
color: ClassVar[str] = "light_red"

inputs: list[Data] = field(default_factory=list)
outputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
# TODO: This list is too long. We should start with the set of supported
# keywords and extend it as we support more
command: str | None = None
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
host: str | None = None
account: str | None = None
plugin: str | None = None
config: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None
src: str | None = None
conda_env: str | None = None

# use classmethod instead of custom init
@classmethod
def from_config(
cls,
config: ConfigTask,
workflow_parameters: dict[str, list],
coordinates: dict[str, Any],
datastore: Store,
graph_spec: ConfigCycleTask,
workflow: Workflow,
*,
date: datetime | None = None,
) -> Iterator[Self]:
for coordinates in cls.iter_coordinates(config.parameters, workflow_parameters, date):
inputs = list(
chain(
*(workflow.data.iter_from_cycle_spec(input_spec, coordinates) for input_spec in graph_spec.inputs)
) -> Task:
inputs = list(
chain(*(datastore.iter_from_cycle_spec(input_spec, coordinates) for input_spec in graph_spec.inputs))
)
outputs = [datastore[output_spec.name, coordinates] for output_spec in graph_spec.outputs]
# use the fact that pydantic models can be turned into dicts easily
cls_config = dict(config)
del cls_config["parameters"]
del cls_config["plugin"]
new = Plugin.classes[config.plugin](
coordinates=coordinates,
inputs=inputs,
outputs=outputs,
**cls_config,
) # this works because dataclass has generated this init for us

# Store for actual linking in link_wait_on_tasks() once all tasks are created
new._wait_on_specs = graph_spec.wait_on # noqa: SLF001 we don't have access to self in a dataclass
# and setting an underscored attribute from
# the class itself raises SLF001

return new

def link_wait_on_tasks(self, taskstore: Store):
self.wait_on = list(
chain(
*(
taskstore.iter_from_cycle_spec(wait_on_spec, self.coordinates)
for wait_on_spec in self._wait_on_specs
)
)

outputs = [workflow.data[output_spec.name, coordinates] for output_spec in graph_spec.outputs]

# use the fact that pydantic models can be turned into dicts easily
cls_config = dict(config)
del cls_config["parameters"]

new = cls(
coordinates=coordinates,
inputs=inputs,
outputs=outputs,
workflow=workflow,
**cls_config,
) # this works because dataclass has generated this init for us

# Store for actual linking in link_wait_on_tasks() once all tasks are created
new._wait_on_specs = graph_spec.wait_on # noqa: SLF001 we don't have access to self in a dataclass
# and setting an underscored attribute from
# the class itself raises SLF001

yield new

def link_wait_on_tasks(self):
self.wait_on: list[Task] = []
for wait_on_spec in self._wait_on_specs:
self.wait_on.extend(
task
for task in self.workflow.tasks.iter_from_cycle_spec(wait_on_spec, self.coordinates)
if task is not None
)
)


@dataclass(kw_only=True)
class Data(GraphItem):
"""Internal representation of a data node"""

color: str = "light_blue"
color: ClassVar[str] = "light_blue"

type: str
src: str
available: bool

@classmethod
def from_config(
cls, config: DataBaseModel, workflow_parameters: dict[str, list], *, date: datetime | None = None
) -> Iterator[Self]:
for coordinates in cls.iter_coordinates(config.parameters, workflow_parameters, date):
yield cls(
name=config.name,
type=config.type,
src=config.src,
available=config.available,
coordinates=coordinates,
)
def from_config(cls, config: DataBaseModel, coordinates: dict) -> Self:
return cls(
name=config.name,
type=config.type,
src=config.src,
available=config.available,
coordinates=coordinates,
)


@dataclass(kw_only=True)
Expand Down Expand Up @@ -186,7 +171,7 @@ def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> I
if "date" not in self._dims and (spec.lag or spec.date):
msg = f"Array {self._name} has no date dimension, cannot be referenced by dates"
raise ValueError(msg)
if "date" in self._dims and reference.get("date") is None and spec.date is []:
if "date" in self._dims and reference.get("date") is None and len(spec.date) == 0:
msg = f"Array {self._name} has a date dimension, must be referenced by dates"
raise ValueError(msg)

Expand Down Expand Up @@ -256,59 +241,3 @@ def __iter__(self) -> Iterator[GraphItem]:
yield from item
else:
yield item


class Workflow:
"""Internal reprensentation of a workflow"""

def __init__(self, workflow_config: ConfigWorkflow) -> None:
self.name = workflow_config.name
self.tasks = Store()
self.data = Store()
self.cycles = Store()

# 1 - create availalbe data nodes
for data_config in workflow_config.data.available:
for data in Data.from_config(data_config, workflow_config.parameters, date=None):
self.data.add(data)

# 2 - create output data nodes
for cycle_config in workflow_config.cycles:
for date in self.cycle_dates(cycle_config):
for task_ref in cycle_config.tasks:
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = workflow_config.data_dict[data_name]
for data in Data.from_config(data_config, workflow_config.parameters, date=date):
self.data.add(data)

# 3 - create cycles and tasks
for cycle_config in workflow_config.cycles:
cycle_name = cycle_config.name
for date in self.cycle_dates(cycle_config):
cycle_tasks = []
for task_graph_spec in cycle_config.tasks:
task_name = task_graph_spec.name
task_config = workflow_config.task_dict[task_name]
for task in Task.from_config(
task_config, workflow_config.parameters, task_graph_spec, workflow=self, date=date
):
self.tasks.add(task)
cycle_tasks.append(task)
coordinates = {} if date is None else {"date": date}
self.cycles.add(Cycle(name=cycle_name, tasks=cycle_tasks, coordinates=coordinates))

# 4 - Link wait on tasks
for task in self.tasks:
task.link_wait_on_tasks()

@staticmethod
def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]:
yield (date := cycle_config.start_date)
if cycle_config.period is not None:
while (date := date + cycle_config.period) < cycle_config.end_date:
yield date

@classmethod
def from_yaml(cls, config_path: str):
return cls(load_workflow_config(config_path))
Loading

0 comments on commit 65ea029

Please sign in to comment.