From ff3fdf6e7f8ace06f4fb7bcea02479af4bcddc47 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Tue, 3 Sep 2024 12:52:36 +0100 Subject: [PATCH] Enable Pydantic I/O types in workflow context Extend experimental Pydantic I/O support to allow passing Pydantic types into `@script`-decorated functions when inside a `with` workflow context block, and using fields on the returned Pydantic output as shorthand for the associated Hera template in subsequent steps. Signed-off-by: Alice Purcell --- .../pydantic_io_in_dag_context.md | 154 ++++++++++++++++++ .../pydantic_io_in_steps_context.md | 152 +++++++++++++++++ .../pydantic-io-in-dag-context.yaml | 84 ++++++++++ .../pydantic-io-in-steps-context.yaml | 82 ++++++++++ .../pydantic_io_in_dag_context.py | 53 ++++++ .../pydantic_io_in_steps_context.py | 53 ++++++ src/hera/workflows/_meta_mixins.py | 6 + src/hera/workflows/_mixins.py | 4 +- src/hera/workflows/io/_io_mixins.py | 35 ++-- src/hera/workflows/script.py | 18 +- 10 files changed, 624 insertions(+), 17 deletions(-) create mode 100644 docs/examples/workflows/experimental/pydantic_io_in_dag_context.md create mode 100644 docs/examples/workflows/experimental/pydantic_io_in_steps_context.md create mode 100644 examples/workflows/experimental/pydantic-io-in-dag-context.yaml create mode 100644 examples/workflows/experimental/pydantic-io-in-steps-context.yaml create mode 100644 examples/workflows/experimental/pydantic_io_in_dag_context.py create mode 100644 examples/workflows/experimental/pydantic_io_in_steps_context.py diff --git a/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md b/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md new file mode 100644 index 000000000..2b0d97d6e --- /dev/null +++ b/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md @@ -0,0 +1,154 @@ +# Pydantic Io In Dag Context + + + + + + +=== "Hera" + + ```python linenums="1" + import sys + from typing import List + + if sys.version_info >= (3, 9): + from typing import Annotated + else: + from typing_extensions import Annotated + + + from hera.shared import global_config + from hera.workflows import DAG, Parameter, WorkflowTemplate, script + from hera.workflows.io.v1 import Input, Output + + global_config.experimental_features["script_pydantic_io"] = True + + + class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + + class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + + class JoinInput(Input): + strings: List[str] + joiner: str + + + class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + + @script(constructor="runner") + def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + + @script(constructor="runner") + def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + + with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with DAG(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) + ``` + +=== "YAML" + + ```yaml linenums="1" + apiVersion: argoproj.io/v1alpha1 + kind: WorkflowTemplate + metadata: + generateName: pydantic-io-in-steps-context-v1- + spec: + entrypoint: d + templates: + - dag: + tasks: + - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - arguments: + parameters: + - name: strings + value: '{{tasks.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + depends: cut + name: join + template: join + name: d + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + ``` + diff --git a/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md b/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md new file mode 100644 index 000000000..d0cf9b977 --- /dev/null +++ b/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md @@ -0,0 +1,152 @@ +# Pydantic Io In Steps Context + + + + + + +=== "Hera" + + ```python linenums="1" + import sys + from typing import List + + if sys.version_info >= (3, 9): + from typing import Annotated + else: + from typing_extensions import Annotated + + + from hera.shared import global_config + from hera.workflows import Parameter, Steps, WorkflowTemplate, script + from hera.workflows.io.v1 import Input, Output + + global_config.experimental_features["script_pydantic_io"] = True + + + class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + + class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + + class JoinInput(Input): + strings: List[str] + joiner: str + + + class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + + @script(constructor="runner") + def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + + @script(constructor="runner") + def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + + with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with Steps(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) + ``` + +=== "YAML" + + ```yaml linenums="1" + apiVersion: argoproj.io/v1alpha1 + kind: WorkflowTemplate + metadata: + generateName: pydantic-io-in-steps-context-v1- + spec: + entrypoint: d + templates: + - name: d + steps: + - - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - - arguments: + parameters: + - name: strings + value: '{{steps.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + name: join + template: join + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + ``` + diff --git a/examples/workflows/experimental/pydantic-io-in-dag-context.yaml b/examples/workflows/experimental/pydantic-io-in-dag-context.yaml new file mode 100644 index 000000000..a5e4a2e8d --- /dev/null +++ b/examples/workflows/experimental/pydantic-io-in-dag-context.yaml @@ -0,0 +1,84 @@ +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + generateName: pydantic-io-in-steps-context-v1- +spec: + entrypoint: d + templates: + - dag: + tasks: + - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - arguments: + parameters: + - name: strings + value: '{{tasks.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + depends: cut + name: join + template: join + name: d + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/pydantic-io-in-steps-context.yaml b/examples/workflows/experimental/pydantic-io-in-steps-context.yaml new file mode 100644 index 000000000..e9f176668 --- /dev/null +++ b/examples/workflows/experimental/pydantic-io-in-steps-context.yaml @@ -0,0 +1,82 @@ +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + generateName: pydantic-io-in-steps-context-v1- +spec: + entrypoint: d + templates: + - name: d + steps: + - - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - - arguments: + parameters: + - name: strings + value: '{{steps.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + name: join + template: join + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/pydantic_io_in_dag_context.py b/examples/workflows/experimental/pydantic_io_in_dag_context.py new file mode 100644 index 000000000..8066d4eca --- /dev/null +++ b/examples/workflows/experimental/pydantic_io_in_dag_context.py @@ -0,0 +1,53 @@ +import sys +from typing import List + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +from hera.shared import global_config +from hera.workflows import DAG, Parameter, WorkflowTemplate, script +from hera.workflows.io.v1 import Input, Output + +global_config.experimental_features["script_pydantic_io"] = True + + +class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + +class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + +class JoinInput(Input): + strings: List[str] + joiner: str + + +class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + +@script(constructor="runner") +def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + +@script(constructor="runner") +def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + +with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with DAG(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) diff --git a/examples/workflows/experimental/pydantic_io_in_steps_context.py b/examples/workflows/experimental/pydantic_io_in_steps_context.py new file mode 100644 index 000000000..2517b20f8 --- /dev/null +++ b/examples/workflows/experimental/pydantic_io_in_steps_context.py @@ -0,0 +1,53 @@ +import sys +from typing import List + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +from hera.shared import global_config +from hera.workflows import Parameter, Steps, WorkflowTemplate, script +from hera.workflows.io.v1 import Input, Output + +global_config.experimental_features["script_pydantic_io"] = True + + +class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + +class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + +class JoinInput(Input): + strings: List[str] + joiner: str + + +class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + +@script(constructor="runner") +def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + +@script(constructor="runner") +def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + +with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with Steps(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 992936e82..30d8e588f 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -372,6 +372,12 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]: return Step(template=self, **kwargs) if isinstance(_context.pieces[-1], DAG): + # Add dependencies based on context if not explicitly provided + current_task_depends = _context.pieces[-1]._current_task_depends + if current_task_depends and "depends" not in kwargs: + kwargs["depends"] = " && ".join(sorted(current_task_depends)) + current_task_depends.clear() + return Task(template=self, **kwargs) raise InvalidTemplateCall( diff --git a/src/hera/workflows/_mixins.py b/src/hera/workflows/_mixins.py index 89bcc1ab2..d1c549a74 100644 --- a/src/hera/workflows/_mixins.py +++ b/src/hera/workflows/_mixins.py @@ -721,12 +721,12 @@ def __getattribute__(self, name: str) -> Any: except AttributeError: build_obj = None - if build_obj and _context.declaring: + if build_obj and _context.active: fields = get_fields(build_obj.output_class) annotations = get_field_annotations(build_obj.output_class) if name in fields: # If the attribute name is in the build_obj's output class fields, then - # as we are in a declaring context, the access is for a Task/Step output + # as we are in an active context, the access is for a Task/Step output subnode_name = object.__getattribute__(self, "name") subnode_type = object.__getattribute__(self, "_subtype") diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 912f9f6c4..c81e82b78 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -1,6 +1,7 @@ import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, Iterator, List, Optional, Union if sys.version_info >= (3, 11): from typing import Self @@ -39,21 +40,29 @@ BaseModel = object # type: ignore +@contextmanager +def no_active_context() -> Iterator[None]: + pieces = _context.pieces + _context.pieces = [] + try: + yield + finally: + _context.pieces = pieces + + class InputMixin(BaseModel): def __new__(cls, **kwargs): - if _context.declaring: + if _context.active: # Intercept the declaration to avoid validation on the templated strings - # We must then turn off declaring mode to be able to "construct" an instance + # We must then disable the active context to be able to "construct" an instance # of the InputMixin subclass. - _context.declaring = False - instance = cls.construct(**kwargs) - _context.declaring = True - return instance + with no_active_context(): + return cls.construct(**kwargs) else: return super(InputMixin, cls).__new__(cls) def __init__(self, /, **kwargs): - if _context.declaring: + if _context.active: # Return in order to skip validation of `construct`ed instance return @@ -159,17 +168,15 @@ def _get_as_arguments(self) -> ModelArguments: class OutputMixin(BaseModel): def __new__(cls, **kwargs): - if _context.declaring: + if _context.active: # Intercept the declaration to avoid validation on the templated strings - _context.declaring = False - instance = cls.construct(**kwargs) - _context.declaring = True - return instance + with no_active_context(): + return cls.construct(**kwargs) else: return super(OutputMixin, cls).__new__(cls) def __init__(self, /, **kwargs): - if _context.declaring: + if _context.active: # Return in order to skip validation of `construct`ed instance return diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 2c817a09d..3bb324c89 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,7 +48,7 @@ from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass from hera.shared.serialization import serialize from hera.workflows._context import _context -from hera.workflows._meta_mixins import CallableTemplateMixin +from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj from hera.workflows._mixins import ( ArgumentsT, ContainerMixin, @@ -755,6 +755,22 @@ def script_wrapper(func: Callable[FuncIns, FuncR]) -> Callable: def task_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]: """Invokes a `Script` object's `__call__` method using the given SubNode (Step or Task) args/kwargs.""" if _context.active: + if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)): + signature = inspect.signature(func) + output_class = signature.return_annotation + _assert_pydantic_io_enabled(output_class) + + arguments = args[0]._get_as_arguments() + arguments_list = [ + *(arguments.artifacts or []), + *(arguments.parameters or []), + ] + + subnode = s.__call__(arguments=arguments_list, **kwargs) + + if subnode: + subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) + return subnode return s.__call__(*args, **kwargs) return func(*args, **kwargs)