diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88110b29..c50d47f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,4 +18,5 @@ repos: rev: v1.9.0 hooks: - id: mypy - args: [--strict, --ignore-missing-imports] + args: [--strict, --install-types, --non-interactive] + additional_dependencies: [sympy, attrs, pytest, click, dask] diff --git a/example/extra.py b/example/extra.py index 2ef1744c..053b2f66 100644 --- a/example/extra.py +++ b/example/extra.py @@ -7,26 +7,31 @@ JUMP: int = 10 + @task() def increase(num: int) -> int: """Add globally-configured integer JUMP to a number.""" return num + JUMP + @task() def increment(num: int) -> int: """Increment an integer.""" return num + 1 + @task() def double(num: int) -> int: """Double an integer.""" return 2 * num + @task() def mod10(num: int) -> int: """Calculate supplied integer modulo 10.""" return num % 10 + @task() def sum(left: int, right: int) -> int: """Add two integers.""" diff --git a/example/workflow_complex.py b/example/workflow_complex.py index 71762518..4f9f0182 100644 --- a/example/workflow_complex.py +++ b/example/workflow_complex.py @@ -3,7 +3,7 @@ Useful as an example of a workflow with a nested task. ```sh -$ python -m dewret workflow_complex.py --pretty run +$ python -m dewret workflow_complex.py --pretty nested_workflow ``` """ @@ -12,12 +12,10 @@ STARTING_NUMBER: int = 23 + @nested_task() def nested_workflow() -> int | float: """Creates a graph of task calls.""" left = double(num=increase(num=STARTING_NUMBER)) right = increase(num=increase(num=17)) - return sum( - left=left, - right=right - ) + return sum(left=left, right=right) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 04f54e83..059ec406 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -29,13 +29,28 @@ from .renderers.cwl import render as cwl_render from .tasks import Backend, construct + @click.command() -@click.option("--pretty", is_flag=True, show_default=True, default=False, help="Pretty-print output where possible.") -@click.option("--backend", type=click.Choice(list(Backend.__members__)), show_default=True, default=Backend.DASK.name, help="Backend to use for workflow evaluation.") +@click.option( + "--pretty", + is_flag=True, + show_default=True, + default=False, + help="Pretty-print output where possible.", +) +@click.option( + "--backend", + type=click.Choice(list(Backend.__members__)), + show_default=True, + default=Backend.DASK.name, + help="Backend to use for workflow evaluation.", +) @click.argument("workflow_py") @click.argument("task") @click.argument("arguments", nargs=-1) -def render(workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend) -> None: +def render( + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend +) -> None: """Render a workflow. WORKFLOW_PY is the Python file containing workflow. @@ -49,14 +64,24 @@ def render(workflow_py: str, task: str, arguments: list[str], pretty: bool, back kwargs = {} for arg in arguments: if ":" not in arg: - raise RuntimeError("Arguments should be specified as key:val, where val is a JSON representation of the argument") + raise RuntimeError( + "Arguments should be specified as key:val, where val is a JSON representation of the argument" + ) key, val = arg.split(":", 1) kwargs[key] = json.loads(val) - cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) - if pretty: - yaml.dump(cwl, sys.stdout, indent=2) + try: + cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) + except Exception as exc: + import traceback + + print(exc, exc.__cause__, exc.__context__) + traceback.print_exc() else: - print(cwl) + if pretty: + yaml.dump(cwl, sys.stdout, indent=2) + else: + print(cwl) + render() diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index f95e7744..7e6b33da 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -21,6 +21,7 @@ from dewret.workflow import Workflow, Lazy, StepReference, Target from typing import Protocol, runtime_checkable, Any, cast + @runtime_checkable class Delayed(Protocol): """Description of a dask `delayed`. @@ -48,6 +49,7 @@ def compute(self, __workflow__: Workflow | None) -> StepReference[Any]: """ ... + def unwrap(task: Lazy) -> Target: """Unwraps a lazy-evaluated function to get the function. @@ -69,6 +71,7 @@ def unwrap(task: Lazy) -> Target: raise RuntimeError("Task is not actually a callable") return cast(Target, task._obj) + def is_lazy(task: Any) -> bool: """Checks if a task is really a lazy-evaluated function for this backend. @@ -80,7 +83,10 @@ def is_lazy(task: Any) -> bool: """ return isinstance(task, Delayed) + lazy = delayed + + def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: """Execute a task as the output of a workflow. @@ -92,5 +98,7 @@ def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: """ # We need isinstance to reassure type-checker. if not isinstance(task, Delayed) or not is_lazy(task): - raise RuntimeError("Cannot mix backends") + raise RuntimeError( + f"{task} is not a dask delayed, perhaps you tried to mix backends?" + ) return task.compute(__workflow__=workflow) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 1c9e4a8b..2b2dae3f 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -35,31 +35,39 @@ from functools import cached_property from collections.abc import Callable from typing import Any, ParamSpec, TypeVar, cast +from types import TracebackType from attrs import has as attrs_has from dataclasses import is_dataclass +import traceback -from .utils import is_raw +from .utils import is_raw, make_traceback from .workflow import ( StepReference, ParameterReference, Workflow, Lazy, + LazyEvaluation, Target, LazyFactory, StepExecution, merge_workflows, Parameter, param, - is_task + Task, + is_task, ) from .backends._base import BackendModule + class Backend(Enum): """Stringy enum representing available backends.""" + DASK = "dask" + DEFAULT_BACKEND = Backend.DASK + class TaskManager: """Overarching backend-agnostic task manager. @@ -106,7 +114,9 @@ def backend(self) -> BackendModule: if backend is None: backend = self.set_backend(DEFAULT_BACKEND) - backend_mod = importlib.import_module(f".backends.backend_{backend.value}", "dewret") + backend_mod = importlib.import_module( + f".backends.backend_{backend.value}", "dewret" + ) return backend_mod def make_lazy(self) -> LazyFactory: @@ -158,11 +168,13 @@ def ensure_lazy(self, task: Any) -> Lazy | None: Returns: Original task, cast to a Lazy, or None. """ - if (task := self.ensure_lazy(task)) is None: - raise RuntimeError(f"Task passed to be evaluated, must be lazy-evaluatable, not {type(task)}.") - return cast(task, Lazy) if self.backend.is_lazy(task) else None + if isinstance(task, LazyEvaluation): + return self.ensure_lazy(task._fn) + return task if self.backend.is_lazy(task) else None - def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Workflow: + def __call__( + self, task: Any, simplify_ids: bool = False, **kwargs: Any + ) -> Workflow: """Execute the lazy evalution. Arguments: @@ -177,6 +189,7 @@ def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Work result = self.evaluate(task, workflow, **kwargs) return Workflow.from_result(result, simplify_ids=simplify_ids) + _manager = TaskManager() lazy = _manager.make_lazy ensure_lazy = _manager.ensure_lazy @@ -184,6 +197,44 @@ def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Work evaluate = _manager.evaluate construct = _manager + +class TaskException(Exception): + """Exception tied to a specific task. + + Primarily aimed at parsing issues, but this will ensure that + a message is shown with useful debug information for the + workflow writer. + """ + + def __init__( + self, + task: Task | Target, + dec_tb: TracebackType | None, + tb: TracebackType | None, + message: str, + *args: Any, + **kwargs: Any, + ): + """Create a TaskException for this exception. + + Args: + task: the Task causing the exception. + dec_tb: a traceback of the task declaration. + tb: a traceback of the original task call. + message: a message to show to the user. + *args: any other arguments accepted by Exception. + **kwargs: any other arguments accepted by Exception. + """ + if dec_tb: + frame = traceback.extract_tb(dec_tb)[-1] + self.add_note( + f"Task {task.__name__} declared in {frame.name} at {frame.filename}:{frame.lineno}\n" + f"{frame.line}" + ) + super().__init__(message) + self.__traceback__ = tb + + def nested_task() -> Callable[[Target], StepExecution]: """Shortcut for marking a task as nested. @@ -211,9 +262,14 @@ def nested_task() -> Callable[[Target], StepExecution]: """ return task(nested=True) + Param = ParamSpec("Param") RetType = TypeVar("RetType") -def task(nested: bool = False) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: + + +def task( + nested: bool = False, +) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: """Decorator factory abstracting backend's own task decorator. For example: @@ -244,56 +300,114 @@ def task(nested: bool = False) -> Callable[[Callable[Param, RetType]], Callable[ """ def _task(fn: Callable[Param, RetType]) -> Callable[Param, RetType]: - def _fn(*args: Any, __workflow__: Workflow | None = None, **kwargs: Param.kwargs) -> RetType: - # By marking any as the positional results list, we prevent unnamed results being - # passed at all. - if args: - raise TypeError( - f"Calling {fn.__name__}: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" - ) - - # Ensure that the passed arguments are, at least, a Python-match for the signature. - sig = inspect.signature(fn) - sig.bind(*args, **kwargs) - - workflows = [ - reference.__workflow__ - for reference in kwargs.values() - if hasattr(reference, "__workflow__") and reference.__workflow__ is not None - ] - if __workflow__ is not None: - workflows.insert(0, __workflow__) - if workflows: - workflow = merge_workflows(*workflows) - else: - workflow = Workflow() - original_kwargs = dict(kwargs) - for var, value in inspect.getclosurevars(fn).globals.items(): - if var in kwargs: - raise TypeError("Captured parameter (global variable in task) shadows an argument") - if isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) - elif is_raw(value): - parameter = param(var, value) - kwargs[var] = ParameterReference(workflow, parameter) - elif is_task(value): - if not nested: - raise TypeError("You reference a task inside another task, but it is not a nested_task - this will not be found!") - elif attrs_has(value) or is_dataclass(value): - ... - elif nested: - raise NotImplementedError(f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}") - if nested: - lazy_fn = cast(Lazy, fn(**original_kwargs)) - step_reference = evaluate(lazy_fn, __workflow__=workflow) - if isinstance(step_reference, StepReference): - return cast(RetType, step_reference) - raise TypeError("Nested tasks must return a step reference, to ensure graph makes sense.") - return cast(RetType, workflow.add_step(fn, kwargs)) - _fn.__step_expression__ = True # type: ignore - return lazy()(_fn) + declaration_tb = make_traceback() + + def _fn( + *args: Any, + __workflow__: Workflow | None = None, + __traceback__: TracebackType | None = None, + **kwargs: Param.kwargs, + ) -> RetType: + try: + # Ensure that all arguments are passed as keyword args and prevent positional args. + # passed at all. + if args: + raise TypeError( + f""" + Calling {fn.__name__}: Arguments must _always_ be named, + e.g. my_task(num=1) not my_task(1)\n" + + @task() + def add_numbers(left: int, right: int): + return left + right + + construct(add_numbers(left=3, right=5)) + """ + ) + + # Ensure that the passed arguments are, at least, a Python-match for the signature. + sig = inspect.signature(fn) + sig.bind(*args, **kwargs) + + workflows = [ + reference.__workflow__ + for reference in kwargs.values() + if hasattr(reference, "__workflow__") + and reference.__workflow__ is not None + ] + if __workflow__ is not None: + workflows.insert(0, __workflow__) + if workflows: + workflow = merge_workflows(*workflows) + else: + workflow = Workflow() + original_kwargs = dict(kwargs) + for var, value in inspect.getclosurevars(fn).globals.items(): + # This error is redundant as it triggers a SyntaxError in Python. + # Note: the following test duplicates a syntax error. + # if var in kwargs: + # raise TypeError( + # "Captured parameter {var} (global variable in task) shadows an argument" + # ) + if isinstance(value, Parameter): + kwargs[var] = ParameterReference(workflow, value) + elif is_raw(value): + parameter = param(var, value) + kwargs[var] = ParameterReference(workflow, parameter) + elif is_task(value): + if not nested: + raise TypeError( + f""" + You referenced a task {var} inside another task {fn.__name__}, but it is not a nested_task + - this will not be found! + + @task + def {var}(...) -> ...: + ... + + @nested_task <<<--- likely what you want + def {fn.__name__}(...) -> ...: + ... + {var}(...) + ... + """ + ) + elif attrs_has(value) or is_dataclass(value): + ... + elif nested: + raise NotImplementedError( + f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" + ) + if nested: + output = fn(**original_kwargs) + lazy_fn = ensure_lazy(output) + if lazy_fn is None: + raise TypeError( + f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." + ) + step_reference = evaluate(lazy_fn, __workflow__=workflow) + if isinstance(step_reference, StepReference): + return cast(RetType, step_reference) + raise TypeError( + f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." + ) + return cast(RetType, workflow.add_step(fn, kwargs)) + except TaskException as exc: + raise exc + except Exception as exc: + raise TaskException( + fn, + declaration_tb, + __traceback__, + exc.args[0] if exc.args else "Could not call task {fn.__name__}", + ) from exc + + _fn.__step_expression__ = True # type: ignore + return LazyEvaluation(lazy()(_fn)) + return _task + def set_backend(backend: Backend) -> None: """Choose a backend. diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 74d272e7..cebb14bc 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -19,6 +19,8 @@ import hashlib import json +import sys +from types import FrameType, TracebackType from typing import Any, cast, Union, Protocol, ClassVar from collections.abc import Sequence, Mapping @@ -26,15 +28,34 @@ RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] + class DataclassProtocol(Protocol): """Format of a dataclass. Since dataclasses do not expose a proper type, we use this to represent them. """ + __dataclass_fields__: ClassVar[dict[str, Any]] +def make_traceback(skip: int = 2) -> TracebackType | None: + """Creates a traceback for the current frame. + + Necessary to allow tracebacks to be prepped for + potential errors in lazy-evaluated functions. + + Args: + skip: number of frames to skip before starting traceback. + """ + frame: FrameType | None = sys._getframe(skip) + tb = None + while frame: + tb = TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + frame = frame.f_back + return tb + + def flatten(value: Any) -> RawType: """Takes a Raw-like structure and makes it RawType. @@ -68,6 +89,7 @@ def is_raw(value: Any) -> bool: # but recursive types are problematic. return isinstance(value, str | float | bool | bytes | int | None | list | dict) + def ensure_raw(value: Any) -> RawType | None: """Check if a variable counts as "raw". @@ -79,6 +101,7 @@ def ensure_raw(value: Any) -> RawType | None: # isinstance(var, RawType | list[RawType] | dict[str, RawType]) return cast(RawType, value) if is_raw(value) else None + def hasher(construct: FirmType) -> str: """Consistently hash a RawType or tuple structure. @@ -102,4 +125,3 @@ def hasher(construct: FirmType) -> str: hsh = hashlib.md5() hsh.update(construct_as_string.encode()) return hsh.hexdigest() - diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 48df4010..ffee8fea 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -25,13 +25,17 @@ from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter from typing import Protocol, Any, TypeVar, Generic, cast + from sympy import Symbol import logging logger = logging.getLogger(__name__) +from .utils import hasher, RawType, is_raw, make_traceback + +T = TypeVar("T") +RetType = TypeVar("RetType") -from .utils import hasher, RawType, is_raw @define class Raw: @@ -43,6 +47,7 @@ class Raw: Attributes: value: the real value, e.g. a `str`, `int`, ... """ + value: RawType def __hash__(self) -> int: @@ -58,21 +63,49 @@ def __repr__(self) -> str: value = str(self.value) return f"{type(self.value).__name__}|{value}" + class Lazy(Protocol): """Requirements for a lazy-evaluatable function.""" + __name__: str def __call__(self, *args: Any, **kwargs: Any) -> Any: """When called this should return a reference.""" ... + +class LazyEvaluation(Lazy, Generic[RetType]): + """Tracks a single evaluation of a lazy function.""" + + def __init__(self, fn: Callable[..., RetType]): + """Initialize an evaluation. + + Args: + fn: callable returning RetType, which this will return + also from it's __call__ method for consistency. + """ + self._fn: Callable[..., RetType] = fn + self.__name__ = fn.__name__ + + def __call__(self, *args: Any, **kwargs: Any) -> RetType: + """Wrapper around a lazy execution. + + Captures a traceback, for debugging if this does not work. + + WARNING: this is one of the few places that we would expect + dask distributed to break, if running outside a single process + is attempted. + """ + tb = make_traceback() + return self._fn(*args, **kwargs, __traceback__=tb) + + Target = Callable[..., Any] StepExecution = Callable[..., Lazy] LazyFactory = Callable[[Target], Lazy] -T = TypeVar("T") -class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] +class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] """Global parameter. Independent parameter that will be used when a task is spotted @@ -85,6 +118,7 @@ class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] __name__: name of the parameter. __default__: captured default value from the original value. """ + __name__: str __default__: T @@ -99,6 +133,7 @@ def __init__(self, name: str, default: T): self.__name__ = name self.__default__ = default + def param(name: str, default: T) -> T: """Create a parameter. @@ -109,6 +144,7 @@ def param(name: str, default: T) -> T: """ return cast(T, Parameter(name, default=default)) + class Task: """Named wrapper of a lazy-evaluatable function. @@ -133,6 +169,11 @@ def __init__(self, name: str, target: Lazy): self.name = name self.target = target + @property + def __name__(self) -> str: + """Name of the task.""" + return self.name + def __str__(self) -> str: """Stringify the Task, currently by returning the `name`.""" return self.name @@ -153,10 +194,8 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, Task): return False - return ( - self.name == other.name and - self.target == other.target - ) + return self.name == other.name and self.target == other.target + class Workflow: """Overarching workflow concept. @@ -172,6 +211,7 @@ class Workflow: `Task` wrappers they represent. result: target reference to evaluate, if yet present. """ + steps: list["Step"] tasks: MutableMapping[str, "Task"] result: StepReference[Any] | None @@ -193,10 +233,16 @@ def find_parameters(self) -> set[ParameterReference]: Returns: Set of all references to parameters across the steps. """ - return set().union(*({ - arg for arg in step.arguments.values() - if isinstance(arg, ParameterReference) - } for step in self.steps)) + return set().union( + *( + { + arg + for arg in step.arguments.values() + if isinstance(arg, ParameterReference) + } + for step in self.steps + ) + ) @property def _indexed_steps(self) -> dict[str, Step]: @@ -209,9 +255,7 @@ def _indexed_steps(self) -> dict[str, Step]: Returns: Mapping of steps by ID. """ - return { - step.id: step for step in self.steps - } + return {step.id: step for step in self.steps} @classmethod def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": @@ -231,13 +275,15 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": left_steps = left._indexed_steps right_steps = right._indexed_steps - for step_id in (left_steps.keys() & right_steps.keys()): + for step_id in left_steps.keys() & right_steps.keys(): left_steps[step_id].set_workflow(new) right_steps[step_id].set_workflow(new) if left_steps[step_id] != right_steps[step_id]: - raise RuntimeError(f"Two steps have same ID but do not match: {step_id}") + raise RuntimeError( + f"Two steps have same ID but do not match: {step_id}" + ) - for task_id in (left.tasks.keys() & right.tasks.keys()): + for task_id in left.tasks.keys() & right.tasks.keys(): if left.tasks[task_id] != right.tasks[task_id]: raise RuntimeError("Two tasks have same name but do not match") @@ -261,11 +307,7 @@ def remap(self, step_id: str) -> str: Returns: Same ID or a remapped name. """ - return ( - self._remapping.get(step_id, step_id) - if self._remapping else - step_id - ) + return self._remapping.get(step_id, step_id) if self._remapping else step_id def simplify_ids(self) -> None: """Work out mapping to simple ints from hashes. @@ -296,7 +338,9 @@ def register_task(self, fn: Lazy) -> Task: self.tasks[name] = task return task - def add_step(self, fn: Lazy, kwargs: dict[str, Raw | Reference]) -> StepReference[Any]: + def add_step( + self, fn: Lazy, kwargs: dict[str, Raw | Reference] + ) -> StepReference[Any]: """Append a step. Adds a step, for running a target with key-value arguments, @@ -307,11 +351,7 @@ def add_step(self, fn: Lazy, kwargs: dict[str, Raw | Reference]) -> StepReferenc kwargs: any key-value arguments to pass in the call. """ task = self.register_task(fn) - step = Step( - self, - task, - kwargs - ) + step = Step(self, task, kwargs) self.steps.append(step) return_type = step.return_type if return_type is inspect._empty: @@ -353,6 +393,7 @@ class WorkflowComponent: Attributes: __workflow__: the `Workflow` that this is tied to. """ + __workflow__: Workflow def __init__(self, workflow: Workflow): @@ -365,6 +406,7 @@ def __init__(self, workflow: Workflow): """ self.__workflow__ = workflow + class WorkflowLinkedComponent(Protocol): """Protocol for objects dynamically tied to a `Workflow`.""" @@ -380,6 +422,7 @@ def __workflow__(self) -> Workflow: """ ... + class Reference: """Superclass for all symbolic references to values.""" @@ -388,6 +431,7 @@ def name(self) -> str: """Referral name for this reference.""" raise NotImplementedError("Reference must provide a name") + class Step(WorkflowComponent): """Lazy-evaluated function call. @@ -398,11 +442,14 @@ class Step(WorkflowComponent): task: the `Task` being called in this step. arguments: key-value pairs of arguments to this step. """ + _id: str | None = None task: Task arguments: Mapping[str, Reference | Raw] - def __init__(self, workflow: Workflow, task: Task, arguments: Mapping[str, Reference | Raw]): + def __init__( + self, workflow: Workflow, task: Task, arguments: Mapping[str, Reference | Raw] + ): """Initialize a step. Args: @@ -414,17 +461,19 @@ def __init__(self, workflow: Workflow, task: Task, arguments: Mapping[str, Refer self.task = task self.arguments = {} for key, value in arguments.items(): - if ( - isinstance(value, Reference) or - isinstance(value, Raw) or - is_raw(value) - ): + if isinstance(value, Reference) or isinstance(value, Raw) or is_raw(value): # Avoid recursive type issues - if not isinstance(value, Reference) and not isinstance(value, Raw) and is_raw(value): + if ( + not isinstance(value, Reference) + and not isinstance(value, Raw) + and is_raw(value) + ): value = Raw(value) self.arguments[key] = value else: - raise RuntimeError(f"Non-references must be a serializable type: {key}>{value}") + raise RuntimeError( + f"Non-references must be a serializable type: {key}>{value}" + ) def __eq__(self, other: object) -> bool: """Is this the same step? @@ -435,9 +484,9 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Step): return False return ( - self.__workflow__ == other.__workflow__ and - self.task == other.task and - self.arguments == other.arguments + self.__workflow__ == other.__workflow__ + and self.task == other.task + and self.arguments == other.arguments ) def set_workflow(self, workflow: Workflow) -> None: @@ -487,7 +536,9 @@ def id(self) -> str: check_id = self._generate_id() if check_id != self._id: - raise RuntimeError(f"Cannot change a step after requesting its ID: {self.task}") + raise RuntimeError( + f"Cannot change a step after requesting its ID: {self.task}" + ) return self._id def _generate_id(self) -> str: @@ -500,6 +551,7 @@ def _generate_id(self) -> str: return f"{self.task}-{hasher(comp_tup)}" + class ParameterReference(Reference): """Reference to an individual `Parameter`. @@ -509,6 +561,7 @@ class ParameterReference(Reference): Attributes: parameter: `Parameter` referred to. """ + parameter: Parameter[RawType] workflow: Workflow @@ -569,11 +622,13 @@ def __eq__(self, other: object) -> bool: True if the other parameter reference is materially the same, otherwise False. """ return ( - isinstance(other, ParameterReference) and - self.parameter == other.parameter + isinstance(other, ParameterReference) and self.parameter == other.parameter ) + U = TypeVar("U") + + class StepReference(Generic[U], Reference): """Reference to an individual `Step`. @@ -583,6 +638,7 @@ class StepReference(Generic[U], Reference): Attributes: step: `Step` referred to. """ + step: Step _field: str | None typ: type[U] @@ -598,7 +654,9 @@ def field(self) -> str: """ return self._field or "out" - def __init__(self, workflow: Workflow, step: Step, typ: type[U], field: str | None = None): + def __init__( + self, workflow: Workflow, step: Step, typ: type[U], field: str | None = None + ): """Initialize the reference. Args: @@ -640,7 +698,9 @@ def __getattr__(self, attr: str) -> "StepReference"[Any]: resolve_types(self.typ) typ = getattr(attrs_fields(self.typ), attr).type elif is_dataclass(self.typ): - matched = [field for field in dataclass_fields(self.typ) if field.name == attr] + matched = [ + field for field in dataclass_fields(self.typ) if field.name == attr + ] if not matched: raise AttributeError(f"Field {attr} not present in dataclass") typ = matched[0].type @@ -649,12 +709,11 @@ def __getattr__(self, attr: str) -> "StepReference"[Any]: if typ: return self.__class__( - workflow=self.__workflow__, - step=self.step, - typ=typ, - field=attr + workflow=self.__workflow__, step=self.step, typ=typ, field=attr ) - raise RuntimeError("Can only get attribute of a StepReference representing an attrs-class or dataclass") + raise RuntimeError( + "Can only get attribute of a StepReference representing an attrs-class or dataclass" + ) @property def return_type(self) -> type[U]: @@ -683,6 +742,7 @@ def __workflow__(self) -> Workflow: """ return self.step.__workflow__ + def merge_workflows(*workflows: Workflow) -> Workflow: """Combine several workflows into one. @@ -699,6 +759,7 @@ def merge_workflows(*workflows: Workflow) -> Workflow: base = Workflow.assimilate(base, workflow) return base + def is_task(task: Lazy) -> bool: """Decide whether this is a task. @@ -712,11 +773,4 @@ def is_task(task: Lazy) -> bool: Returns: True if `task` is indeed a task. """ - from .tasks import unwrap - try: - func = unwrap(task) - if hasattr(func, "__step_expression__"): - return bool(func.__step_expression__) - except Exception: - ... - return False + return isinstance(task, LazyEvaluation) diff --git a/tests/test_errors.py b/tests/test_errors.py index bbfd29a2..3bbf3945 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,17 +1,51 @@ """Test for expected errors.""" import pytest -from dewret.tasks import construct, task, nested_task +from dewret.workflow import Task, Lazy +from dewret.tasks import construct, task, nested_task, TaskException -@task() + +@task() # This is expected to be the line number shown below. def add_task(left: int, right: int) -> int: """Adds two values and returns the result.""" return left + right + +ADD_TASK_LINE_NO = 8 + + @nested_task() def badly_add_task(left: int, right: int) -> int: """Badly attempts to add two numbers.""" - return add_task(left=left) # type: ignore + return add_task(left=left) # type: ignore + + +@task() +def badly_wrap_task() -> int: + """Sums two values but should not be calling a task.""" + return add_task(left=3, right=4) + + +class MyStrangeClass: + """Dummy class for tests.""" + + def __init__(self, task: Task): + """Dummy constructor for tests.""" + ... + + +@nested_task() +def unacceptable_object_usage() -> int: + """Invalid use of custom object within nested task.""" + return MyStrangeClass(add_task(left=3, right=4)) # type: ignore + + +@nested_task() +def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: + """Bad nested_task that fails to return a task.""" + add_task(left=3, right=4) + return 7 if int_not_global else ADD_TASK_LINE_NO + def test_missing_arguments_throw_error() -> None: """Check whether omitting a required argument will give an error. @@ -22,13 +56,17 @@ def test_missing_arguments_throw_error() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = add_task(left=3) # type: ignore - with pytest.raises(TypeError) as exc: + result = add_task(left=3) # type: ignore + with pytest.raises(TaskException) as exc: construct(result) + end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" + assert "Task add_task declared in at " in end_section + assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section + def test_missing_arguments_throw_error_in_nested_task() -> None: - """Check whether omitting a required argument will give an error. + """Check whether omitting a required argument within a nested_task will give an error. Since we do not run the original function, it is up to dewret to check that the signature is, at least, acceptable to Python. @@ -37,17 +75,87 @@ def test_missing_arguments_throw_error_in_nested_task() -> None: mismatch, but `mypy` should. You **must** type-check your code to catch these. """ result = badly_add_task(left=3, right=4) - with pytest.raises(TypeError) as exc: + with pytest.raises(TaskException) as exc: construct(result) + end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" + assert "def badly_add_task" in end_section + assert "Task add_task declared in at " in end_section + assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section + def test_positional_arguments_throw_error() -> None: - """Check whether we can produce simple CWL. + """Check whether unnamed (positional) arguments throw an error. We can use default and non-default arguments, but we expect them to _always_ be named. """ result = add_task(3, right=4) - with pytest.raises(TypeError) as exc: + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + .strip() + .startswith("Calling add_task: Arguments must _always_ be named") + ) + + +def test_nesting_non_nested_tasks_throws_error() -> None: + """Ensure nesting is only allow in nested_tasks. + + Nested tasks must be evaluated at construction time, and there + is no concept of task calls that are not resolved during construction, so + a task should not be called inside a non-nested task. + """ + result = badly_wrap_task() + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + .strip() + .startswith( + "You referenced a task add_task inside another task badly_wrap_task, but it is not a nested_task" + ) + ) + + +def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: + """Most entities cannot appear in a nested_task, ensure we catch them. + + Since the logic in nested tasks has to be embedded explicitly in the workflow, + complex types are not necessarily representable, and in most cases, we would not + be able to guarantee that the libraries, versions, etc. match. + + Note: this may be mitigated with sympy support, to some extent. + """ + result = unacceptable_object_usage() + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + == "Nested tasks must now only refer to global parameters, raw or tasks, not objects: MyStrangeClass" + ) + + +def test_nested_tasks_must_return_a_task() -> None: + """Ensure nested tasks are lazy-evaluatable. + + A graph only makes sense if the edges connect, and nested tasks must therefore chain. + As such, a nested task must represent a real subgraph, and return a node to pull it into + the main graph. + """ + result = unacceptable_nested_return(int_not_global=True) + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + ) + + result = unacceptable_nested_return(int_not_global=False) + with pytest.raises(TaskException) as exc: construct(result) - assert str(exc.value) == "Calling add_task: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + assert ( + str(exc.value) + == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + )