Skip to content

Commit

Permalink
fix(client): type fixes (#4182)
Browse files Browse the repository at this point in the history
* type fixes

* fix imports

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sauyon and pre-commit-ci[bot] authored Oct 12, 2023
1 parent 032edf8 commit ab59b8d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 40 deletions.
6 changes: 2 additions & 4 deletions src/bentoml/_internal/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, svc: Service, server_url: str):
self._svc = svc
self.server_url = server_url

if svc is not None and len(svc.apis) == 0:
if len(svc.apis) == 0:
raise BentoMLException("No APIs were found when constructing client.")

self.endpoints = []
Expand Down Expand Up @@ -82,9 +82,7 @@ def wait_until_server_ready(

@t.overload
@staticmethod
def from_url(
server_url: str, *, kind: None | t.Literal["auto"] = ...
) -> GrpcClient | HTTPClient:
def from_url(server_url: str, *, kind: None | t.Literal["auto"] = ...) -> Client:
...

@t.overload
Expand Down
27 changes: 16 additions & 11 deletions src/bentoml/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from simple_di import inject

from ._internal.bento import Bento
from ._internal.client import Client
from ._internal.client.grpc import GrpcClient
from ._internal.client.http import HTTPClient
from ._internal.configuration.containers import BentoMLContainer
from ._internal.service import Service
from ._internal.tag import Tag
Expand All @@ -27,10 +30,6 @@
if TYPE_CHECKING:
from types import TracebackType

from ._internal.client import Client
from ._internal.client import GrpcClient
from ._internal.client import HTTPClient

_FILE: t.TypeAlias = None | int | t.IO[t.Any]


Expand All @@ -40,7 +39,10 @@
__all__ = ["Server", "GrpcServer", "HTTPServer"]


class Server(ABC):
ClientType = t.TypeVar("ClientType", bound=Client)


class Server(ABC, t.Generic[ClientType]):
servable: str | Bento | Tag | Service
host: str
port: int
Expand Down Expand Up @@ -134,7 +136,7 @@ def start(
stdout: _FILE = None,
stderr: _FILE = None,
text: bool | None = None,
):
) -> t.ContextManager[ClientType]:
"""Start the server programmatically.
To get the client, use the context manager.
Expand Down Expand Up @@ -182,7 +184,7 @@ def __init__(__inner_self):
except KeyboardInterrupt:
pass

def __enter__(__inner_self):
def __enter__(__inner_self) -> ClientType:
return self.get_client()

def __exit__(
Expand All @@ -195,7 +197,7 @@ def __exit__(

return _Manager()

def get_client(self) -> Client:
def get_client(self) -> ClientType:
if self.process is None:
# NOTE: if the process is None, we reset this envvar
del os.environ[BENTOML_SERVE_FROM_SERVER_API]
Expand Down Expand Up @@ -229,7 +231,7 @@ def get_client(self) -> Client:
return self._get_client()

@abstractmethod
def _get_client(self) -> Client:
def _get_client(self) -> ClientType:
pass

def stop(self) -> None:
Expand Down Expand Up @@ -296,7 +298,7 @@ def __exit__(
logger.error(f"Error stopping server: {e}", exc_info=e)


class HTTPServer(Server):
class HTTPServer(Server[HTTPClient]):
_client: HTTPClient | None = None

@inject
Expand Down Expand Up @@ -352,6 +354,9 @@ def __init__(

self.args.extend(construct_ssl_args(**ssl_args))

def get_client(self) -> HTTPClient:
return super().get_client()

def client(self) -> HTTPClient | None:
warn(
"'Server.client()' is deprecated, use 'Server.get_client()' instead.",
Expand All @@ -371,7 +376,7 @@ def _get_client(self) -> HTTPClient:
return self._client


class GrpcServer(Server):
class GrpcServer(Server[GrpcClient]):
_client: GrpcClient | None = None

@inject
Expand Down
71 changes: 46 additions & 25 deletions tests/e2e/bento_server_http/tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from bentoml.testing.utils import async_request


def test_http_server(bentoml_home: str):
@pytest.mark.usefixtures("bentoml_home")
def test_http_server():
server = bentoml.HTTPServer("service.py:svc", port=12345)

server.start()
Expand All @@ -21,54 +22,64 @@ def test_http_server(bentoml_home: str):

assert resp.status == 200

res = client.echo_json_sync({"test": "json"})
res = client.call("echo_json", {"test": "json"})

assert res == {"test": "json"}

server.stop()

assert server.process is not None # process should not be removed

timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
retcode = server.process.poll()
if retcode is not None and retcode <= 0:
break

retcode = server.process.poll()
assert retcode is not None

if sys.platform == "win32":
# on Windows, because of the way that terminate is run, it seems the exit code is set.
assert isinstance(server.process.poll(), int)
pass
else:
# on POSIX negative return codes mean the process was terminated; since we will be terminating
# negative return codes mean the process was terminated; since we will be terminating
# the process, it should be negative.
# on all other platforms, this should be 0.
assert server.process.poll() <= 0
assert retcode <= 0


def test_http_server_ctx(bentoml_home: str):
@pytest.mark.usefixtures("bentoml_home")
def test_http_server_ctx():
server = bentoml.HTTPServer("service.py:svc", port=12346)

with server.start() as client:
resp = client.health()

assert resp.status == 200

res = client.echo_json_sync({"more_test": "and more json"})
res = client.call("echo_json", {"more_test": "and more json"})

assert res == {"more_test": "and more json"}

assert server.process is not None # process should not be removed

timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
retcode = server.process.poll()
if retcode is not None and retcode <= 0:
break

retcode = server.process.poll()
assert retcode is not None

if sys.platform == "win32":
# on Windows, because of the way that terminate is run, it seems the exit code is set.
assert isinstance(server.process.poll(), int)
pass
else:
# on POSIX negative return codes mean the process was terminated; since we will be terminating
# negative return codes mean the process was terminated; since we will be terminating
# the process, it should be negative.
# on all other platforms, this should be 0.
assert server.process.poll() <= 0
assert retcode <= 0


def test_serve_from_svc():
Expand All @@ -81,23 +92,29 @@ def test_serve_from_svc():
assert resp.status == 200
server.stop()

timeout = 60
assert server.process is not None # process should not be removed

timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
retcode = server.process.poll()
if retcode is not None and retcode <= 0:
break

retcode = server.process.poll()
assert retcode is not None

if sys.platform == "win32":
# on Windows, because of the way that terminate is run, it seems the exit code is set.
assert isinstance(server.process.poll(), int)
pass
else:
# on POSIX negative return codes mean the process was terminated; since we will be terminating
# negative return codes mean the process was terminated; since we will be terminating
# the process, it should be negative.
# on all other platforms, this should be 0.
assert server.process.poll() <= 0
assert retcode <= 0


def test_serve_with_timeout(bentoml_home: str):
@pytest.mark.usefixtures("bentoml_home")
def test_serve_with_timeout():
server = bentoml.HTTPServer("service.py:svc", port=12349)
config_file = os.path.abspath("configs/timeout.yml")
env = os.environ.copy()
Expand All @@ -108,23 +125,26 @@ def test_serve_with_timeout(bentoml_home: str):
BentoMLException,
match="504: b'Not able to process the request in 1 seconds'",
):
client.echo_delay({})
client.call("echo_delay", {})


@pytest.mark.asyncio
async def test_serve_with_api_max_concurrency(bentoml_home: str):
@pytest.mark.usefixtures("bentoml_home")
async def test_serve_with_api_max_concurrency():
server = bentoml.HTTPServer("service.py:svc", port=12350, api_workers=1)
config_file = os.path.abspath("configs/max_concurrency.yml")
env = os.environ.copy()
env.update(BENTOML_CONFIG=config_file)

with server.start(env=env) as client:
tasks = [
asyncio.create_task(client.async_echo_delay({"delay": 0.5})),
asyncio.create_task(client.async_echo_delay({"delay": 0.5})),
asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5})),
asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5})),
]
await asyncio.sleep(0.1)
tasks.append(asyncio.create_task(client.async_echo_delay({"delay": 0.5})))
tasks.append(
asyncio.create_task(client.async_call("echo_delay", {"delay": 0.5}))
)
results = await asyncio.gather(*tasks, return_exceptions=True)

for i in range(2):
Expand All @@ -138,7 +158,8 @@ async def test_serve_with_api_max_concurrency(bentoml_home: str):
reason="Windows runner doesn't have enough cores to run this test",
)
@pytest.mark.asyncio
async def test_serve_with_lifecycle_hooks(bentoml_home: str, tmp_path: Path):
@pytest.mark.usefixtures("bentoml_home")
async def test_serve_with_lifecycle_hooks(tmp_path: Path):
server = bentoml.HTTPServer("service.py:svc", port=12351, api_workers=4)
env = os.environ.copy()
env["BENTOML_TEST_DATA"] = str(tmp_path)
Expand Down

0 comments on commit ab59b8d

Please sign in to comment.