From 7eedc687876d636dd6c0f1d8d4305a608d7ec8db Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 16 Dec 2024 11:13:30 -0800 Subject: [PATCH] Backport changes for 14.1 (#3006) Signed-off-by: taieeuu Signed-off-by: Thomas Newton Signed-off-by: Samhita Alla Signed-off-by: Yee Hing Tong Signed-off-by: Eduardo Apolinario --- .github/workflows/pythonbuild.yml | 3 ++ flytekit/core/data_persistence.py | 39 ++++++---------- flytekit/remote/remote.py | 20 ++++++--- plugins/flytekit-dbt/setup.py | 5 +-- .../flytekitplugins/inference/ollama/serve.py | 2 +- .../flytekitplugins/polars/sd_transformers.py | 3 +- .../tests/test_polars_plugin_sd.py | 27 ++++++++++- .../unit/core/test_data_persistence.py | 18 +++++++- tests/flytekit/unit/remote/test_remote.py | 45 ++++++++++++++++++- 9 files changed, 121 insertions(+), 41 deletions(-) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 7dda4f5588..d8347a798d 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -299,6 +299,9 @@ jobs: FLYTEKIT_IMAGE: localhost:30000/flytekit:dev FLYTEKIT_CI: 1 PYTEST_OPTS: -n2 + AWS_ENDPOINT_URL: 'http://localhost:30002' + AWS_ACCESS_KEY_ID: minio + AWS_SECRET_ACCESS_KEY: miniostorage run: | make ${{ matrix.makefile-cmd }} - name: Codecov diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 7035147016..0640bc2eb5 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -423,47 +423,34 @@ async def async_put_raw_data( r = await self._put(from_path, to_path, **kwargs) return r or to_path + # See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to + # support effectively async open(). For now these use-cases below will revert to sync calls. # raw bytes if isinstance(lpath, bytes): - fs = await self.get_async_filesystem_for_path(to_path) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - s.write(lpath) - else: - with fs.open(to_path, "wb", **kwargs) as s: - s.write(lpath) - + fs = self.get_filesystem_for_path(to_path) + with fs.open(to_path, "wb", **kwargs) as s: + s.write(lpath) return to_path # If lpath is a buffered reader of some kind if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) + with fs.open(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) return to_path if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) + with fs.open(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) return to_path raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 549f0045d3..190143334b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -524,7 +524,9 @@ def get_launch_plan_from_then_node( if node.branch_node: get_launch_plan_from_branch(node.branch_node, node_launch_plans) - return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + flyte_workflow = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + flyte_workflow.template._id = workflow_id + return flyte_workflow def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchPlan: """ @@ -863,13 +865,17 @@ async def _serialize_and_register( cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items())) tasks = [] loop = asyncio.get_running_loop() - for entity, cp_entity in cp_task_entity_map.items(): + for task_entity, cp_entity in cp_task_entity_map.items(): tasks.append( loop.run_in_executor( None, - functools.partial(self.raw_register, cp_entity, serialization_settings, version, og_entity=entity), + functools.partial( + self.raw_register, cp_entity, serialization_settings, version, og_entity=task_entity + ), ) ) + if task_entity == entity: + registered_entity = await tasks[-1] identifiers_or_exceptions = [] identifiers_or_exceptions.extend(await asyncio.gather(*tasks, return_exceptions=True)) @@ -882,15 +888,17 @@ async def _serialize_and_register( raise ie # serial register cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items())) - for entity, cp_entity in cp_other_entities.items(): + for non_task_entity, cp_entity in cp_other_entities.items(): try: identifiers_or_exceptions.append( - self.raw_register(cp_entity, serialization_settings, version, og_entity=entity) + self.raw_register(cp_entity, serialization_settings, version, og_entity=non_task_entity) ) except RegistrationSkipped as e: logger.info(f"Skipping registration... {e}") continue - return identifiers_or_exceptions[-1] + if non_task_entity == entity: + registered_entity = identifiers_or_exceptions[-1] + return registered_entity def register_task( self, diff --git a/plugins/flytekit-dbt/setup.py b/plugins/flytekit-dbt/setup.py index 08899d42ce..f683449fea 100644 --- a/plugins/flytekit-dbt/setup.py +++ b/plugins/flytekit-dbt/setup.py @@ -4,10 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = [ - "flytekit>=1.3.0b2", - "dbt-core<1.8.0", -] +plugin_requires = ["flytekit>=1.3.0b2", "dbt-core>=1.6.0,<1.8.0", "networkx>=2.5"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py index 81e68618ca..c8f93c585e 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py @@ -164,7 +164,7 @@ def __init__(self, *args, **kwargs): name=container_name, image="python:3.11-slim", command=["/bin/sh", "-c"], - args=[f"pip install requests && pip install ollama && {command}"], + args=[f"pip install requests && pip install ollama==0.3.3 && {command}"], resources=V1ResourceRequirements( requests={ "cpu": self._model_cpu, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 474901544d..e6359641ca 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -69,9 +69,10 @@ def encode( df.to_parquet(output_bytes) if structured_dataset.uri is not None: + output_bytes.seek(0) fs = ctx.file_access.get_filesystem_for_path(path=structured_dataset.uri) with fs.open(structured_dataset.uri, "wb") as s: - s.write(output_bytes) + s.write(output_bytes.read()) output_uri = structured_dataset.uri else: remote_fn = "00000" # 00000 is our default unnamed parquet filename diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index c2d4a39be7..9acae1c274 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -5,7 +5,7 @@ import pytest from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer from typing_extensions import Annotated -from packaging import version +import numpy as np from polars.testing import assert_frame_equal from flytekit import kwtypes, task, workflow @@ -134,3 +134,28 @@ def consume_sd_return_sd(sd: StructuredDataset) -> StructuredDataset: opened_sd = opened_sd.collect() assert_frame_equal(opened_sd, polars_df) + + +def test_with_uri(): + temp_file = tempfile.mktemp() + + @task + def random_dataframe(num_rows: int) -> StructuredDataset: + feature_1_list = np.random.randint(low=100, high=999, size=(num_rows,)) + feature_2_list = np.random.normal(loc=0, scale=1, size=(num_rows, )) + pl_df = pl.DataFrame({'protein_length': feature_1_list, + 'protein_feature': feature_2_list}) + sd = StructuredDataset(dataframe=pl_df, uri=temp_file) + return sd + + @task + def consume(df: pd.DataFrame): + print(df.head(5)) + print(df.describe()) + + @workflow + def my_wf(num_rows: int): + pl = random_dataframe(num_rows=num_rows) + consume(pl) + + my_wf(num_rows=100) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index d992ed1fa5..116717b92d 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,16 +1,17 @@ import io import os -import fsspec import pathlib import random import string import sys import tempfile +import fsspec import mock import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from flytekit.configuration import Config from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.local_fsspec import FlyteLocalFileSystem @@ -207,3 +208,18 @@ def __init__(self, *args, **kwargs): fp = FileAccessProvider("/tmp", "s3://my-bucket") fp.get_filesystem("testgetfs", test_arg="test_arg") + + +@pytest.mark.sandbox_test +def test_put_raw_data_bytes(): + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + prefix = provider.get_random_string() + provider.put_raw_data(lpath=b"hello", upload_prefix=prefix, file_name="hello_bytes") + provider.put_raw_data(lpath=io.BytesIO(b"hello"), upload_prefix=prefix, file_name="hello_bytes_io") + provider.put_raw_data(lpath=io.StringIO("hello"), upload_prefix=prefix, file_name="hello_string_io") + + fs = provider.get_filesystem("s3") + listing = fs.ls(f"{raw_output}{prefix}/") + assert len(listing) == 3 diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 85b51ca2fd..98b1139ac6 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -35,7 +35,7 @@ from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution from flytekit.models.task import Task -from flytekit.remote import FlyteTask +from flytekit.remote import FlyteTask, FlyteWorkflow from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan @@ -811,3 +811,46 @@ def wf() -> int: # the second one should rr.register_launch_plan(lp2, version="1", serialization_settings=ss) mock_client.update_launch_plan.assert_called() + + +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_register_task_with_node_dependency_hints(mock_client): + @task + def task0(): + return None + + @workflow + def workflow0(): + return task0() + + @dynamic(node_dependency_hints=[workflow0]) + def dynamic0(): + return workflow0() + + @workflow + def workflow1(): + return dynamic0() + + rr = FlyteRemote( + Config.for_sandbox(), + default_project="flytesnacks", + default_domain="development", + ) + + ss = SerializationSettings( + image_config=ImageConfig.from_images("docker.io/abc:latest"), + version="dummy_version", + ) + + registered_task = rr.register_task(dynamic0, ss) + assert isinstance(registered_task, FlyteTask) + assert registered_task.id.resource_type == ResourceType.TASK + assert registered_task.id.project == "flytesnacks" + assert registered_task.id.domain == "development" + # When running via `make unit_test` there is a `__-channelexec__` prefix added to the name. + assert registered_task.id.name.endswith("tests.flytekit.unit.remote.test_remote.dynamic0") + assert registered_task.id.version == "dummy_version" + + registered_workflow = rr.register_workflow(workflow1, ss) + assert isinstance(registered_workflow, FlyteWorkflow) + assert registered_workflow.id == Identifier(ResourceType.WORKFLOW, "flytesnacks", "development", "tests.flytekit.unit.remote.test_remote.workflow1", "dummy_version")