Skip to content

Commit

Permalink
feat: remote client impl
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Oct 24, 2023
1 parent b16a2fa commit 2911853
Show file tree
Hide file tree
Showing 16 changed files with 574 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/bentoml/_internal/configuration/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def cloud_config(bentoml_home: str = Provide[bentoml_home]) -> Path:
serialization_strategy: providers.Static[SerializationStrategy] = providers.Static(
"EXPORT_BENTO"
)
worker_index: providers.Static[int] = providers.Static(0)

@providers.SingletonFactory
@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion src/bentoml/_internal/service/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def import_service(
import_service("fraud_detector.py")
import_service("fraud_detector")
"""
from bentoml import Service
if BentoMLContainer.new_io:
from bentoml_io.server import Service
else:
from bentoml import Service

prev_cwd = None
sys_path_modified = False
Expand Down
16 changes: 16 additions & 0 deletions src/bentoml/_internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,19 @@ def async_gen_to_sync(
finally:
loop.close()
asyncio.set_event_loop(None)


async def sync_gen_to_async(
gen: t.Generator[T, None, None]
) -> t.AsyncGenerator[T, None]:
"""
Convert a sync generator to an async generator
"""
from starlette.concurrency import run_in_threadpool

while True:
try:
rv = await run_in_threadpool(gen.__next__)
yield rv
except StopIteration:
break
2 changes: 0 additions & 2 deletions src/bentoml_cli/worker/http_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ def main(
# worker ID is not set; this server is running in standalone mode
# and should not be concerned with the status of its runners
BentoMLContainer.config.runner_probe.enabled.set(False)
else:
BentoMLContainer.worker_index.set(worker_id)

BentoMLContainer.development_mode.set(development_mode)
if prometheus_dir is not None:
Expand Down
12 changes: 11 additions & 1 deletion src/bentoml_io/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from .http import AsyncHTTPClient
from .http import SyncHTTPClient
from .local import AsyncLocalClient
from .local import SyncLocalClient
from .manager import ClientManager
from .testing import TestingClient

__all__ = ["AsyncHTTPClient", "SyncHTTPClient", "TestingClient"]
__all__ = [
"AsyncHTTPClient",
"SyncHTTPClient",
"TestingClient",
"SyncLocalClient",
"AsyncLocalClient",
"ClientManager",
]
8 changes: 8 additions & 0 deletions src/bentoml_io/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
import abc
import typing as t

T = t.TypeVar("T")


class AbstractClient(abc.ABC):
@abc.abstractmethod
def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Call a service method by its name.
It takes the same arguments as the service method.
"""

async def __aenter__(self: T) -> T:
return self

async def __aexit__(self, *args: t.Any) -> None:
pass
11 changes: 4 additions & 7 deletions src/bentoml_io/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from bentoml._internal.utils.uri import uri_to_path
from bentoml.exceptions import BentoMLException

from ..servable import Servable
from .base import AbstractClient

if t.TYPE_CHECKING:
Expand All @@ -24,6 +23,7 @@
from aiohttp import ClientSession

from ..models import IODescriptor
from ..servable import Servable

T = t.TypeVar("T", bound="HTTPClient")

Expand Down Expand Up @@ -275,6 +275,9 @@ async def close(self) -> None:
if self._client is not None and not self._client.closed:
await self._client.close()

async def __aexit__(self, *args: t.Any) -> None:
return await self.close()


class SyncHTTPClient(HTTPClient):
"""A synchronous client for BentoML service.
Expand Down Expand Up @@ -367,9 +370,3 @@ async def _get_stream(
assert inspect.isasyncgen(resp)
async for data in resp:
yield data

async def __aenter__(self) -> HTTPClient:
return self

async def __aexit__(self, exc_type: t.Any, exc: t.Any, tb: t.Any) -> None:
await self.close()
61 changes: 61 additions & 0 deletions src/bentoml_io/client/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import asyncio
import inspect
import typing as t
from functools import partial

from .base import AbstractClient

if t.TYPE_CHECKING:
from ..server.service import Service

T = t.TypeVar("T")


class LocalClient(AbstractClient):
def __init__(self, service: Service) -> None:
self.service = service
self.servable = service.init_servable()


class SyncLocalClient(LocalClient):
def __init__(self, service: Service):
super().__init__(service)
for name in self.servable.__servable_methods__:
setattr(self, name, partial(self.call, name))

def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any:
if name not in self.servable.__servable_methods__:
raise ValueError(f"Method {name} not found")
result = getattr(self.servable, name)(*args, **kwargs)
if inspect.iscoroutine(result):
return asyncio.run(result)
elif inspect.isasyncgen(result):
from bentoml._internal.utils import async_gen_to_sync

return async_gen_to_sync(result)
return result

def __enter__(self: T) -> T:
return self


class AsyncLocalClient(LocalClient):
def __init__(self, service: Service):
super().__init__(service)
for name in self.servable.__servable_methods__:
setattr(self, name, partial(self.call, name))

def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any:
from starlette.concurrency import run_in_threadpool

from bentoml._internal.utils import is_async_callable
from bentoml._internal.utils import sync_gen_to_async

meth = getattr(self.servable, name)
if inspect.isgeneratorfunction(meth):
return sync_gen_to_async(meth(*args, **kwargs))
elif not is_async_callable(meth):
return run_in_threadpool(meth, *args, **kwargs)
return meth(*args, **kwargs)
64 changes: 64 additions & 0 deletions src/bentoml_io/client/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import inspect
import typing as t

from simple_di import Provide
from simple_di import inject

from bentoml._internal.container import BentoMLContainer

from .base import AbstractClient
from .http import AsyncHTTPClient
from .http import SyncHTTPClient
from .local import AsyncLocalClient
from .local import SyncLocalClient

if t.TYPE_CHECKING:
from ..servable import Servable
from ..server import Service


class ClientManager:
@inject
def __init__(
self,
service: Service,
runner_map: dict[str, str] = Provide[BentoMLContainer.remote_runner_mapping],
) -> None:
self.service = service
self._runner_map = runner_map
self._sync_clients: dict[str, AbstractClient] = {}
self._async_clients: dict[str, AbstractClient] = {}

def get_client(self, name_or_class: str | type[Servable]) -> AbstractClient:
caller_frame = inspect.currentframe().f_back # type: ignore
assert caller_frame is not None
is_async = bool(
caller_frame.f_code.co_flags & inspect.CO_COROUTINE
or caller_frame.f_code.co_flags & inspect.CO_ASYNC_GENERATOR
)
cache = self._async_clients if is_async else self._sync_clients
name = name_or_class if isinstance(name_or_class, str) else name_or_class.name
if name not in cache:
dep = next(
(dep for dep in self.service.dependencies if dep.name == name), None
)
if dep is None:
raise ValueError(
f"Dependency service {name} not found, please specify it in dependencies list"
)
if name in self._runner_map:
client_cls = AsyncHTTPClient if is_async else SyncHTTPClient
client = client_cls(self._runner_map[name], servable=dep.servable_cls)
else:
client_cls = AsyncLocalClient if is_async else SyncLocalClient
client = client_cls(dep)
cache[name] = client
return cache[name]

async def cleanup(self) -> None:
for client in self._async_clients.values():
await client.__aexit__(None, None, None)
for client in self._sync_clients.values():
await client.__aexit__(None, None, None)
14 changes: 10 additions & 4 deletions src/bentoml_io/client/testing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import typing as t

from ..server.service import Service
from .base import AbstractClient
from bentoml_io.client.local import LocalClient

if t.TYPE_CHECKING:
from ..server.service import Service

T = t.TypeVar("T")


class TestingClient(AbstractClient):
class TestingClient(LocalClient):
def __init__(self, service: Service):
self.servable = service.get_servable()
super().__init__(service)
for name in self.servable.__servable_methods__:
setattr(self, name, getattr(self.servable, name))

Expand Down
23 changes: 14 additions & 9 deletions src/bentoml_io/servable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .client.base import AbstractClient


class DependencySpec(t.NamedTuple):
connect_string: str | None
servable_cls: type[Servable]


class Servable:
__servable_methods__: dict[str, APIMethod[..., t.Any]] = {}
__clients_cache: dict[str, AbstractClient]
# User defined attributes
name: str
SUPPORTED_RESOURCES: tuple[str, ...] = ("cpu",)
Expand All @@ -22,17 +26,12 @@ def __init_subclass__(cls) -> None:
new_servable_methods: dict[str, APIMethod[..., t.Any]] = {}
for attr in vars(cls).values():
if isinstance(attr, APIMethod):
new_servable_methods[attr.name] = attr
new_servable_methods[attr.name] = attr # type: ignore
cls.__servable_methods__ = {**cls.__servable_methods__, **new_servable_methods}

def get_client(self, name_or_class: str | type[Servable]) -> AbstractClient:
if not hasattr(self, "__clients_cache"):
self.__clients_cache = {}
name = name_or_class if isinstance(name_or_class, str) else name_or_class.name
if name not in self.__clients_cache:
# TODO: create client
...
return self.__clients_cache[name]
# To be injected by service
raise NotImplementedError

def schema(self) -> dict[str, t.Any]:
return {
Expand All @@ -50,3 +49,9 @@ def call(self, method_name: str, input_data: dict[str, t.Any]) -> t.Any:
input_model = method.input_spec(**input_data)
args = {k: getattr(input_model, k) for k in input_model.model_fields}
return method.func(self, **args)

def test(self):
return self.get_client("abc")

async def async_test(self):
return self.get_client("abc")
Loading

0 comments on commit 2911853

Please sign in to comment.