-
Notifications
You must be signed in to change notification settings - Fork 801
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Frost Ming <[email protected]>
- Loading branch information
Showing
16 changed files
with
574 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,15 @@ | ||
from .http import AsyncHTTPClient | ||
from .http import SyncHTTPClient | ||
from .local import AsyncLocalClient | ||
from .local import SyncLocalClient | ||
from .manager import ClientManager | ||
from .testing import TestingClient | ||
|
||
__all__ = ["AsyncHTTPClient", "SyncHTTPClient", "TestingClient"] | ||
__all__ = [ | ||
"AsyncHTTPClient", | ||
"SyncHTTPClient", | ||
"TestingClient", | ||
"SyncLocalClient", | ||
"AsyncLocalClient", | ||
"ClientManager", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import inspect | ||
import typing as t | ||
from functools import partial | ||
|
||
from .base import AbstractClient | ||
|
||
if t.TYPE_CHECKING: | ||
from ..server.service import Service | ||
|
||
T = t.TypeVar("T") | ||
|
||
|
||
class LocalClient(AbstractClient): | ||
def __init__(self, service: Service) -> None: | ||
self.service = service | ||
self.servable = service.init_servable() | ||
|
||
|
||
class SyncLocalClient(LocalClient): | ||
def __init__(self, service: Service): | ||
super().__init__(service) | ||
for name in self.servable.__servable_methods__: | ||
setattr(self, name, partial(self.call, name)) | ||
|
||
def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any: | ||
if name not in self.servable.__servable_methods__: | ||
raise ValueError(f"Method {name} not found") | ||
result = getattr(self.servable, name)(*args, **kwargs) | ||
if inspect.iscoroutine(result): | ||
return asyncio.run(result) | ||
elif inspect.isasyncgen(result): | ||
from bentoml._internal.utils import async_gen_to_sync | ||
|
||
return async_gen_to_sync(result) | ||
return result | ||
|
||
def __enter__(self: T) -> T: | ||
return self | ||
|
||
|
||
class AsyncLocalClient(LocalClient): | ||
def __init__(self, service: Service): | ||
super().__init__(service) | ||
for name in self.servable.__servable_methods__: | ||
setattr(self, name, partial(self.call, name)) | ||
|
||
def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any: | ||
from starlette.concurrency import run_in_threadpool | ||
|
||
from bentoml._internal.utils import is_async_callable | ||
from bentoml._internal.utils import sync_gen_to_async | ||
|
||
meth = getattr(self.servable, name) | ||
if inspect.isgeneratorfunction(meth): | ||
return sync_gen_to_async(meth(*args, **kwargs)) | ||
elif not is_async_callable(meth): | ||
return run_in_threadpool(meth, *args, **kwargs) | ||
return meth(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
import typing as t | ||
|
||
from simple_di import Provide | ||
from simple_di import inject | ||
|
||
from bentoml._internal.container import BentoMLContainer | ||
|
||
from .base import AbstractClient | ||
from .http import AsyncHTTPClient | ||
from .http import SyncHTTPClient | ||
from .local import AsyncLocalClient | ||
from .local import SyncLocalClient | ||
|
||
if t.TYPE_CHECKING: | ||
from ..servable import Servable | ||
from ..server import Service | ||
|
||
|
||
class ClientManager: | ||
@inject | ||
def __init__( | ||
self, | ||
service: Service, | ||
runner_map: dict[str, str] = Provide[BentoMLContainer.remote_runner_mapping], | ||
) -> None: | ||
self.service = service | ||
self._runner_map = runner_map | ||
self._sync_clients: dict[str, AbstractClient] = {} | ||
self._async_clients: dict[str, AbstractClient] = {} | ||
|
||
def get_client(self, name_or_class: str | type[Servable]) -> AbstractClient: | ||
caller_frame = inspect.currentframe().f_back # type: ignore | ||
assert caller_frame is not None | ||
is_async = bool( | ||
caller_frame.f_code.co_flags & inspect.CO_COROUTINE | ||
or caller_frame.f_code.co_flags & inspect.CO_ASYNC_GENERATOR | ||
) | ||
cache = self._async_clients if is_async else self._sync_clients | ||
name = name_or_class if isinstance(name_or_class, str) else name_or_class.name | ||
if name not in cache: | ||
dep = next( | ||
(dep for dep in self.service.dependencies if dep.name == name), None | ||
) | ||
if dep is None: | ||
raise ValueError( | ||
f"Dependency service {name} not found, please specify it in dependencies list" | ||
) | ||
if name in self._runner_map: | ||
client_cls = AsyncHTTPClient if is_async else SyncHTTPClient | ||
client = client_cls(self._runner_map[name], servable=dep.servable_cls) | ||
else: | ||
client_cls = AsyncLocalClient if is_async else SyncLocalClient | ||
client = client_cls(dep) | ||
cache[name] = client | ||
return cache[name] | ||
|
||
async def cleanup(self) -> None: | ||
for client in self._async_clients.values(): | ||
await client.__aexit__(None, None, None) | ||
for client in self._sync_clients.values(): | ||
await client.__aexit__(None, None, None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.