From 6bea9b08d9b9fc64c8559060725cfb49543aba8a Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 15 Apr 2024 18:19:40 +0100 Subject: [PATCH 01/11] feat(errors): manipulate traceback to show error at original call of lazy task --- .pre-commit-config.yaml | 3 +- example/extra.py | 5 + example/workflow_complex.py | 8 +- src/dewret/__main__.py | 41 ++++++-- src/dewret/tasks.py | 191 ++++++++++++++++++++++++++---------- src/dewret/utils.py | 24 ++++- src/dewret/workflow.py | 168 ++++++++++++++++++++----------- tests/test_errors.py | 34 +++++-- 8 files changed, 342 insertions(+), 132 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88110b29..0123bf7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,4 +18,5 @@ repos: rev: v1.9.0 hooks: - id: mypy - args: [--strict, --ignore-missing-imports] + args: [--strict, --install-types, --non-interactive] + additional_dependencies: [sympy, attrs, pytest, click] diff --git a/example/extra.py b/example/extra.py index 2ef1744c..053b2f66 100644 --- a/example/extra.py +++ b/example/extra.py @@ -7,26 +7,31 @@ JUMP: int = 10 + @task() def increase(num: int) -> int: """Add globally-configured integer JUMP to a number.""" return num + JUMP + @task() def increment(num: int) -> int: """Increment an integer.""" return num + 1 + @task() def double(num: int) -> int: """Double an integer.""" return 2 * num + @task() def mod10(num: int) -> int: """Calculate supplied integer modulo 10.""" return num % 10 + @task() def sum(left: int, right: int) -> int: """Add two integers.""" diff --git a/example/workflow_complex.py b/example/workflow_complex.py index 71762518..4f9f0182 100644 --- a/example/workflow_complex.py +++ b/example/workflow_complex.py @@ -3,7 +3,7 @@ Useful as an example of a workflow with a nested task. ```sh -$ python -m dewret workflow_complex.py --pretty run +$ python -m dewret workflow_complex.py --pretty nested_workflow ``` """ @@ -12,12 +12,10 @@ STARTING_NUMBER: int = 23 + @nested_task() def nested_workflow() -> int | float: """Creates a graph of task calls.""" left = double(num=increase(num=STARTING_NUMBER)) right = increase(num=increase(num=17)) - return sum( - left=left, - right=right - ) + return sum(left=left, right=right) diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 04f54e83..059ec406 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -29,13 +29,28 @@ from .renderers.cwl import render as cwl_render from .tasks import Backend, construct + @click.command() -@click.option("--pretty", is_flag=True, show_default=True, default=False, help="Pretty-print output where possible.") -@click.option("--backend", type=click.Choice(list(Backend.__members__)), show_default=True, default=Backend.DASK.name, help="Backend to use for workflow evaluation.") +@click.option( + "--pretty", + is_flag=True, + show_default=True, + default=False, + help="Pretty-print output where possible.", +) +@click.option( + "--backend", + type=click.Choice(list(Backend.__members__)), + show_default=True, + default=Backend.DASK.name, + help="Backend to use for workflow evaluation.", +) @click.argument("workflow_py") @click.argument("task") @click.argument("arguments", nargs=-1) -def render(workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend) -> None: +def render( + workflow_py: str, task: str, arguments: list[str], pretty: bool, backend: Backend +) -> None: """Render a workflow. WORKFLOW_PY is the Python file containing workflow. @@ -49,14 +64,24 @@ def render(workflow_py: str, task: str, arguments: list[str], pretty: bool, back kwargs = {} for arg in arguments: if ":" not in arg: - raise RuntimeError("Arguments should be specified as key:val, where val is a JSON representation of the argument") + raise RuntimeError( + "Arguments should be specified as key:val, where val is a JSON representation of the argument" + ) key, val = arg.split(":", 1) kwargs[key] = json.loads(val) - cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) - if pretty: - yaml.dump(cwl, sys.stdout, indent=2) + try: + cwl = cwl_render(construct(task_fn(**kwargs), simplify_ids=True)) + except Exception as exc: + import traceback + + print(exc, exc.__cause__, exc.__context__) + traceback.print_exc() else: - print(cwl) + if pretty: + yaml.dump(cwl, sys.stdout, indent=2) + else: + print(cwl) + render() diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 1c9e4a8b..d9680a74 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -35,31 +35,39 @@ from functools import cached_property from collections.abc import Callable from typing import Any, ParamSpec, TypeVar, cast +from types import TracebackType from attrs import has as attrs_has from dataclasses import is_dataclass +import traceback -from .utils import is_raw +from .utils import is_raw, make_traceback from .workflow import ( StepReference, ParameterReference, Workflow, Lazy, + LazyEvaluation, Target, LazyFactory, StepExecution, merge_workflows, Parameter, param, - is_task + Task, + is_task, ) from .backends._base import BackendModule + class Backend(Enum): """Stringy enum representing available backends.""" + DASK = "dask" + DEFAULT_BACKEND = Backend.DASK + class TaskManager: """Overarching backend-agnostic task manager. @@ -106,7 +114,9 @@ def backend(self) -> BackendModule: if backend is None: backend = self.set_backend(DEFAULT_BACKEND) - backend_mod = importlib.import_module(f".backends.backend_{backend.value}", "dewret") + backend_mod = importlib.import_module( + f".backends.backend_{backend.value}", "dewret" + ) return backend_mod def make_lazy(self) -> LazyFactory: @@ -159,10 +169,14 @@ def ensure_lazy(self, task: Any) -> Lazy | None: Original task, cast to a Lazy, or None. """ if (task := self.ensure_lazy(task)) is None: - raise RuntimeError(f"Task passed to be evaluated, must be lazy-evaluatable, not {type(task)}.") + raise RuntimeError( + f"Task passed to be evaluated, must be lazy-evaluatable, not {type(task)}." + ) return cast(task, Lazy) if self.backend.is_lazy(task) else None - def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Workflow: + def __call__( + self, task: Any, simplify_ids: bool = False, **kwargs: Any + ) -> Workflow: """Execute the lazy evalution. Arguments: @@ -177,6 +191,7 @@ def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Work result = self.evaluate(task, workflow, **kwargs) return Workflow.from_result(result, simplify_ids=simplify_ids) + _manager = TaskManager() lazy = _manager.make_lazy ensure_lazy = _manager.ensure_lazy @@ -184,6 +199,44 @@ def __call__(self, task: Any, simplify_ids: bool = False, **kwargs: Any) -> Work evaluate = _manager.evaluate construct = _manager + +class TaskException(Exception): + """Exception tied to a specific task. + + Primarily aimed at parsing issues, but this will ensure that + a message is shown with useful debug information for the + workflow writer. + """ + + def __init__( + self, + task: Task | Target, + dec_tb: TracebackType | None, + tb: TracebackType | None, + message: str, + *args: Any, + **kwargs: Any, + ): + """Create a TaskException for this exception. + + Args: + task: the Task causing the exception. + dec_tb: a traceback of the task declaration. + tb: a traceback of the original task call. + message: a message to show to the user. + *args: any other arguments accepted by Exception. + **kwargs: any other arguments accepted by Exception. + """ + if dec_tb: + frame = traceback.extract_tb(dec_tb)[-1] + self.add_note( + f"Task {task.__name__} declared in {frame.name} at {frame.filename}:{frame.lineno}\n" + f"{frame.line}" + ) + super().__init__(message) + self.__traceback__ = tb + + def nested_task() -> Callable[[Target], StepExecution]: """Shortcut for marking a task as nested. @@ -211,9 +264,14 @@ def nested_task() -> Callable[[Target], StepExecution]: """ return task(nested=True) + Param = ParamSpec("Param") RetType = TypeVar("RetType") -def task(nested: bool = False) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: + + +def task( + nested: bool = False, +) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: """Decorator factory abstracting backend's own task decorator. For example: @@ -244,56 +302,85 @@ def task(nested: bool = False) -> Callable[[Callable[Param, RetType]], Callable[ """ def _task(fn: Callable[Param, RetType]) -> Callable[Param, RetType]: - def _fn(*args: Any, __workflow__: Workflow | None = None, **kwargs: Param.kwargs) -> RetType: + declaration_tb = make_traceback() + + def _fn( + *args: Any, + __workflow__: Workflow | None = None, + __traceback__: TracebackType | None = None, + **kwargs: Param.kwargs, + ) -> RetType: # By marking any as the positional results list, we prevent unnamed results being # passed at all. - if args: - raise TypeError( - f"Calling {fn.__name__}: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" - ) - - # Ensure that the passed arguments are, at least, a Python-match for the signature. - sig = inspect.signature(fn) - sig.bind(*args, **kwargs) - - workflows = [ - reference.__workflow__ - for reference in kwargs.values() - if hasattr(reference, "__workflow__") and reference.__workflow__ is not None - ] - if __workflow__ is not None: - workflows.insert(0, __workflow__) - if workflows: - workflow = merge_workflows(*workflows) - else: - workflow = Workflow() - original_kwargs = dict(kwargs) - for var, value in inspect.getclosurevars(fn).globals.items(): - if var in kwargs: - raise TypeError("Captured parameter (global variable in task) shadows an argument") - if isinstance(value, Parameter): - kwargs[var] = ParameterReference(workflow, value) - elif is_raw(value): - parameter = param(var, value) - kwargs[var] = ParameterReference(workflow, parameter) - elif is_task(value): - if not nested: - raise TypeError("You reference a task inside another task, but it is not a nested_task - this will not be found!") - elif attrs_has(value) or is_dataclass(value): - ... - elif nested: - raise NotImplementedError(f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}") - if nested: - lazy_fn = cast(Lazy, fn(**original_kwargs)) - step_reference = evaluate(lazy_fn, __workflow__=workflow) - if isinstance(step_reference, StepReference): - return cast(RetType, step_reference) - raise TypeError("Nested tasks must return a step reference, to ensure graph makes sense.") - return cast(RetType, workflow.add_step(fn, kwargs)) - _fn.__step_expression__ = True # type: ignore - return lazy()(_fn) + try: + if args: + raise TypeError( + f"Calling {fn.__name__}: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + ) + + # Ensure that the passed arguments are, at least, a Python-match for the signature. + sig = inspect.signature(fn) + sig.bind(*args, **kwargs) + + workflows = [ + reference.__workflow__ + for reference in kwargs.values() + if hasattr(reference, "__workflow__") + and reference.__workflow__ is not None + ] + if __workflow__ is not None: + workflows.insert(0, __workflow__) + if workflows: + workflow = merge_workflows(*workflows) + else: + workflow = Workflow() + original_kwargs = dict(kwargs) + for var, value in inspect.getclosurevars(fn).globals.items(): + if var in kwargs: + raise TypeError( + "Captured parameter (global variable in task) shadows an argument" + ) + if isinstance(value, Parameter): + kwargs[var] = ParameterReference(workflow, value) + elif is_raw(value): + parameter = param(var, value) + kwargs[var] = ParameterReference(workflow, parameter) + elif is_task(value): + if not nested: + raise TypeError( + "You reference a task inside another task, but it is not a nested_task - this will not be found!" + ) + elif attrs_has(value) or is_dataclass(value): + ... + elif nested: + raise NotImplementedError( + f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" + ) + if nested: + lazy_fn = cast(Lazy, fn(**original_kwargs)) + step_reference = evaluate(lazy_fn, __workflow__=workflow) + if isinstance(step_reference, StepReference): + return cast(RetType, step_reference) + raise TypeError( + "Nested tasks must return a step reference, to ensure graph makes sense." + ) + return cast(RetType, workflow.add_step(fn, kwargs)) + except TaskException as exc: + raise exc + except Exception as exc: + raise TaskException( + fn, + declaration_tb, + __traceback__, + exc.args[0] if exc.args else "Could not call task", + ) from exc + + _fn.__step_expression__ = True # type: ignore + return LazyEvaluation(lazy()(_fn)) + return _task + def set_backend(backend: Backend) -> None: """Choose a backend. diff --git a/src/dewret/utils.py b/src/dewret/utils.py index 74d272e7..cebb14bc 100644 --- a/src/dewret/utils.py +++ b/src/dewret/utils.py @@ -19,6 +19,8 @@ import hashlib import json +import sys +from types import FrameType, TracebackType from typing import Any, cast, Union, Protocol, ClassVar from collections.abc import Sequence, Mapping @@ -26,15 +28,34 @@ RawType = Union[BasicType, list["RawType"], dict[str, "RawType"]] FirmType = BasicType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] + class DataclassProtocol(Protocol): """Format of a dataclass. Since dataclasses do not expose a proper type, we use this to represent them. """ + __dataclass_fields__: ClassVar[dict[str, Any]] +def make_traceback(skip: int = 2) -> TracebackType | None: + """Creates a traceback for the current frame. + + Necessary to allow tracebacks to be prepped for + potential errors in lazy-evaluated functions. + + Args: + skip: number of frames to skip before starting traceback. + """ + frame: FrameType | None = sys._getframe(skip) + tb = None + while frame: + tb = TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + frame = frame.f_back + return tb + + def flatten(value: Any) -> RawType: """Takes a Raw-like structure and makes it RawType. @@ -68,6 +89,7 @@ def is_raw(value: Any) -> bool: # but recursive types are problematic. return isinstance(value, str | float | bool | bytes | int | None | list | dict) + def ensure_raw(value: Any) -> RawType | None: """Check if a variable counts as "raw". @@ -79,6 +101,7 @@ def ensure_raw(value: Any) -> RawType | None: # isinstance(var, RawType | list[RawType] | dict[str, RawType]) return cast(RawType, value) if is_raw(value) else None + def hasher(construct: FirmType) -> str: """Consistently hash a RawType or tuple structure. @@ -102,4 +125,3 @@ def hasher(construct: FirmType) -> str: hsh = hashlib.md5() hsh.update(construct_as_string.encode()) return hsh.hexdigest() - diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index 48df4010..ffee8fea 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -25,13 +25,17 @@ from dataclasses import is_dataclass, fields as dataclass_fields from collections import Counter from typing import Protocol, Any, TypeVar, Generic, cast + from sympy import Symbol import logging logger = logging.getLogger(__name__) +from .utils import hasher, RawType, is_raw, make_traceback + +T = TypeVar("T") +RetType = TypeVar("RetType") -from .utils import hasher, RawType, is_raw @define class Raw: @@ -43,6 +47,7 @@ class Raw: Attributes: value: the real value, e.g. a `str`, `int`, ... """ + value: RawType def __hash__(self) -> int: @@ -58,21 +63,49 @@ def __repr__(self) -> str: value = str(self.value) return f"{type(self.value).__name__}|{value}" + class Lazy(Protocol): """Requirements for a lazy-evaluatable function.""" + __name__: str def __call__(self, *args: Any, **kwargs: Any) -> Any: """When called this should return a reference.""" ... + +class LazyEvaluation(Lazy, Generic[RetType]): + """Tracks a single evaluation of a lazy function.""" + + def __init__(self, fn: Callable[..., RetType]): + """Initialize an evaluation. + + Args: + fn: callable returning RetType, which this will return + also from it's __call__ method for consistency. + """ + self._fn: Callable[..., RetType] = fn + self.__name__ = fn.__name__ + + def __call__(self, *args: Any, **kwargs: Any) -> RetType: + """Wrapper around a lazy execution. + + Captures a traceback, for debugging if this does not work. + + WARNING: this is one of the few places that we would expect + dask distributed to break, if running outside a single process + is attempted. + """ + tb = make_traceback() + return self._fn(*args, **kwargs, __traceback__=tb) + + Target = Callable[..., Any] StepExecution = Callable[..., Lazy] LazyFactory = Callable[[Target], Lazy] -T = TypeVar("T") -class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] +class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] """Global parameter. Independent parameter that will be used when a task is spotted @@ -85,6 +118,7 @@ class Parameter(Generic[T], Symbol): # type: ignore[no-untyped-call] __name__: name of the parameter. __default__: captured default value from the original value. """ + __name__: str __default__: T @@ -99,6 +133,7 @@ def __init__(self, name: str, default: T): self.__name__ = name self.__default__ = default + def param(name: str, default: T) -> T: """Create a parameter. @@ -109,6 +144,7 @@ def param(name: str, default: T) -> T: """ return cast(T, Parameter(name, default=default)) + class Task: """Named wrapper of a lazy-evaluatable function. @@ -133,6 +169,11 @@ def __init__(self, name: str, target: Lazy): self.name = name self.target = target + @property + def __name__(self) -> str: + """Name of the task.""" + return self.name + def __str__(self) -> str: """Stringify the Task, currently by returning the `name`.""" return self.name @@ -153,10 +194,8 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, Task): return False - return ( - self.name == other.name and - self.target == other.target - ) + return self.name == other.name and self.target == other.target + class Workflow: """Overarching workflow concept. @@ -172,6 +211,7 @@ class Workflow: `Task` wrappers they represent. result: target reference to evaluate, if yet present. """ + steps: list["Step"] tasks: MutableMapping[str, "Task"] result: StepReference[Any] | None @@ -193,10 +233,16 @@ def find_parameters(self) -> set[ParameterReference]: Returns: Set of all references to parameters across the steps. """ - return set().union(*({ - arg for arg in step.arguments.values() - if isinstance(arg, ParameterReference) - } for step in self.steps)) + return set().union( + *( + { + arg + for arg in step.arguments.values() + if isinstance(arg, ParameterReference) + } + for step in self.steps + ) + ) @property def _indexed_steps(self) -> dict[str, Step]: @@ -209,9 +255,7 @@ def _indexed_steps(self) -> dict[str, Step]: Returns: Mapping of steps by ID. """ - return { - step.id: step for step in self.steps - } + return {step.id: step for step in self.steps} @classmethod def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": @@ -231,13 +275,15 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow": left_steps = left._indexed_steps right_steps = right._indexed_steps - for step_id in (left_steps.keys() & right_steps.keys()): + for step_id in left_steps.keys() & right_steps.keys(): left_steps[step_id].set_workflow(new) right_steps[step_id].set_workflow(new) if left_steps[step_id] != right_steps[step_id]: - raise RuntimeError(f"Two steps have same ID but do not match: {step_id}") + raise RuntimeError( + f"Two steps have same ID but do not match: {step_id}" + ) - for task_id in (left.tasks.keys() & right.tasks.keys()): + for task_id in left.tasks.keys() & right.tasks.keys(): if left.tasks[task_id] != right.tasks[task_id]: raise RuntimeError("Two tasks have same name but do not match") @@ -261,11 +307,7 @@ def remap(self, step_id: str) -> str: Returns: Same ID or a remapped name. """ - return ( - self._remapping.get(step_id, step_id) - if self._remapping else - step_id - ) + return self._remapping.get(step_id, step_id) if self._remapping else step_id def simplify_ids(self) -> None: """Work out mapping to simple ints from hashes. @@ -296,7 +338,9 @@ def register_task(self, fn: Lazy) -> Task: self.tasks[name] = task return task - def add_step(self, fn: Lazy, kwargs: dict[str, Raw | Reference]) -> StepReference[Any]: + def add_step( + self, fn: Lazy, kwargs: dict[str, Raw | Reference] + ) -> StepReference[Any]: """Append a step. Adds a step, for running a target with key-value arguments, @@ -307,11 +351,7 @@ def add_step(self, fn: Lazy, kwargs: dict[str, Raw | Reference]) -> StepReferenc kwargs: any key-value arguments to pass in the call. """ task = self.register_task(fn) - step = Step( - self, - task, - kwargs - ) + step = Step(self, task, kwargs) self.steps.append(step) return_type = step.return_type if return_type is inspect._empty: @@ -353,6 +393,7 @@ class WorkflowComponent: Attributes: __workflow__: the `Workflow` that this is tied to. """ + __workflow__: Workflow def __init__(self, workflow: Workflow): @@ -365,6 +406,7 @@ def __init__(self, workflow: Workflow): """ self.__workflow__ = workflow + class WorkflowLinkedComponent(Protocol): """Protocol for objects dynamically tied to a `Workflow`.""" @@ -380,6 +422,7 @@ def __workflow__(self) -> Workflow: """ ... + class Reference: """Superclass for all symbolic references to values.""" @@ -388,6 +431,7 @@ def name(self) -> str: """Referral name for this reference.""" raise NotImplementedError("Reference must provide a name") + class Step(WorkflowComponent): """Lazy-evaluated function call. @@ -398,11 +442,14 @@ class Step(WorkflowComponent): task: the `Task` being called in this step. arguments: key-value pairs of arguments to this step. """ + _id: str | None = None task: Task arguments: Mapping[str, Reference | Raw] - def __init__(self, workflow: Workflow, task: Task, arguments: Mapping[str, Reference | Raw]): + def __init__( + self, workflow: Workflow, task: Task, arguments: Mapping[str, Reference | Raw] + ): """Initialize a step. Args: @@ -414,17 +461,19 @@ def __init__(self, workflow: Workflow, task: Task, arguments: Mapping[str, Refer self.task = task self.arguments = {} for key, value in arguments.items(): - if ( - isinstance(value, Reference) or - isinstance(value, Raw) or - is_raw(value) - ): + if isinstance(value, Reference) or isinstance(value, Raw) or is_raw(value): # Avoid recursive type issues - if not isinstance(value, Reference) and not isinstance(value, Raw) and is_raw(value): + if ( + not isinstance(value, Reference) + and not isinstance(value, Raw) + and is_raw(value) + ): value = Raw(value) self.arguments[key] = value else: - raise RuntimeError(f"Non-references must be a serializable type: {key}>{value}") + raise RuntimeError( + f"Non-references must be a serializable type: {key}>{value}" + ) def __eq__(self, other: object) -> bool: """Is this the same step? @@ -435,9 +484,9 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Step): return False return ( - self.__workflow__ == other.__workflow__ and - self.task == other.task and - self.arguments == other.arguments + self.__workflow__ == other.__workflow__ + and self.task == other.task + and self.arguments == other.arguments ) def set_workflow(self, workflow: Workflow) -> None: @@ -487,7 +536,9 @@ def id(self) -> str: check_id = self._generate_id() if check_id != self._id: - raise RuntimeError(f"Cannot change a step after requesting its ID: {self.task}") + raise RuntimeError( + f"Cannot change a step after requesting its ID: {self.task}" + ) return self._id def _generate_id(self) -> str: @@ -500,6 +551,7 @@ def _generate_id(self) -> str: return f"{self.task}-{hasher(comp_tup)}" + class ParameterReference(Reference): """Reference to an individual `Parameter`. @@ -509,6 +561,7 @@ class ParameterReference(Reference): Attributes: parameter: `Parameter` referred to. """ + parameter: Parameter[RawType] workflow: Workflow @@ -569,11 +622,13 @@ def __eq__(self, other: object) -> bool: True if the other parameter reference is materially the same, otherwise False. """ return ( - isinstance(other, ParameterReference) and - self.parameter == other.parameter + isinstance(other, ParameterReference) and self.parameter == other.parameter ) + U = TypeVar("U") + + class StepReference(Generic[U], Reference): """Reference to an individual `Step`. @@ -583,6 +638,7 @@ class StepReference(Generic[U], Reference): Attributes: step: `Step` referred to. """ + step: Step _field: str | None typ: type[U] @@ -598,7 +654,9 @@ def field(self) -> str: """ return self._field or "out" - def __init__(self, workflow: Workflow, step: Step, typ: type[U], field: str | None = None): + def __init__( + self, workflow: Workflow, step: Step, typ: type[U], field: str | None = None + ): """Initialize the reference. Args: @@ -640,7 +698,9 @@ def __getattr__(self, attr: str) -> "StepReference"[Any]: resolve_types(self.typ) typ = getattr(attrs_fields(self.typ), attr).type elif is_dataclass(self.typ): - matched = [field for field in dataclass_fields(self.typ) if field.name == attr] + matched = [ + field for field in dataclass_fields(self.typ) if field.name == attr + ] if not matched: raise AttributeError(f"Field {attr} not present in dataclass") typ = matched[0].type @@ -649,12 +709,11 @@ def __getattr__(self, attr: str) -> "StepReference"[Any]: if typ: return self.__class__( - workflow=self.__workflow__, - step=self.step, - typ=typ, - field=attr + workflow=self.__workflow__, step=self.step, typ=typ, field=attr ) - raise RuntimeError("Can only get attribute of a StepReference representing an attrs-class or dataclass") + raise RuntimeError( + "Can only get attribute of a StepReference representing an attrs-class or dataclass" + ) @property def return_type(self) -> type[U]: @@ -683,6 +742,7 @@ def __workflow__(self) -> Workflow: """ return self.step.__workflow__ + def merge_workflows(*workflows: Workflow) -> Workflow: """Combine several workflows into one. @@ -699,6 +759,7 @@ def merge_workflows(*workflows: Workflow) -> Workflow: base = Workflow.assimilate(base, workflow) return base + def is_task(task: Lazy) -> bool: """Decide whether this is a task. @@ -712,11 +773,4 @@ def is_task(task: Lazy) -> bool: Returns: True if `task` is indeed a task. """ - from .tasks import unwrap - try: - func = unwrap(task) - if hasattr(func, "__step_expression__"): - return bool(func.__step_expression__) - except Exception: - ... - return False + return isinstance(task, LazyEvaluation) diff --git a/tests/test_errors.py b/tests/test_errors.py index bbfd29a2..69dd327d 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,17 +1,23 @@ """Test for expected errors.""" import pytest -from dewret.tasks import construct, task, nested_task +from dewret.tasks import construct, task, nested_task, TaskException -@task() + +@task() # This is expected to be the line number shown below. def add_task(left: int, right: int) -> int: """Adds two values and returns the result.""" return left + right + +ADD_TASK_LINE_NO = 7 + + @nested_task() def badly_add_task(left: int, right: int) -> int: """Badly attempts to add two numbers.""" - return add_task(left=left) # type: ignore + return add_task(left=left) # type: ignore + def test_missing_arguments_throw_error() -> None: """Check whether omitting a required argument will give an error. @@ -22,10 +28,14 @@ def test_missing_arguments_throw_error() -> None: WARNING: in keeping with Python principles, this does not error if types mismatch, but `mypy` should. You **must** type-check your code to catch these. """ - result = add_task(left=3) # type: ignore - with pytest.raises(TypeError) as exc: + result = add_task(left=3) # type: ignore + with pytest.raises(TaskException) as exc: construct(result) + end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" + assert "Task add_task declared in at " in end_section + assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section + def test_missing_arguments_throw_error_in_nested_task() -> None: """Check whether omitting a required argument will give an error. @@ -37,9 +47,14 @@ def test_missing_arguments_throw_error_in_nested_task() -> None: mismatch, but `mypy` should. You **must** type-check your code to catch these. """ result = badly_add_task(left=3, right=4) - with pytest.raises(TypeError) as exc: + with pytest.raises(TaskException) as exc: construct(result) + end_section = str(exc.getrepr())[-500:] assert str(exc.value) == "missing a required argument: 'right'" + assert "def badly_add_task" in end_section + assert "Task add_task declared in at " in end_section + assert f"test_errors.py:{ADD_TASK_LINE_NO}" in end_section + def test_positional_arguments_throw_error() -> None: """Check whether we can produce simple CWL. @@ -48,6 +63,9 @@ def test_positional_arguments_throw_error() -> None: to _always_ be named. """ result = add_task(3, right=4) - with pytest.raises(TypeError) as exc: + with pytest.raises(TaskException) as exc: construct(result) - assert str(exc.value) == "Calling add_task: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + assert ( + str(exc.value) + == "Calling add_task: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + ) From 256767c79e17cad78aa982f8a1fe1330cd169fbf Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Mon, 15 Apr 2024 19:25:13 +0100 Subject: [PATCH 02/11] feat(errors): improving error messages for clarity, and adding tests --- .pre-commit-config.yaml | 2 +- src/dewret/backends/backend_dask.py | 10 ++- src/dewret/tasks.py | 55 ++++++++++++----- tests/test_errors.py | 94 ++++++++++++++++++++++++++++- 4 files changed, 141 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0123bf7c..c50d47f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,4 @@ repos: hooks: - id: mypy args: [--strict, --install-types, --non-interactive] - additional_dependencies: [sympy, attrs, pytest, click] + additional_dependencies: [sympy, attrs, pytest, click, dask] diff --git a/src/dewret/backends/backend_dask.py b/src/dewret/backends/backend_dask.py index f95e7744..7e6b33da 100644 --- a/src/dewret/backends/backend_dask.py +++ b/src/dewret/backends/backend_dask.py @@ -21,6 +21,7 @@ from dewret.workflow import Workflow, Lazy, StepReference, Target from typing import Protocol, runtime_checkable, Any, cast + @runtime_checkable class Delayed(Protocol): """Description of a dask `delayed`. @@ -48,6 +49,7 @@ def compute(self, __workflow__: Workflow | None) -> StepReference[Any]: """ ... + def unwrap(task: Lazy) -> Target: """Unwraps a lazy-evaluated function to get the function. @@ -69,6 +71,7 @@ def unwrap(task: Lazy) -> Target: raise RuntimeError("Task is not actually a callable") return cast(Target, task._obj) + def is_lazy(task: Any) -> bool: """Checks if a task is really a lazy-evaluated function for this backend. @@ -80,7 +83,10 @@ def is_lazy(task: Any) -> bool: """ return isinstance(task, Delayed) + lazy = delayed + + def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: """Execute a task as the output of a workflow. @@ -92,5 +98,7 @@ def run(workflow: Workflow | None, task: Lazy) -> StepReference[Any]: """ # We need isinstance to reassure type-checker. if not isinstance(task, Delayed) or not is_lazy(task): - raise RuntimeError("Cannot mix backends") + raise RuntimeError( + f"{task} is not a dask delayed, perhaps you tried to mix backends?" + ) return task.compute(__workflow__=workflow) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index d9680a74..408949bd 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -168,11 +168,9 @@ def ensure_lazy(self, task: Any) -> Lazy | None: Returns: Original task, cast to a Lazy, or None. """ - if (task := self.ensure_lazy(task)) is None: - raise RuntimeError( - f"Task passed to be evaluated, must be lazy-evaluatable, not {type(task)}." - ) - return cast(task, Lazy) if self.backend.is_lazy(task) else None + if isinstance(task, LazyEvaluation): + return self.ensure_lazy(task._fn) + return task if self.backend.is_lazy(task) else None def __call__( self, task: Any, simplify_ids: bool = False, **kwargs: Any @@ -310,12 +308,21 @@ def _fn( __traceback__: TracebackType | None = None, **kwargs: Param.kwargs, ) -> RetType: - # By marking any as the positional results list, we prevent unnamed results being - # passed at all. try: + # By marking any as the positional results list, we prevent unnamed results being + # passed at all. if args: raise TypeError( - f"Calling {fn.__name__}: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + f""" + Calling {fn.__name__}: Arguments must _always_ be named, + e.g. my_task(num=1) not my_task(1)\n" + + @task() + def add_numbers(left: int, right: int): + return left + right + + construct(add_numbers(left=3, right=5)) + """ ) # Ensure that the passed arguments are, at least, a Python-match for the signature. @@ -336,10 +343,8 @@ def _fn( workflow = Workflow() original_kwargs = dict(kwargs) for var, value in inspect.getclosurevars(fn).globals.items(): - if var in kwargs: - raise TypeError( - "Captured parameter (global variable in task) shadows an argument" - ) + # This error is redundant as it triggers a SyntaxError in Python. + # "Captured parameter {var} (global variable in task) shadows an argument" if isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow, value) elif is_raw(value): @@ -348,7 +353,20 @@ def _fn( elif is_task(value): if not nested: raise TypeError( - "You reference a task inside another task, but it is not a nested_task - this will not be found!" + f""" + You reference a task {var} inside another task {fn.__name__}, but it is not a nested_task + - this will not be found! + + @task + def {var}(...) -> ...: + ... + + @nested_task <<<--- likely what you want + def {fn.__name__}(...) -> ...: + ... + {var}(...) + ... + """ ) elif attrs_has(value) or is_dataclass(value): ... @@ -357,12 +375,17 @@ def _fn( f"Nested tasks must now only refer to global parameters, raw or tasks, not objects: {var}" ) if nested: - lazy_fn = cast(Lazy, fn(**original_kwargs)) + output = fn(**original_kwargs) + lazy_fn = ensure_lazy(output) + if lazy_fn is None: + raise TypeError( + f"Task {fn.__name__} returned output of type {type(output)}, which is not a lazy function for this backend." + ) step_reference = evaluate(lazy_fn, __workflow__=workflow) if isinstance(step_reference, StepReference): return cast(RetType, step_reference) raise TypeError( - "Nested tasks must return a step reference, to ensure graph makes sense." + f"Nested tasks must return a step reference, not {type(step_reference)} to ensure graph makes sense." ) return cast(RetType, workflow.add_step(fn, kwargs)) except TaskException as exc: @@ -372,7 +395,7 @@ def _fn( fn, declaration_tb, __traceback__, - exc.args[0] if exc.args else "Could not call task", + exc.args[0] if exc.args else "Could not call task {fn.__name__}", ) from exc _fn.__step_expression__ = True # type: ignore diff --git a/tests/test_errors.py b/tests/test_errors.py index 69dd327d..d6f75774 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,6 +1,7 @@ """Test for expected errors.""" import pytest +from dewret.workflow import Task, Lazy from dewret.tasks import construct, task, nested_task, TaskException @@ -10,7 +11,7 @@ def add_task(left: int, right: int) -> int: return left + right -ADD_TASK_LINE_NO = 7 +ADD_TASK_LINE_NO = 8 @nested_task() @@ -19,6 +20,33 @@ def badly_add_task(left: int, right: int) -> int: return add_task(left=left) # type: ignore +@task() # This is expected to be the line number shown below. +def badly_wrap_task() -> int: + """Sums two values but should not be calling a task.""" + return add_task(left=3, right=4) + + +class MyStrangeClass: + """Dummy class for tests.""" + + def __init__(self, task: Task): + """Dummy constructor for tests.""" + ... + + +@nested_task() # This is expected to be the line number shown below. +def unacceptable_object_usage() -> int: + """Sums two values but should not be calling a task.""" + return MyStrangeClass(add_task(left=3, right=4)) # type: ignore + + +@nested_task() # This is expected to be the line number shown below. +def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: + """Sums two values but should not be calling a task.""" + add_task(left=3, right=4) + return 7 if int_not_global else ADD_TASK_LINE_NO + + def test_missing_arguments_throw_error() -> None: """Check whether omitting a required argument will give an error. @@ -67,5 +95,67 @@ def test_positional_arguments_throw_error() -> None: construct(result) assert ( str(exc.value) - == "Calling add_task: Arguments must _always_ be named, e.g. my_task(num=1) not my_task(1)" + .strip() + .startswith("Calling add_task: Arguments must _always_ be named") + ) + + +def test_nesting_non_nested_tasks_throws_error() -> None: + """Ensure nesting is only allow in nested_tasks. + + Nested tasks must be evaluated at construction time, and there + is no concept of task calls that are not resolved during construction, so + a task should not be called inside a non-nested task. + """ + result = badly_wrap_task() + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + .strip() + .startswith( + "You reference a task add_task inside another task badly_wrap_task, but it is not a nested_task" + ) + ) + + +def test_normal_objects_cannot_be_used_in_nested_tasks() -> None: + """Most entities cannot appear in a nested_task, ensure we catch them. + + Since the logic in nested tasks has to be embedded explicitly in the workflow, + complex types are not necessarily representable, and in most cases, we would not + be able to guarantee that the libraries, versions, etc. match. + + Note: this may be mitigated with sympy support, to some extent. + """ + result = unacceptable_object_usage() + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + == "Nested tasks must now only refer to global parameters, raw or tasks, not objects: MyStrangeClass" + ) + + +def test_nested_tasks_must_return_a_task() -> None: + """Ensure nested tasks are lazy-evaluatable. + + A graph only makes sense if the edges connect, and nested tasks must therefore chain. + As such, a nested task must represent a real subgraph, and return a node to pull it into + the main graph. + """ + result = unacceptable_nested_return(int_not_global=True) + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." + ) + + result = unacceptable_nested_return(int_not_global=False) + with pytest.raises(TaskException) as exc: + construct(result) + assert ( + str(exc.value) + == "Task unacceptable_nested_return returned output of type , which is not a lazy function for this backend." ) From 26af40ade9249c7cf10f81d0123a95b277383402 Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 22:04:12 +0100 Subject: [PATCH 03/11] fix(comments): remove copypasta regarding line numbers --- tests/test_errors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index d6f75774..5168d006 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -20,7 +20,7 @@ def badly_add_task(left: int, right: int) -> int: return add_task(left=left) # type: ignore -@task() # This is expected to be the line number shown below. +@task() def badly_wrap_task() -> int: """Sums two values but should not be calling a task.""" return add_task(left=3, right=4) @@ -34,13 +34,13 @@ def __init__(self, task: Task): ... -@nested_task() # This is expected to be the line number shown below. +@nested_task() def unacceptable_object_usage() -> int: """Sums two values but should not be calling a task.""" return MyStrangeClass(add_task(left=3, right=4)) # type: ignore -@nested_task() # This is expected to be the line number shown below. +@nested_task() def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: """Sums two values but should not be calling a task.""" add_task(left=3, right=4) From 54c47c8a7fdeab51dca90032ef8bbbe7335891e4 Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 22:08:00 +0100 Subject: [PATCH 04/11] fix(docstrings): correct to match test function --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 5168d006..615a5bab 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -85,7 +85,7 @@ def test_missing_arguments_throw_error_in_nested_task() -> None: def test_positional_arguments_throw_error() -> None: - """Check whether we can produce simple CWL. + """Check whether unnamed (positional) arguments throw an error. We can use default and non-default arguments, but we expect them to _always_ be named. From 7965472e147d874a240c954f9756d5db51e1d0ee Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 23:21:37 +0100 Subject: [PATCH 05/11] fix(docstrings): Update src/dewret/tasks.py Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- src/dewret/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 408949bd..45f5c05e 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -354,7 +354,7 @@ def add_numbers(left: int, right: int): if not nested: raise TypeError( f""" - You reference a task {var} inside another task {fn.__name__}, but it is not a nested_task + You referenced a task {var} inside another task {fn.__name__}, but it is not a nested_task - this will not be found! @task From b5e41df7f43cec991e0dd9765e93621f614d12eb Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 23:22:11 +0100 Subject: [PATCH 06/11] fix(docstrings): Update tests/test_errors.py Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 615a5bab..632bc063 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -114,7 +114,7 @@ def test_nesting_non_nested_tasks_throws_error() -> None: str(exc.value) .strip() .startswith( - "You reference a task add_task inside another task badly_wrap_task, but it is not a nested_task" + "You referenced a task add_task inside another task badly_wrap_task, but it is not a nested_task" ) ) From 38e27d7bce2ea9e7d2041758000f9ae4e5ec022e Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 23:48:40 +0100 Subject: [PATCH 07/11] fix(docstrings): test text correct Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 632bc063..4ef97550 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -42,7 +42,7 @@ def unacceptable_object_usage() -> int: @nested_task() def unacceptable_nested_return(int_not_global: bool) -> int | Lazy: - """Sums two values but should not be calling a task.""" + """Bad nested_task that fails to return a task.""" add_task(left=3, right=4) return 7 if int_not_global else ADD_TASK_LINE_NO From ef1a8bfa801a11b8e2d57bd626eb98b9745ad52e Mon Sep 17 00:00:00 2001 From: P T Weir Date: Fri, 19 Apr 2024 23:49:21 +0100 Subject: [PATCH 08/11] fix(docstrings): test text correct Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 4ef97550..7e1912e3 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -36,7 +36,7 @@ def __init__(self, task: Task): @nested_task() def unacceptable_object_usage() -> int: - """Sums two values but should not be calling a task.""" + """Invalid use of custom object within nested task.""" return MyStrangeClass(add_task(left=3, right=4)) # type: ignore From 58e6a81b51a7dd107be5f76c956c83c69e251798 Mon Sep 17 00:00:00 2001 From: P T Weir Date: Sat, 20 Apr 2024 14:04:12 +0100 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- src/dewret/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index 45f5c05e..d835c09c 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -309,7 +309,7 @@ def _fn( **kwargs: Param.kwargs, ) -> RetType: try: - # By marking any as the positional results list, we prevent unnamed results being + # Ensure that all arguments are passed as keyword args and prevent positional args. # passed at all. if args: raise TypeError( From 6fc123d839ea98f6101856d3b74387c2a575f4e9 Mon Sep 17 00:00:00 2001 From: P T Weir Date: Tue, 23 Apr 2024 01:31:15 +0100 Subject: [PATCH 10/11] fix(docstrings): more precision in tests/test_errors.py Co-authored-by: Ellery Ames <61203509+elleryames@users.noreply.github.com> --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 7e1912e3..3bbf3945 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -66,7 +66,7 @@ def test_missing_arguments_throw_error() -> None: def test_missing_arguments_throw_error_in_nested_task() -> None: - """Check whether omitting a required argument will give an error. + """Check whether omitting a required argument within a nested_task will give an error. Since we do not run the original function, it is up to dewret to check that the signature is, at least, acceptable to Python. From 96bd6af2cad790323c7dbefcf5600e9b7ab003f4 Mon Sep 17 00:00:00 2001 From: P T Weir Date: Tue, 23 Apr 2024 09:29:19 +0100 Subject: [PATCH 11/11] fix(comments): make redundant error explicit --- src/dewret/tasks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dewret/tasks.py b/src/dewret/tasks.py index d835c09c..2b2dae3f 100644 --- a/src/dewret/tasks.py +++ b/src/dewret/tasks.py @@ -344,7 +344,11 @@ def add_numbers(left: int, right: int): original_kwargs = dict(kwargs) for var, value in inspect.getclosurevars(fn).globals.items(): # This error is redundant as it triggers a SyntaxError in Python. - # "Captured parameter {var} (global variable in task) shadows an argument" + # Note: the following test duplicates a syntax error. + # if var in kwargs: + # raise TypeError( + # "Captured parameter {var} (global variable in task) shadows an argument" + # ) if isinstance(value, Parameter): kwargs[var] = ParameterReference(workflow, value) elif is_raw(value):