Skip to content

Commit

Permalink
Inference streaming support (#1750)
Browse files Browse the repository at this point in the history
* Included generate, generate_stream, infer_stream endpoints.

* Implemented rest infer_stream.

* Included adaptive batching hooks for predict_stream.

* Included grpc stream proto.

* Implemented ModelInferStream as stream-stream method.

* Included lazy fixtures as depenedency.

* Included tests for infer_stream endpoint and ModelInferStream.

* Introduced gzip_enabled flag.

* Included grpc stream error handling.

---------

Co-authored-by: Adrian Gonzalez-Martin <[email protected]>
  • Loading branch information
RobertSamoilescu and Adrian Gonzalez-Martin authored May 22, 2024
1 parent aad4a5a commit 54cd47e
Show file tree
Hide file tree
Showing 38 changed files with 1,624 additions and 174 deletions.
22 changes: 22 additions & 0 deletions benchmarking/testserver/models/text-model/model-settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"name": "text-model",

"implementation": "text_model.TextModel",

"versions": ["text-model/v1.2.3"],
"platform": "mlserver",
"inputs": [
{
"datatype": "BYTES",
"name": "prompt",
"shape": [1]
}
],
"outputs": [
{
"datatype": "BYTES",
"name": "output",
"shape": [1]
}
]
}
6 changes: 6 additions & 0 deletions benchmarking/testserver/models/text-model/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"debug": false,
"parallel_workers": 0,
"gzip_enabled": false,
"metrics_endpoint": null
}
45 changes: 45 additions & 0 deletions benchmarking/testserver/models/text-model/text_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio
from typing import AsyncIterator
from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse
from mlserver.codecs import StringCodec


class TextModel(MLModel):

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
text = StringCodec.decode_input(payload.inputs[0])[0]
return InferenceResponse(
model_name=self._settings.name,
outputs=[
StringCodec.encode_output(
name="output",
payload=[text],
use_bytes=True,
),
],
)

async def predict_stream(
self, payloads: AsyncIterator[InferenceRequest]
) -> AsyncIterator[InferenceResponse]:
payload = [_ async for _ in payloads][0]
text = StringCodec.decode_input(payload.inputs[0])[0]
words = text.split(" ")

split_text = []
for i, word in enumerate(words):
split_text.append(word if i == 0 else " " + word)

for word in split_text:
await asyncio.sleep(0.5)
yield InferenceResponse(
model_name=self._settings.name,
outputs=[
StringCodec.encode_output(
name="output",
payload=[word],
use_bytes=True,
),
],
)
64 changes: 48 additions & 16 deletions mlserver/batching/hooks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from typing import Awaitable, Callable, Optional
from typing import Awaitable, Callable, Optional, AsyncIterator

from ..errors import MLServerError
from ..model import MLModel
Expand All @@ -20,6 +20,26 @@ def __init__(self, method_name: str, reason: Optional[str] = None):
super().__init__(msg)


def _get_batcher(f: Callable) -> AdaptiveBatcher:
wrapped_f = get_wrapped_method(f)
model = _get_model(f)

if not hasattr(model, _AdaptiveBatchingAttr):
raise InvalidBatchingMethod(
wrapped_f.__name__, reason="adaptive batching has not been loaded"
)

return getattr(model, _AdaptiveBatchingAttr)


def _get_model(f: Callable) -> MLModel:
wrapped_f = get_wrapped_method(f)
if not hasattr(wrapped_f, "__self__"):
raise InvalidBatchingMethod(wrapped_f.__name__, reason="method is not bound")

return getattr(wrapped_f, "__self__")


def adaptive_batching(f: Callable[[InferenceRequest], Awaitable[InferenceResponse]]):
"""
Decorator for the `predict()` method which will ensure it uses the
Expand All @@ -28,24 +48,36 @@ def adaptive_batching(f: Callable[[InferenceRequest], Awaitable[InferenceRespons

@wraps(f)
async def _inner(payload: InferenceRequest) -> InferenceResponse:
wrapped_f = get_wrapped_method(f)
if not hasattr(wrapped_f, "__self__"):
raise InvalidBatchingMethod(
wrapped_f.__name__, reason="method is not bound"
)

model = getattr(wrapped_f, "__self__")
if not hasattr(model, _AdaptiveBatchingAttr):
raise InvalidBatchingMethod(
wrapped_f.__name__, reason="adaptive batching has not been loaded"
)

batcher = getattr(model, _AdaptiveBatchingAttr)
batcher = _get_batcher(f)
return await batcher.predict(payload)

return _inner


def not_implemented_warning(
f: Callable[[AsyncIterator[InferenceRequest]], AsyncIterator[InferenceResponse]],
):
"""
Decorator to lets users know that adaptive batching is not required on
method `f`.
"""
model = _get_model(f)
logger.warning(
f"Adaptive Batching is enabled for model '{model.name}'"
" but not supported for inference streaming."
" Falling back to non-batched inference streaming."
)

@wraps(f)
async def _inner_stream(
payload: AsyncIterator[InferenceRequest],
) -> AsyncIterator[InferenceResponse]:
async for response in f(payload):
yield response

return _inner_stream


async def load_batching(model: MLModel) -> MLModel:
if model.settings.max_batch_size <= 1:
return model
Expand All @@ -64,7 +96,7 @@ async def load_batching(model: MLModel) -> MLModel:
batcher = AdaptiveBatcher(model)
setattr(model, _AdaptiveBatchingAttr, batcher)

# Decorate predict method
# Decorate predict methods
setattr(model, "predict", adaptive_batching(model.predict))

setattr(model, "predict_stream", not_implemented_warning(model.predict_stream))
return model
5 changes: 3 additions & 2 deletions mlserver/grpc/dataplane_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 54cd47e

Please sign in to comment.