Skip to content

Commit

Permalink
address review comments and some other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon committed Oct 10, 2023
1 parent d9119e6 commit 25c5a67
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 69 deletions.
35 changes: 5 additions & 30 deletions src/bentoml/_internal/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,7 @@ def from_url(
kind: t.Literal["auto", "http", "grpc"] | None = None,
**kwargs: t.Any,
) -> Client:
if kind is None or kind == "auto":
try:
from .http import HTTPClient

return HTTPClient.from_url(server_url, **kwargs)
except BadStatusLine:
from .grpc import GrpcClient

return GrpcClient.from_url(server_url, **kwargs)
except Exception as e: # pylint: disable=broad-except
raise BentoMLException(
f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})"
) from e
elif kind == "http":
from .http import HTTPClient

return HTTPClient.from_url(server_url, **kwargs)
elif kind == "grpc":
from .grpc import GrpcClient

return GrpcClient.from_url(server_url, **kwargs)
else:
raise BentoMLException(
f"Invalid client kind '{kind}'. Must be one of 'http', 'grpc', or 'auto'."
)
return SyncClient.from_url(server_url, kind=kind, **kwargs)

def __enter__(self):
return self
Expand Down Expand Up @@ -168,7 +144,6 @@ def __init__(self, svc: Service, server_url: str):
if len(svc.apis) == 0:
raise BentoMLException("No APIs were found when constructing client.")

self.endpoints = []
self.endpoints = []
for name, api in self._svc.apis.items():
self.endpoints.append(name)
Expand All @@ -177,7 +152,7 @@ def __init__(self, svc: Service, server_url: str):
setattr(
self,
name,
functools.partial(self.call, _bentoml_api=api),
functools.partial(self._call, _bentoml_api=api),
)

async def call(
Expand Down Expand Up @@ -282,7 +257,7 @@ async def from_url(
)


class SyncClient(ABC):
class SyncClient(Client):
server_url: str
_svc: Service
endpoints: list[str]
Expand All @@ -302,7 +277,7 @@ def __init__(self, svc: Service, server_url: str):
setattr(
self,
name,
functools.partial(self.call, _bentoml_api=api),
functools.partial(self._call, _bentoml_api=api),
)

def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any:
Expand Down Expand Up @@ -370,7 +345,7 @@ def from_url(

@classmethod
def from_url(
cls, server_url: str, *, kind: str | None = None, **kwargs: t.Any
cls, server_url: str, *, kind: t.Literal["auto", "http", "grpc"] | None = None, **kwargs: t.Any
) -> SyncClient:
if kind is None or kind == "auto":
try:
Expand Down
61 changes: 22 additions & 39 deletions src/bentoml/_internal/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
import asyncio
import json
import logging
import socket
import time
import typing as t
import urllib.error
import urllib.request
from functools import cached_property
from urllib.parse import urlparse

import httpx
import starlette.datastructures
Expand Down Expand Up @@ -60,9 +56,9 @@ async def wait_until_server_ready(
else:
await asyncio.sleep(check_interval)
except (
ConnectionError,
urllib.error.URLError,
socket.timeout,
httpx.TimeoutException,
httpx.NetworkError,
httpx.HTTPStatusError,
):
logger.debug("Server is not ready. Retrying...")
await asyncio.sleep(check_interval)
Expand All @@ -76,10 +72,9 @@ async def wait_until_server_ready(
f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready."
)
except (
ConnectionError,
urllib.error.URLError,
socket.timeout,
TimeoutError,
httpx.TimeoutException,
httpx.NetworkError,
httpx.HTTPStatusError,
) as err:
logger.error("Timed out while connecting to %s:%s:", host, port)
logger.error(err)
Expand All @@ -92,16 +87,6 @@ async def health(self) -> httpx.Response:
async def from_url(cls, server_url: str, **kwargs: t.Any) -> AsyncHTTPClient:
server_url = server_url if "://" in server_url else "http://" + server_url

conn = HTTPConnection(url_parts.netloc)
conn.set_debuglevel(logging.DEBUG if get_debug_mode() else 0)

# we want to preserve as much of the user path as possible, so we don't really want to use
# a path join here.
if url_parts.path.endswith("/"):
url_parts.path + "docs.json"
else:
url_parts.path + "/docs.json"

async with httpx.AsyncClient(base_url=server_url) as session:
resp = await session.get("/docs.json")
if resp.status_code != 200:
Expand Down Expand Up @@ -175,7 +160,7 @@ async def _call(

resp = await self.client.post(
"/" + api.route if not api.route.startswith("/") else api.route,
data=req_body,
content=req_body,
headers={"content-type": fake_resp.headers["content-type"]},
)
if resp.status_code != 200:
Expand All @@ -200,9 +185,6 @@ async def close(self):
class SyncHTTPClient(SyncClient):
@cached_property
def client(self) -> httpx.Client:
server_url = urlparse(self.server_url)
if not server_url.netloc:
raise BentoMLException("Invalid API server URL: {self.server_url}. ")
return httpx.Client(base_url=self.server_url)

@staticmethod
Expand All @@ -225,9 +207,9 @@ def wait_until_server_ready(
else:
time.sleep(check_interval)
except (
ConnectionError,
urllib.error.URLError,
socket.timeout,
httpx.TimeoutException,
httpx.NetworkError,
httpx.HTTPStatusError,
):
logger.debug("Server is not ready. Retrying...")

Expand All @@ -239,10 +221,9 @@ def wait_until_server_ready(
f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready."
)
except (
ConnectionError,
urllib.error.URLError,
socket.timeout,
TimeoutError,
httpx.TimeoutException,
httpx.NetworkError,
httpx.HTTPStatusError,
) as err:
logger.error("Timed out while connecting to %s:%s:", host, port)
logger.error(err)
Expand All @@ -254,12 +235,14 @@ def health(self) -> httpx.Response:
@classmethod
def from_url(cls, server_url: str, **kwargs: t.Any) -> SyncHTTPClient:
server_url = server_url if "://" in server_url else "http://" + server_url
resp = httpx.get(f"{server_url}/docs.json")
if resp.status_code != 200:
raise RemoteException(
f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content}"
)
openapi_spec = json.loads(resp.content)

with httpx.Client(base_url=server_url) as session:
resp = session.get("docs.json")
if resp.status_code != 200:
raise RemoteException(
f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content}"
)
openapi_spec = json.loads(resp.content)

dummy_service = Service(openapi_spec["info"]["title"])

Expand Down Expand Up @@ -323,7 +306,7 @@ def _call(
self.server_url + "/" + api.route
if not api.route.startswith("/")
else api.route,
data=req_body,
content=req_body,
headers={"content-type": fake_resp.headers["content-type"]},
)
if resp.status_code != 200:
Expand Down

0 comments on commit 25c5a67

Please sign in to comment.