Skip to content

Commit

Permalink
feat: multipart input
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Oct 24, 2023
1 parent 05b76ae commit a40f045
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 22 deletions.
23 changes: 23 additions & 0 deletions src/bentoml_io/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from bentoml._internal.utils.pkg import pkg_version_info

from .fields import DataframeSchema
from .fields import TensorSchema

if (ver := pkg_version_info("pydantic")) < (2,):
Expand All @@ -23,6 +24,7 @@
[t.Any, t.Iterable[t.Any], ConfigDict], tuple[t.Any, list[t.Any]] | None
]
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
else:
Expand All @@ -31,6 +33,7 @@
np = LazyLoader("np", globals(), "numpy")
tf = LazyLoader("tf", globals(), "tensorflow")
torch = LazyLoader("torch", globals(), "torch")
pd = LazyLoader("pd", globals(), "pandas")


def numpy_prepare_pydantic_annotations(
Expand Down Expand Up @@ -97,6 +100,24 @@ def tf_prepare_pydantic_annotations(
return origin, remaining_annotations


def pandas_prepare_pydantic_annotations(
source: t.Any, annotations: t.Iterable[t.Any], config: ConfigDict
) -> tuple[t.Any, list[t.Any]] | None:
if not getattr(source, "__module__", "").startswith("pandas"):
return None

origin = get_origin(source) or source
if not issubclass(origin, pd.DataFrame):
return None

_, remaining_annotations = _known_annotated_metadata.collect_known_metadata(
annotations
)
if not any(isinstance(a, DataframeSchema) for a in remaining_annotations):
remaining_annotations.insert(0, DataframeSchema())
return origin, remaining_annotations


def add_custom_preparers():
try:
from pydantic._internal import _std_types_schema
Expand All @@ -108,4 +129,6 @@ def add_custom_preparers():
numpy_prepare_pydantic_annotations,
torch_prepare_pydantic_annotations,
tf_prepare_pydantic_annotations,
# dataframe
pandas_prepare_pydantic_annotations,
)
65 changes: 64 additions & 1 deletion src/bentoml_io/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if t.TYPE_CHECKING:
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
from pydantic import GetCoreSchemaHandler
Expand All @@ -26,6 +27,7 @@
tf = LazyLoader("tf", globals(), "tensorflow")
torch = LazyLoader("torch", globals(), "torch")
pa = LazyLoader("pa", globals(), "pyarrow")
pd = LazyLoader("pd", globals(), "pandas")

T = t.TypeVar("T")
# This is an internal global state that is True when the model is being serialized for arrow
Expand Down Expand Up @@ -63,7 +65,18 @@ def __get_pydantic_core_schema__(
)


File = t.Annotated[t.BinaryIO, FileEncoder(io.BytesIO, lambda x: x.getvalue())]
def _get_file(obj: bytes | t.BinaryIO) -> t.BinaryIO:
if hasattr(obj, "read"):
return obj
return io.BytesIO(obj)


def _get_file_bytes(obj: t.BinaryIO) -> bytes:
obj.seek(0)
return obj.read()


File = t.Annotated[t.BinaryIO, FileEncoder(_get_file, _get_file_bytes)]

# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -150,6 +163,49 @@ def _validate(self, obj: t.Any) -> t.Any:
return arr


@dataclass(unsafe_hash=True, **slots_true)
class DataframeSchema:
orient: str = "records"
columns: list[str] | None = None

def __get_pydantic_json_schema__(
self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, t.Any]:
return _dict_filter_none(
{
"type": "dataframe",
"orient": self.orient,
"columns": self.columns,
"media_type": "application/json",
}
)

def __get_pydantic_core_schema__(
self, source_type: t.Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
self._validate,
core_schema.list_schema(core_schema.dict_schema())
if self.orient == "records"
else core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.list_schema(),
),
serialization=core_schema.plain_serializer_function_ser_schema(self.encode),
)

def encode(self, df: pd.DataFrame) -> list | dict:
if self.orient == "records":
return df.to_dict(orient="records")
elif self.orient == "columns":
return df.to_dict(orient="list")
else:
raise ValueError("Only 'records' and 'columns' are supported for orient")

def _validate(self, obj: t.Any) -> pd.DataFrame:
return pd.DataFrame(obj, columns=self.columns)


@t.overload
def Tensor(
format: Literal["numpy-array"], dtype: str, shape: tuple[int, ...]
Expand Down Expand Up @@ -183,3 +239,10 @@ def Tensor(
else:
annotation = tf.Tensor
return t.Annotated[annotation, TensorSchema(format, dtype, shape)]


def Dataframe(
orient: t.Literal["records", "columns"] = "records",
columns: list[str] | None = None,
) -> t.Type[pd.DataFrame]:
return t.Annotated[pd.DataFrame, DataframeSchema(orient, columns)]
42 changes: 33 additions & 9 deletions src/bentoml_io/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import get_args

from .typing_utils import is_iterator_type
from .typing_utils import is_list_type

if t.TYPE_CHECKING:
from starlette.requests import Request
Expand All @@ -20,11 +21,34 @@


class IOMixin:
multipart_fields: t.ClassVar[list[str]]

@classmethod
def __pydantic_init_subclass__(cls) -> None:
from pydantic._internal._typing_extra import is_annotated

from .fields import FileEncoder

cls.multipart_fields = []
for k, field in cls.model_fields.items():
is_multipart = False
if any(isinstance(d, FileEncoder) for d in field.metadata):
is_multipart = True
else:
annotation = field.annotation
if is_list_type(annotation):
annotation = get_args(annotation)[0]
if is_annotated(annotation) and any(
isinstance(d, FileEncoder) for d in get_args(annotation)[1:]
):
is_multipart = True
if is_multipart:
cls.multipart_fields.append(k)

@classmethod
async def from_http_request(cls, request: Request, serde: Serde) -> BaseModel:
async def from_http_request(cls, request: Request, serde: Serde) -> IODescriptor:
"""Parse a input model from HTTP request"""
json_str = await request.body()
return serde.deserialize_model(json_str, t.cast(t.Type[BaseModel], cls))
return await serde.parse_request(request, t.cast(t.Type[IODescriptor], cls))

@classmethod
async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response:
Expand All @@ -35,7 +59,7 @@ async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response:

if not issubclass(cls, RootModel):
return Response(
content=serde.serialize_model(t.cast("BaseModel", obj)),
content=serde.serialize_model(t.cast(IODescriptor, obj)),
media_type=serde.media_type,
)
if inspect.isasyncgen(obj):
Expand All @@ -45,7 +69,7 @@ async def async_stream() -> t.AsyncGenerator[str | bytes, None]:
if isinstance(item, (str, bytes)):
yield item
else:
yield serde.serialize_model(cls(item))
yield serde.serialize_model(t.cast(IODescriptor, cls(item)))

return StreamingResponse(async_stream(), media_type="text/plain")

Expand All @@ -56,14 +80,14 @@ def content_stream() -> t.Generator[str | bytes, None, None]:
if isinstance(item, (str, bytes)):
yield item
else:
yield serde.serialize_model(cls(item))
yield serde.serialize_model(t.cast(IODescriptor, cls(item)))

return StreamingResponse(content_stream(), media_type="text/plain")
else:
if not isinstance(obj, RootModel):
ins = cls(obj)
ins: IODescriptor = t.cast(IODescriptor, cls(obj))
else:
ins = obj
ins = t.cast(IODescriptor, obj)
if isinstance(rendered := ins.model_dump(), (str, bytes)):
media_type = cls.model_json_schema().get("media_type", "text/plain")
return Response(content=rendered, media_type=media_type)
Expand All @@ -73,7 +97,7 @@ def content_stream() -> t.Generator[str | bytes, None, None]:
)


class IODescriptor(BaseModel, IOMixin):
class IODescriptor(IOMixin, BaseModel):
@classmethod
def from_input(
cls, func: t.Callable[..., t.Any], *, skip_self: bool = False
Expand Down
43 changes: 36 additions & 7 deletions src/bentoml_io/serde.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
from __future__ import annotations

import abc
import io
import json
import pickle
import typing as t

from pydantic import BaseModel
from pydantic import RootModel
from starlette.datastructures import UploadFile

if t.TYPE_CHECKING:
from starlette.requests import Request

T = t.TypeVar("T", bound=BaseModel)
from .models import IODescriptor

T = t.TypeVar("T", bound="IODescriptor")


class Serde(abc.ABC):
media_type: str

@abc.abstractmethod
def serialize_model(self, model: BaseModel) -> bytes:
def serialize_model(self, model: IODescriptor) -> bytes:
...

@abc.abstractmethod
Expand All @@ -29,11 +36,16 @@ def serialize(self, obj: t.Any) -> bytes:
def deserialize(self, obj_bytes: bytes) -> t.Any:
...

async def parse_request(self, request: Request, cls: type[T]) -> T:
"""Parse a input model from HTTP request"""
json_str = await request.body()
return self.deserialize_model(json_str, cls)


class JSONSerde(Serde):
media_type = "application/json"

def serialize_model(self, model: BaseModel) -> bytes:
def serialize_model(self, model: IODescriptor) -> bytes:
return model.model_dump_json().encode("utf-8")

def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
Expand All @@ -46,10 +58,27 @@ def deserialize(self, obj_bytes: bytes) -> t.Any:
return json.loads(obj_bytes)


class MultipartSerde(JSONSerde):
media_type = "multipart/form-data"

async def parse_request(self, request: Request, cls: type[T]) -> T:
async with request.form() as form:
data: dict[str, t.Any] = json.loads(t.cast(str, form.get("data")))
for k in form:
if k == "data" or k not in cls.multipart_fields:
continue
value = form.getlist(k)
if not all(isinstance(v, UploadFile) for v in value):
raise ValueError("Unable to parse multipart request")
files = [v.file for v in value]
data[k] = files[0] if len(files) == 1 else files
return cls.model_validate(data)


class PickleSerde(Serde):
media_type = "application/vnd.bentoml+pickle"

def serialize_model(self, model: BaseModel) -> bytes:
def serialize_model(self, model: IODescriptor) -> bytes:
if isinstance(model, RootModel):
model_data: t.Any = model.root
else:
Expand All @@ -72,7 +101,7 @@ def deserialize(self, obj_bytes: bytes) -> t.Any:
class ArrowSerde(Serde):
media_type = "application/vnd.bentoml+arrow"

def serialize_model(self, model: BaseModel) -> bytes:
def serialize_model(self, model: IODescriptor) -> bytes:
from .arrow import serialize_to_arrow

buffer = io.BytesIO()
Expand All @@ -97,5 +126,5 @@ def deserialize(self, obj_bytes: bytes) -> t.Any:


ALL_SERDE: t.Mapping[str, type[Serde]] = {
s.media_type: s for s in [JSONSerde, PickleSerde, ArrowSerde]
s.media_type: s for s in [JSONSerde, PickleSerde, ArrowSerde, MultipartSerde]
}
11 changes: 6 additions & 5 deletions src/bentoml_io/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,13 @@ async def api_endpoint(self, name: str, request: Request) -> Response:
func = getattr(servable, name)

with self.service.context.in_request(request) as ctx:
input_data = await method.input_spec.from_http_request(request, serde)
input_params = {k: getattr(input_data, k) for k in input_data.model_fields}
if method.ctx_param is not None:
input_params[method.ctx_param] = ctx

try:
input_data = await method.input_spec.from_http_request(request, serde)
input_params = {
k: getattr(input_data, k) for k in input_data.model_fields
}
if method.ctx_param is not None:
input_params[method.ctx_param] = ctx
if is_async_callable(func):
output = await func(**input_params)
elif inspect.isasyncgenfunction(func):
Expand Down

0 comments on commit a40f045

Please sign in to comment.