Skip to content

Commit

Permalink
fix: Response compatibility issue
Browse files Browse the repository at this point in the history
  • Loading branch information
xianml committed Oct 24, 2023
1 parent c272c64 commit ed2417a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
29 changes: 20 additions & 9 deletions src/bentoml/_internal/io_descriptors/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t

from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import StreamingResponse

from ..service.openapi import SUCCESS_DESCRIPTION
Expand Down Expand Up @@ -161,20 +162,30 @@ async def from_http_request(self, request: Request) -> str:

async def to_http_response(
self, obj: str | t.AsyncGenerator[str, None], ctx: Context | None = None
) -> StreamingResponse:
) -> Response | StreamingResponse:
content_stream = iter([obj]) if isinstance(obj, str) else obj

if ctx is not None:
res = StreamingResponse(
content_stream,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
if isinstance(obj, str):
res = Response(
obj,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
else:
res = StreamingResponse(
content_stream,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
return StreamingResponse(content_stream, media_type=self._mime_type)
if isinstance(obj, str):
return Response(obj, media_type=self._mime_type)
else:
return StreamingResponse(content_stream, media_type=self._mime_type)

async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
if isinstance(field, bytes):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/_internal/io/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
from typing import TYPE_CHECKING

import Exception
import pytest

from bentoml._internal.utils import LazyLoader
Expand All @@ -11,8 +12,10 @@
from bentoml.io import JSON
from bentoml.io import Image
from bentoml.io import Multipart
from bentoml.io import Text

example = Multipart(arg1=JSON(), arg2=Image(mime_type="image/bmp", pilmode="RGB"))
example2 = Multipart(arg1=Image(), arg2=Text())

if TYPE_CHECKING:
import PIL.Image as PILImage
Expand Down Expand Up @@ -98,3 +101,14 @@ async def test_multipart_from_to_proto(img_file: str):
)
assert isinstance(message, pb.Multipart)
assert message.fields["arg1"].json.struct_value.fields["asd"].string_value == "asd"


@pytest.mark.asyncio
async def test_multipart_to_http_response(img_file: str):
try:
res = await example2.to_http_response(
{"arg1": PILImage.open(img_file), "arg2": "test prompt"}
)
assert not res
except Exception as e:
pytest.fail(f"Unexpected exception: {e}")

0 comments on commit ed2417a

Please sign in to comment.