Skip to content

Commit

Permalink
Merge pull request #49 from flaxandteal/fix/31-return-tuple
Browse files Browse the repository at this point in the history
Fix/31 return tuple
  • Loading branch information
KamenDimitrov97 authored Jul 25, 2024
2 parents 8006938 + 6b181cb commit 39695cb
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 50 deletions.
39 changes: 27 additions & 12 deletions src/dewret/renderers/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import dataclass, is_dataclass, fields as dataclass_fields
from collections.abc import Mapping
from contextvars import ContextVar
from typing import TypedDict, NotRequired, get_args, Union, cast, Any, Unpack
from typing import TypedDict, NotRequired, get_args, Union, cast, Any, Iterable, Unpack
from types import UnionType

from dewret.workflow import (
Expand All @@ -38,7 +38,9 @@
)
from dewret.utils import RawType, flatten, DataclassProtocol

InputSchemaType = Union[str, "CommandInputSchema", list[str], list["InputSchemaType"]]
InputSchemaType = Union[
str, "CommandInputSchema", list[str], list["InputSchemaType"], dict[str, str]
]


@dataclass
Expand Down Expand Up @@ -174,7 +176,7 @@ def render(self) -> dict[str, RawType]:
}


def cwl_type_from_value(val: RawType | Unset) -> str | list[str]:
def cwl_type_from_value(val: RawType | Unset) -> str | list[str] | dict[str, Any]:
"""Find a CWL type for a given (possibly Unset) value.
Args:
Expand All @@ -191,7 +193,7 @@ def cwl_type_from_value(val: RawType | Unset) -> str | list[str]:
return to_cwl_type(raw_type)


def to_cwl_type(typ: type) -> str | list[str]:
def to_cwl_type(typ: type) -> str | dict[str, Any] | list[str]:
"""Map Python types to CWL types.
Args:
Expand All @@ -201,24 +203,37 @@ def to_cwl_type(typ: type) -> str | list[str]:
CWL specification type name, or a list
if a union.
"""
if isinstance(typ, UnionType):
return [to_cwl_type(item) for item in get_args(typ)]

if typ == int:
return "int"
elif typ == bool:
return "boolean"
elif typ == dict or attrs_has(typ):
return "record"
elif typ == list:
return "array"
elif typ == float:
return "double"
return "float"
elif typ == str:
return "string"
elif typ == bytes:
return "bytes"
elif configuration("allow_complex_types"):
return typ if isinstance(typ, str) else typ.__name__
elif isinstance(typ, UnionType):
return [to_cwl_type(item) for item in get_args(typ)]
elif isinstance(typ, Iterable):
try:
basic_types = get_args(typ)
if len(basic_types) > 1:
return {
"type": "array",
"items": [{"type": to_cwl_type(t)} for t in basic_types],
}
else:
return {"type": "array", "items": to_cwl_type(basic_types[0])}
except IndexError as err:
raise TypeError(
f"Cannot render complex type ({typ}) to CWL, have you enabled allow_complex_types configuration?"
) from err
else:
if configuration("allow_complex_types"):
return typ if isinstance(typ, str) else typ.__name__
raise TypeError(f"Cannot render complex type ({typ}) to CWL")


Expand Down
52 changes: 37 additions & 15 deletions src/dewret/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# with WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Expand Down Expand Up @@ -465,7 +465,7 @@ def assimilate(cls, left: Workflow, right: Workflow) -> "Workflow":
new.tasks.update(right.tasks)

for step in new.steps:
step.__workflow__ = new
step.set_workflow(new, with_arguments=True)

# TODO: should we combine as a result array?
result = left.result or right.result
Expand Down Expand Up @@ -742,22 +742,25 @@ def __eq__(self, other: object) -> bool:
and self.arguments == other.arguments
)

def set_workflow(self, workflow: Workflow) -> None:
"""Move the step reference to another workflow.
def set_workflow(self, workflow: Workflow, with_arguments: bool = True) -> None:
"""Move the step reference to a different workflow.
Primarily intended to be called by its step, as a cascade.
It will attempt to update its arguments, similarly.
This method is primarily intended to be called by a step, allowing it to
switch to a new workflow. It also updates the workflow reference for any
arguments that are steps themselves, if specified.
Args:
workflow: the new target workflow.
workflow: The new target workflow to which the step should be moved.
with_arguments: If True, also update the workflow reference for the step's arguments.
"""
self.__workflow__ = workflow
for argument in self.arguments.values():
if hasattr(argument, "__workflow__"):
try:
argument.__workflow__ = workflow
except AttributeError:
...
if with_arguments:
for argument in self.arguments.values():
if hasattr(argument, "__workflow__"):
try:
argument.__workflow__ = workflow
except AttributeError:
...

@property
def return_type(self) -> Any:
Expand Down Expand Up @@ -1011,6 +1014,7 @@ class StepReference(Generic[U], Reference):
"""

step: BaseStep
_tethered_workflow: Workflow | None
_field: str | None
typ: type[U]

Expand Down Expand Up @@ -1039,6 +1043,7 @@ def __init__(
self.step = step
self._field = field
self.typ = typ
self._tethered_workflow = None

def __str__(self) -> str:
"""Global description of the reference."""
Expand All @@ -1048,7 +1053,7 @@ def __repr__(self) -> str:
"""Hashable reference to the step (and field)."""
return f"{self.step.id}/{self.field}"

def __getattr__(self, attr: str) -> "StepReference"[Any]:
def __getattr__(self, attr: str) -> "StepReference[Any]":
"""Reference to a field within this result, if possible.
If the result is an attrs-class or dataclass, this will pull out an individual
Expand Down Expand Up @@ -1119,7 +1124,24 @@ def __workflow__(self) -> Workflow:
Returns:
Workflow that the referee is related to.
"""
return self.step.__workflow__
return self._tethered_workflow or self.step.__workflow__

@__workflow__.setter
def __workflow__(self, workflow: Workflow) -> None:
"""Sets related workflow.
We update the tethered workflow. If the step is missing from
this workflow then, by construction, it should have at least
been through an indexing process once, so we should be able
to get it back by name.
Args:
workflow: workflow to update the step
"""
self._tethered_workflow = workflow
if self._tethered_workflow:
if self.step not in self._tethered_workflow.steps:
self.step = self._tethered_workflow._indexed_steps[self.step.id]

@__workflow__.setter
def __workflow__(self, workflow: Workflow) -> None:
Expand Down
8 changes: 7 additions & 1 deletion tests/_lib/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def double(num: int | float) -> int | float:

@task()
def mod10(num: int) -> int:
"""Double an integer."""
"""Remainder of an integer divided by 10."""
return num % 10


Expand All @@ -37,3 +37,9 @@ def sum(left: int | float, right: int | float) -> int | float:
def triple_and_one(num: int | float) -> int | float:
"""Triple a number by doubling and adding again, then add 1."""
return sum(left=sum(left=double(num=num), right=num), right=1)


@task()
def tuple_float_return() -> tuple[float, float]:
"""Return a tuple of floats."""
return 48.856667, 2.351667
55 changes: 48 additions & 7 deletions tests/test_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from dewret.utils import hasher
from dewret.workflow import param

from ._lib.extra import increment, double, mod10, sum, triple_and_one
from ._lib.extra import (
increment,
double,
mod10,
sum,
triple_and_one,
tuple_float_return,
)


@task()
Expand Down Expand Up @@ -54,7 +61,7 @@ def test_basic_cwl() -> None:
out:
label: out
outputSource: pi-{hsh}/out
type: double
type: float
steps:
pi-{hsh}:
run: pi
Expand Down Expand Up @@ -272,7 +279,7 @@ def test_cwl_with_subworkflow() -> None:
outputSource: sum-1-2/out
type:
- int
- double
- float
steps:
double-1-1:
in:
Expand Down Expand Up @@ -328,7 +335,9 @@ def test_cwl_references() -> None:
out:
label: out
outputSource: double-{hsh_double}/out
type: [int, double]
type:
- int
- float
steps:
increment-{hsh_increment}:
run: increment
Expand Down Expand Up @@ -370,7 +379,9 @@ def test_complex_cwl_references() -> None:
out:
label: out
outputSource: sum-1/out
type: [int, double]
type:
- int
- float
steps:
increment-1:
run: increment
Expand Down Expand Up @@ -471,7 +482,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None:
label: num
type: [
int,
double
float
]
sum-1-2-right:
default: 1
Expand All @@ -483,7 +494,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None:
outputSource: sum-1-2/out
type:
- int
- double
- float
steps:
double-1-1:
in:
Expand Down Expand Up @@ -511,3 +522,33 @@ def test_cwl_with_subworkflow_and_raw_params() -> None:
- out
run: sum
""")


def test_tuple_floats() -> None:
"""Checks whether we can return a tuple.
Produces CWL that has a tuple of 2 values of type float.
"""
result = tuple_float_return()
workflow = construct(result, simplify_ids=True)
rendered = render(workflow)
print(yaml.dump(rendered))
assert rendered == yaml.safe_load("""
cwlVersion: 1.2
class: Workflow
inputs: {}
outputs:
out:
label: out
outputSource: tuple_float_return-1/out
type:
items:
- type: float
- type: float
type: array
steps:
tuple_float_return-1:
run: tuple_float_return
in: {}
out: [out]
""")
4 changes: 2 additions & 2 deletions tests/test_modularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_nested_task() -> None:
inputs:
JUMP:
label: JUMP
type: double
type: float
default: 1.0
increase-3-num:
default: 23
Expand All @@ -45,7 +45,7 @@ def test_nested_task() -> None:
out:
label: out
outputSource: sum-1/out
type: [int, double]
type: [int, float]
steps:
increase-1:
run: increase
Expand Down
Loading

0 comments on commit 39695cb

Please sign in to comment.