Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom serialization for non-user types and non-serializable types for Hera runner (Parameter/Artifact inputs and outputs) #1166

Open
elliotgunton opened this issue Aug 21, 2024 · 2 comments
Labels
type:enhancement A general enhancement

Comments

@elliotgunton
Copy link
Collaborator

elliotgunton commented Aug 21, 2024

Is your feature request related to a problem? Please describe.
Tried to use pandas.DataFrame for outputs, got error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 222, in _runner
    output = _save_annotated_return_outputs(function(**kwargs), output_annotations)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 250, in _save_annotated_return_outputs
    _write_to_path(path, value)
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 326, in _write_to_path
    output_string = serialize(output_value)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/shared/serialization.py", line 51, in serialize
    if value == MISSING:
       ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/pandas/core/generic.py", line 1577, in __nonzero__
    raise ValueError(
ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

Python:

global_config.experimental_features["script_pydantic_io"] = True


class Datasets(Output):
    X_train: pd.DataFrame
    X_test: pd.DataFrame
    y_train: pd.Series
    y_test: pd.Series

    class Config:
        arbitrary_types_allowed=True


# Load dataset
@script(constructor="runner")
def load_and_split_dataset(
    dataset_path: Annotated[
        Path,
        S3Artifact(...),
    ],
) -> Datasets:
    data = pd.read_csv(dataset_path)

    # Split into features and target
    X = data.drop("Outcome", axis=1)
    y = data["Outcome"]

    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    return Datasets(
        X_train=X_train,
        X_test=X_test,
        y_train=y_train,
        y_test=y_test,
    )

Pandas DataFrames have a to_json method which would make things easier, but I have no way to tell the serialize function in hera.shared.serialization what to do with DataFrames. I also can't change the class code, hence "non-user" type (I could subclass it though?).

Describe the solution you'd like
A clear and concise description of what you want to happen.

An easy way to plug in the "how" for serializing custom types in the runner, e.g. as part of the type annotation, or a global setter such as global_config.serializer = my_serializer, or maybe in the RunnerScriptConstructor? (Needs some more thought)

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

  • Just use strs and use the DataFrame.to_json method
  • Create a subclass and associated runtime files of RunnerScriptConstructor so I can use my own serialize function

Additional context
Add any other context or screenshots about the feature request here.

@elliotgunton elliotgunton added the type:enhancement A general enhancement label Aug 21, 2024
@elliotgunton
Copy link
Collaborator Author

elliotgunton commented Aug 22, 2024

Also related and a complete blocker to using the new decorators - I have no way to output a bytes Artifact from a template -

Using

class ModelTrainingInput(Input):
    X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)]
    y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)]
    model: Annotated[Path, Artifact(name="model", output=True)]

Gets the following error when building the workflow

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/elliot/projects/ds-blog/ds_blog/__main__.py", line 14, in <module>
    from ds_blog.workflow import w
  File "/Users/elliot/projects/ds-blog/ds_blog/workflow.py", line 138, in <module>
    @w.dag()
     ^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 826, in decorator
    func_return = func(input_obj)
                  ^^^^^^^^^^^^^^^
  File "/Users/elliot/projects/ds-blog/ds_blog/workflow.py", line 144, in run_training
    model_training(
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 670, in script_call_wrapper
    return self._create_subnode(subnode_name, func, script_template, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 550, in _create_subnode
    subnode_args = args[0]._get_as_arguments()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/io/_io_mixins.py", line 152, in _get_as_arguments
    templated_value = serialize(self_dict[field])
                                ~~~~~~~~~^^^^^^^
KeyError: 'model'

And using

class ModelTrainingOutput(Output):
    model: Annotated[bytes, Artifact(name="model", archive=NoneArchiveStrategy())]

@w.script()
def model_training(model_training_input: ModelTrainingInput) -> ModelTrainingOutput:
    X_train = np.array(model_training_input.X_train)
    y_train = pd.Series(model_training_input.y_train)
    model = LogisticRegression(random_state=42)
    model.fit(X_train, y_train)
    return ModelTrainingOutput(model={"model": pickle.dumps(model)})

gets the following error when running on the cluster

File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 222, in _runner
    output = _save_annotated_return_outputs(function(**kwargs), output_annotations)
                                            ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 59, in inner
    return f(**filtered_kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/pydantic/validate_call_decorator.py", line 60, in wrapper_function
    return validate_call_wrapper(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/pydantic/_internal/_validate_call.py", line 96, in __call__
    res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/ds_blog/workflow.py", line 134, in model_training
    return ModelTrainingOutput(model=json.dumps(pickle.dumps(model)))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/__init__.py", line 231, in dumps
    return _default_encoder.encode(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 200, in encode
    chunks = self.iterencode(o, _one_shot=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 258, in iterencode
    return _iterencode(o, 0)
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type bytes is not JSON serializable

Workaround is to use the old syntax with an "output" artifact in the function inputs i.e.

@script(constructor="runner")
def model_training(
    X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)],
    y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)],
    model_path: Annotated[Path, Artifact(name="model", archive=NoneArchiveStrategy(), output=True)],
):

And doing

    model_path.write_bytes(pickle.dumps(model))

@elliotgunton elliotgunton changed the title Custom serialization for non-user types for Hera runner Custom serialization for non-user types and non-serializable for Hera runner Aug 22, 2024
@elliotgunton elliotgunton changed the title Custom serialization for non-user types and non-serializable for Hera runner Custom serialization for non-user types and non-serializable types for Hera runner Aug 22, 2024
@elliotgunton elliotgunton changed the title Custom serialization for non-user types and non-serializable types for Hera runner Custom serialization for non-user types and non-serializable types for Hera runner (Parameter/Artifact inputs and outputs) Oct 1, 2024
@elliotgunton
Copy link
Collaborator Author

An easy way to plug in the "how" for serializing custom types in the runner

Good idea from #903 - if the loader is a BaseModel type we can just use its parse_raw function:

def fan_in(*, responses: Annotated[list[Magic], Parameter(loader=Magic)])

Otherwise the loader could be any Callable[[str], X] where X is the user's class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:enhancement A general enhancement
Projects
None yet
Development

No branches or pull requests

1 participant