Skip to content

Commit

Permalink
Fix async coroutine limit not respected and add s3/gcs chunk size (#3080
Browse files Browse the repository at this point in the history
) (#3083)

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jan 24, 2025
1 parent 169843a commit 29d024c
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 18 deletions.
29 changes: 29 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@

Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO]

# This is the default chunk size flytekit will use for writing to S3 and GCS. This is set to 25MB by default and is
# configurable by the user if needed. This is used when put() is called on filesystems.
_WRITE_SIZE_CHUNK_BYTES = int(os.environ.get("_F_P_WRITE_CHUNK_SIZE", "26214400")) # 25 * 2**20


def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -108,6 +112,27 @@ def get_fsspec_storage_options(
return {}


def get_additional_fsspec_call_kwargs(protocol: typing.Union[str, tuple], method_name: str) -> Dict[str, Any]:
"""
These are different from the setup args functions defined above. Those kwargs are applied when asking fsspec
to create the filesystem. These kwargs returned here are for when the filesystem's methods are invoked.
:param protocol: s3, gcs, etc.
:param method_name: Pass in the __name__ of the fsspec.filesystem function. _'s will be ignored.
"""
kwargs = {}
method_name = method_name.replace("_", "")
if isinstance(protocol, tuple):
protocol = protocol[0]

# For s3fs and gcsfs, we feel the default chunksize of 50MB is too big.
# Re-evaluate these kwargs when we move off of s3fs to obstore.
if method_name == "put" and protocol in ["s3", "gs"]:
kwargs["chunksize"] = _WRITE_SIZE_CHUNK_BYTES

return kwargs


@decorator
def retry_request(func, *args, **kwargs):
# TODO: Remove this method once s3fs has a new release. https://github.com/fsspec/s3fs/pull/865
Expand Down Expand Up @@ -353,6 +378,10 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)

additional_kwargs = get_additional_fsspec_call_kwargs(file_system.protocol, file_system.put.__name__)
kwargs.update(additional_kwargs)

if isinstance(file_system, AsyncFileSystem):
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
Expand Down
31 changes: 14 additions & 17 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
DEFINITIONS = "definitions"
TITLE = "title"

_TYPE_ENGINE_COROS_BATCH_SIZE = int(os.environ.get("_F_TE_MAX_COROS", "10"))


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
Expand Down Expand Up @@ -1678,10 +1680,9 @@ async def async_to_literal(
raise TypeTransformerFailedError("Expected a list")

t = self.get_sub_type(python_type)
lit_list = [
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val
]
lit_list = await _run_coros_in_chunks(lit_list)
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]

lit_list = await _run_coros_in_chunks(lit_list, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)

return Literal(collection=LiteralCollection(literals=lit_list))

Expand All @@ -1703,7 +1704,7 @@ async def async_to_python_value( # type: ignore

st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await _run_coros_in_chunks(result)
result = await _run_coros_in_chunks(result, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
return result # type: ignore # should be a list, thinks its a tuple

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
Expand Down Expand Up @@ -2150,13 +2151,10 @@ async def async_to_literal(
else:
_, v_type = self.extract_types_or_metadata(python_type)

lit_map[k] = asyncio.create_task(
TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
)

await _run_coros_in_chunks([c for c in lit_map.values()])
for k, v in lit_map.items():
lit_map[k] = v.result()
lit_map[k] = TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
vals = await _run_coros_in_chunks([c for c in lit_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
for idx, k in zip(range(len(vals)), lit_map.keys()):
lit_map[k] = vals[idx]

return Literal(map=LiteralMap(literals=lit_map))

Expand All @@ -2177,12 +2175,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key")
py_map = {}
for k, v in lv.map.literals.items():
fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1])))
py_map[k] = fut
py_map[k] = TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1]))

await _run_coros_in_chunks([c for c in py_map.values()])
for k, v in py_map.items():
py_map[k] = v.result()
vals = await _run_coros_in_chunks([c for c in py_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
for idx, k in zip(range(len(vals)), py_map.keys()):
py_map[k] = vals[idx]

return py_map

Expand Down
34 changes: 33 additions & 1 deletion tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import mock
import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from mock import AsyncMock

from flytekit.configuration import Config
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.data_persistence import FileAccessProvider, get_additional_fsspec_call_kwargs
from flytekit.core.local_fsspec import FlyteLocalFileSystem


Expand Down Expand Up @@ -210,6 +211,37 @@ def __init__(self, *args, **kwargs):
fp.get_filesystem("testgetfs", test_arg="test_arg")


def test_get_additional_fsspec_call_kwargs():
with mock.patch("flytekit.core.data_persistence._WRITE_SIZE_CHUNK_BYTES", 12345):
kwargs = get_additional_fsspec_call_kwargs(("s3", "s3a"), "put")
assert kwargs == {"chunksize": 12345}

kwargs = get_additional_fsspec_call_kwargs("s3", "_put")
assert kwargs == {"chunksize": 12345}

kwargs = get_additional_fsspec_call_kwargs("s3", "get")
assert kwargs == {}


@pytest.mark.asyncio
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_async_filesystem_for_path", new_callable=AsyncMock)
@mock.patch("flytekit.core.data_persistence.get_additional_fsspec_call_kwargs")
async def test_chunk_size(mock_call_kwargs, mock_get_fs):
mock_call_kwargs.return_value = {"chunksize": 1234}
mock_fs = mock.MagicMock()
mock_get_fs.return_value = mock_fs

mock_fs.protocol = ("s3", "s3a")
fp = FileAccessProvider("/tmp", "s3://container/path/within/container")

def put(*args, **kwargs):
assert "chunksize" in kwargs

mock_fs.put = put
upload_location = await fp._put("/tmp/foo", "s3://bar")
assert upload_location == "s3://bar"


@pytest.mark.sandbox_test
def test_put_raw_data_bytes():
dc = Config.for_sandbox().data_config
Expand Down
85 changes: 85 additions & 0 deletions tests/flytekit/unit/core/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import typing

import mock
import pytest

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import (
AsyncTypeTransformer,
TypeEngine,
)
from flytekit.models.literals import (
Literal,
Primitive,
Scalar,
)
from flytekit.models.types import LiteralType, SimpleType


class MyInt:
def __init__(self, x: int):
self.val = x

def __eq__(self, other):
if not isinstance(other, MyInt):
return False
return other.val == self.val


class MyIntAsyncTransformer(AsyncTypeTransformer[MyInt]):
def __init__(self):
super().__init__(name="MyAsyncInt", t=MyInt)
self.my_lock = asyncio.Lock()
self.my_count = 0

def assert_type(self, t, v):
return

def get_literal_type(self, t: typing.Type[MyInt]) -> LiteralType:
return LiteralType(simple=SimpleType.INTEGER)

async def async_to_literal(
self,
ctx: FlyteContext,
python_val: MyInt,
python_type: typing.Type[MyInt],
expected: LiteralType,
) -> Literal:
async with self.my_lock:
self.my_count += 1
if self.my_count > 2:
raise ValueError("coroutine count exceeded")
await asyncio.sleep(0.1)
lit = Literal(scalar=Scalar(primitive=Primitive(integer=python_val.val)))

async with self.my_lock:
self.my_count -= 1

return lit

async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[MyInt]
) -> MyInt:
return MyInt(lv.scalar.primitive.integer)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[MyInt]:
return MyInt


@pytest.mark.asyncio
async def test_coroutine_batching_of_list_transformer():
TypeEngine.register(MyIntAsyncTransformer())

lt = LiteralType(simple=SimpleType.INTEGER)
python_val = [MyInt(10), MyInt(11), MyInt(12), MyInt(13), MyInt(14)]
ctx = FlyteContext.current_context()

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 2):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
with pytest.raises(ValueError):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

del TypeEngine._REGISTRY[MyInt]

0 comments on commit 29d024c

Please sign in to comment.