diff --git a/src/finch/_files.py b/src/finch/_files.py index b6e8af8b..0d2022ae 100644 --- a/src/finch/_files.py +++ b/src/finch/_files.py @@ -13,12 +13,17 @@ FileContent, RequestFiles, HttpxFileTypes, + Base64FileInput, HttpxFileContent, HttpxRequestFiles, ) from ._utils import is_tuple_t, is_mapping_t, is_sequence_t +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + + def is_file_content(obj: object) -> TypeGuard[FileContent]: return ( isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) diff --git a/src/finch/_types.py b/src/finch/_types.py index 1ccf4a5e..037cd3e9 100644 --- a/src/finch/_types.py +++ b/src/finch/_types.py @@ -41,8 +41,10 @@ ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] ProxiesTypes = Union[str, Proxy, ProxiesDict] if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] FileContent = Union[IO[bytes], bytes, PathLike[str]] else: + Base64FileInput = Union[IO[bytes], PathLike] FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. FileTypes = Union[ # file (or bytes) diff --git a/src/finch/_utils/_transform.py b/src/finch/_utils/_transform.py index 9c769306..1bd1330c 100644 --- a/src/finch/_utils/_transform.py +++ b/src/finch/_utils/_transform.py @@ -1,9 +1,13 @@ from __future__ import annotations +import io +import base64 +import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime from typing_extensions import Literal, get_args, override, get_type_hints +import anyio import pydantic from ._utils import ( @@ -11,6 +15,7 @@ is_mapping, is_iterable, ) +from .._files import is_base64_file_input from ._typing import ( is_list_type, is_union_type, @@ -29,7 +34,7 @@ # TODO: ensure works correctly with forward references in all cases -PropertyFormat = Literal["iso8601", "custom"] +PropertyFormat = Literal["iso8601", "base64", "custom"] class PropertyInfo: @@ -201,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N if format_ == "custom" and format_template is not None: return data.strftime(format_template) + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + return data @@ -323,6 +344,22 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ if format_ == "custom" and format_template is not None: return data.strftime(format_template) + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + return data diff --git a/tests/sample_file.txt b/tests/sample_file.txt new file mode 100644 index 00000000..af5626b4 --- /dev/null +++ b/tests/sample_file.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/tests/test_transform.py b/tests/test_transform.py index da9a5f15..187d119e 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,11 +1,14 @@ from __future__ import annotations +import io +import pathlib from typing import Any, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict import pytest +from finch._types import Base64FileInput from finch._utils import ( PropertyInfo, transform as _transform, @@ -17,6 +20,8 @@ _T = TypeVar("_T") +SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt") + async def transform( data: _T, @@ -377,3 +382,27 @@ async def test_iterable_union_str(use_async: bool) -> None: assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [ {"fooBaz": "bar"} ] + + +class TypedDictBase64Input(TypedDict): + foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")] + + +@parametrize +@pytest.mark.asyncio +async def test_base64_file_input(use_async: bool) -> None: + # strings are left as-is + assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"} + + # pathlib.Path is automatically converted to base64 + assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQo=" + } # type: ignore[comparison-overlap] + + # io instances are automatically converted to base64 + assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap]