From ab59b8d33c21ea3b613420020366c8c67286860d Mon Sep 17 00:00:00 2001 From: Sauyon Lee <2347889+sauyon@users.noreply.github.com> Date: Thu, 12 Oct 2023 02:32:32 -0700 Subject: [PATCH] fix(client): type fixes (#4182) * type fixes * fix imports * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/bentoml/_internal/client/__init__.py | 6 +- src/bentoml/server.py | 27 ++++--- .../e2e/bento_server_http/tests/test_serve.py | 71 ++++++++++++------- 3 files changed, 64 insertions(+), 40 deletions(-) diff --git a/src/bentoml/_internal/client/__init__.py b/src/bentoml/_internal/client/__init__.py index 81d3a9b850b..4250e42b2da 100644 --- a/src/bentoml/_internal/client/__init__.py +++ b/src/bentoml/_internal/client/__init__.py @@ -30,7 +30,7 @@ def __init__(self, svc: Service, server_url: str): self._svc = svc self.server_url = server_url - if svc is not None and len(svc.apis) == 0: + if len(svc.apis) == 0: raise BentoMLException("No APIs were found when constructing client.") self.endpoints = [] @@ -82,9 +82,7 @@ def wait_until_server_ready( @t.overload @staticmethod - def from_url( - server_url: str, *, kind: None | t.Literal["auto"] = ... - ) -> GrpcClient | HTTPClient: + def from_url(server_url: str, *, kind: None | t.Literal["auto"] = ...) -> Client: ... @t.overload diff --git a/src/bentoml/server.py b/src/bentoml/server.py index a4fe1c22cb1..48dd3101401 100644 --- a/src/bentoml/server.py +++ b/src/bentoml/server.py @@ -16,6 +16,9 @@ from simple_di import inject from ._internal.bento import Bento +from ._internal.client import Client +from ._internal.client.grpc import GrpcClient +from ._internal.client.http import HTTPClient from ._internal.configuration.containers import BentoMLContainer from ._internal.service import Service from ._internal.tag import Tag @@ -27,10 +30,6 @@ if TYPE_CHECKING: from types import TracebackType - from ._internal.client import Client - from ._internal.client import GrpcClient - from ._internal.client import HTTPClient - _FILE: t.TypeAlias = None | int | t.IO[t.Any] @@ -40,7 +39,10 @@ __all__ = ["Server", "GrpcServer", "HTTPServer"] -class Server(ABC): +ClientType = t.TypeVar("ClientType", bound=Client) + + +class Server(ABC, t.Generic[ClientType]): servable: str | Bento | Tag | Service host: str port: int @@ -134,7 +136,7 @@ def start( stdout: _FILE = None, stderr: _FILE = None, text: bool | None = None, - ): + ) -> t.ContextManager[ClientType]: """Start the server programmatically. To get the client, use the context manager. @@ -182,7 +184,7 @@ def __init__(__inner_self): except KeyboardInterrupt: pass - def __enter__(__inner_self): + def __enter__(__inner_self) -> ClientType: return self.get_client() def __exit__( @@ -195,7 +197,7 @@ def __exit__( return _Manager() - def get_client(self) -> Client: + def get_client(self) -> ClientType: if self.process is None: # NOTE: if the process is None, we reset this envvar del os.environ[BENTOML_SERVE_FROM_SERVER_API] @@ -229,7 +231,7 @@ def get_client(self) -> Client: return self._get_client() @abstractmethod - def _get_client(self) -> Client: + def _get_client(self) -> ClientType: pass def stop(self) -> None: @@ -296,7 +298,7 @@ def __exit__( logger.error(f"Error stopping server: {e}", exc_info=e) -class HTTPServer(Server): +class HTTPServer(Server[HTTPClient]): _client: HTTPClient | None = None @inject @@ -352,6 +354,9 @@ def __init__( self.args.extend(construct_ssl_args(**ssl_args)) + def get_client(self) -> HTTPClient: + return super().get_client() + def client(self) -> HTTPClient | None: warn( "'Server.client()' is deprecated, use 'Server.get_client()' instead.", @@ -371,7 +376,7 @@ def _get_client(self) -> HTTPClient: return self._client -class GrpcServer(Server): +class GrpcServer(Server[GrpcClient]): _client: GrpcClient | None = None @inject diff --git a/tests/e2e/bento_server_http/tests/test_serve.py b/tests/e2e/bento_server_http/tests/test_serve.py index cbcd6a4dc9a..82162cdf696 100644 --- a/tests/e2e/bento_server_http/tests/test_serve.py +++ b/tests/e2e/bento_server_http/tests/test_serve.py @@ -11,7 +11,8 @@ from bentoml.testing.utils import async_request -def test_http_server(bentoml_home: str): +@pytest.mark.usefixtures("bentoml_home") +def test_http_server(): server = bentoml.HTTPServer("service.py:svc", port=12345) server.start() @@ -21,54 +22,64 @@ def test_http_server(bentoml_home: str): assert resp.status == 200 - res = client.echo_json_sync({"test": "json"}) + res = client.call("echo_json", {"test": "json"}) assert res == {"test": "json"} server.stop() + assert server.process is not None # process should not be removed + timeout = 10 start_time = time.time() while time.time() - start_time < timeout: retcode = server.process.poll() if retcode is not None and retcode <= 0: break + + retcode = server.process.poll() + assert retcode is not None + if sys.platform == "win32": # on Windows, because of the way that terminate is run, it seems the exit code is set. - assert isinstance(server.process.poll(), int) + pass else: - # on POSIX negative return codes mean the process was terminated; since we will be terminating + # negative return codes mean the process was terminated; since we will be terminating # the process, it should be negative. - # on all other platforms, this should be 0. - assert server.process.poll() <= 0 + assert retcode <= 0 -def test_http_server_ctx(bentoml_home: str): +@pytest.mark.usefixtures("bentoml_home") +def test_http_server_ctx(): server = bentoml.HTTPServer("service.py:svc", port=12346) with server.start() as client: resp = client.health() - assert resp.status == 200 - res = client.echo_json_sync({"more_test": "and more json"}) + res = client.call("echo_json", {"more_test": "and more json"}) assert res == {"more_test": "and more json"} + assert server.process is not None # process should not be removed + timeout = 10 start_time = time.time() while time.time() - start_time < timeout: retcode = server.process.poll() if retcode is not None and retcode <= 0: break + + retcode = server.process.poll() + assert retcode is not None + if sys.platform == "win32": # on Windows, because of the way that terminate is run, it seems the exit code is set. - assert isinstance(server.process.poll(), int) + pass else: - # on POSIX negative return codes mean the process was terminated; since we will be terminating + # negative return codes mean the process was terminated; since we will be terminating # the process, it should be negative. - # on all other platforms, this should be 0. - assert server.process.poll() <= 0 + assert retcode <= 0 def test_serve_from_svc(): @@ -81,23 +92,29 @@ def test_serve_from_svc(): assert resp.status == 200 server.stop() - timeout = 60 + assert server.process is not None # process should not be removed + + timeout = 10 start_time = time.time() while time.time() - start_time < timeout: retcode = server.process.poll() if retcode is not None and retcode <= 0: break + + retcode = server.process.poll() + assert retcode is not None + if sys.platform == "win32": # on Windows, because of the way that terminate is run, it seems the exit code is set. - assert isinstance(server.process.poll(), int) + pass else: - # on POSIX negative return codes mean the process was terminated; since we will be terminating + # negative return codes mean the process was terminated; since we will be terminating # the process, it should be negative. - # on all other platforms, this should be 0. - assert server.process.poll() <= 0 + assert retcode <= 0 -def test_serve_with_timeout(bentoml_home: str): +@pytest.mark.usefixtures("bentoml_home") +def test_serve_with_timeout(): server = bentoml.HTTPServer("service.py:svc", port=12349) config_file = os.path.abspath("configs/timeout.yml") env = os.environ.copy() @@ -108,11 +125,12 @@ def test_serve_with_timeout(bentoml_home: str): BentoMLException, match="504: b'Not able to process the request in 1 seconds'", ): - client.echo_delay({}) + client.call("echo_delay", {}) @pytest.mark.asyncio -async def test_serve_with_api_max_concurrency(bentoml_home: str): +@pytest.mark.usefixtures("bentoml_home") +async def test_serve_with_api_max_concurrency(): server = bentoml.HTTPServer("service.py:svc", port=12350, api_workers=1) config_file = os.path.abspath("configs/max_concurrency.yml") env = os.environ.copy() @@ -120,11 +138,13 @@ async def test_serve_with_api_max_concurrency(bentoml_home: str): with server.start(env=env) as client: tasks = [ - asyncio.create_task(client.async_echo_delay({"delay": 0.5})), - asyncio.create_task(client.async_echo_delay({"delay": 0.5})), + asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5})), + asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5})), ] await asyncio.sleep(0.1) - tasks.append(asyncio.create_task(client.async_echo_delay({"delay": 0.5}))) + tasks.append( + asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5})) + ) results = await asyncio.gather(*tasks, return_exceptions=True) for i in range(2): @@ -138,7 +158,8 @@ async def test_serve_with_api_max_concurrency(bentoml_home: str): reason="Windows runner doesn't have enough cores to run this test", ) @pytest.mark.asyncio -async def test_serve_with_lifecycle_hooks(bentoml_home: str, tmp_path: Path): +@pytest.mark.usefixtures("bentoml_home") +async def test_serve_with_lifecycle_hooks(tmp_path: Path): server = bentoml.HTTPServer("service.py:svc", port=12351, api_workers=4) env = os.environ.copy() env["BENTOML_TEST_DATA"] = str(tmp_path)