Skip to content

Commit

Permalink
fix: log errors inside a streaming response and detect service name c…
Browse files Browse the repository at this point in the history
…onflicts (#4767)

fix: log errors inside a streaming response

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored May 31, 2024
1 parent c03f2d9 commit 1cb5781
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 53 deletions.
42 changes: 25 additions & 17 deletions src/_bentoml_sdk/io_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import io
import logging
import pathlib
import sys
import typing as t
Expand Down Expand Up @@ -34,6 +35,7 @@


DEFAULT_TEXT_MEDIA_TYPE = "text/plain"
logger = logging.getLogger("bentoml.serve")


def is_file_type(type_: type) -> bool:
Expand Down Expand Up @@ -180,29 +182,35 @@ async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response:
if inspect.isasyncgen(obj):

async def async_stream() -> t.AsyncGenerator[str | bytes, None]:
async for item in obj:
if isinstance(item, (str, bytes)):
yield item
else:
obj_item = cls(item) if issubclass(cls, RootModel) else item
for chunk in serde.serialize_model(
t.cast(IODescriptor, obj_item)
).data:
yield chunk
try:
async for item in obj:
if isinstance(item, (str, bytes)):
yield item
else:
obj_item = cls(item) if issubclass(cls, RootModel) else item
for chunk in serde.serialize_model(
t.cast(IODescriptor, obj_item)
).data:
yield chunk
except Exception:
logger.exception("Error while streaming response")

return StreamingResponse(async_stream(), media_type=cls.mime_type())

elif inspect.isgenerator(obj):

def content_stream() -> t.Generator[str | bytes, None, None]:
for item in obj:
if isinstance(item, (str, bytes)):
yield item
else:
obj_item = cls(item) if issubclass(cls, RootModel) else item
yield from serde.serialize_model(
t.cast(IODescriptor, obj_item)
).data
try:
for item in obj:
if isinstance(item, (str, bytes)):
yield item
else:
obj_item = cls(item) if issubclass(cls, RootModel) else item
yield from serde.serialize_model(
t.cast(IODescriptor, obj_item)
).data
except Exception:
logger.exception("Error while streaming response")

return StreamingResponse(content_stream(), media_type=cls.mime_type())
elif not issubclass(cls, RootModel):
Expand Down
16 changes: 15 additions & 1 deletion src/_bentoml_sdk/service/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from bentoml._internal.context import ServiceContext
from bentoml._internal.models import Model
from bentoml._internal.utils import dict_filter_none
from bentoml.exceptions import BentoMLConfigException
from bentoml.exceptions import BentoMLException

from ..method import APIMethod
Expand Down Expand Up @@ -139,7 +140,20 @@ def all_services(self) -> dict[str, Service[t.Any]]:
"""Get a map of the service and all recursive dependencies"""
services: dict[str, Service[t.Any]] = {self.name: self}
for dependency in self.dependencies.values():
services.update(dependency.on.all_services())
dependents = dependency.on.all_services()
conflict = next(
(
k
for k in dependents
if k in services and dependents[k] is not services[k]
),
None,
)
if conflict:
raise BentoMLConfigException(
f"Dependency conflict: {conflict} is already defined by {services[conflict].inner}"
)
services.update(dependents)
return services

@property
Expand Down
35 changes: 0 additions & 35 deletions src/bentoml/_internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,40 +535,5 @@ def is_async_callable(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable[t.A
)


def async_gen_to_sync(
gen: t.AsyncGenerator[T, None], *, loop: asyncio.AbstractEventLoop | None = None
) -> t.Generator[T, None, None]:
"""
Convert an async generator to a sync generator
"""
if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
pass
finally:
loop.close()
asyncio.set_event_loop(None)


async def sync_gen_to_async(
gen: t.Generator[T, None, None],
) -> t.AsyncGenerator[T, None]:
"""
Convert a sync generator to an async generator
"""
from starlette.concurrency import run_in_threadpool

while True:
try:
rv = await run_in_threadpool(gen.__next__)
yield rv
except StopIteration:
break


def dict_filter_none(d: dict[str, t.Any]) -> dict[str, t.Any]:
return {k: v for k, v in d.items() if v is not None}

0 comments on commit 1cb5781

Please sign in to comment.