Skip to content

Commit

Permalink
Fix mypy warnings in api_server.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fred Reiss <[email protected]>
  • Loading branch information
frreiss committed Jan 10, 2025
1 parent a4e2b26 commit ae27a6e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import AsyncIterator, Optional, Set, Tuple, Dict, Union

import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
Expand Down Expand Up @@ -419,6 +419,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")

res = await fallback_handler.create_pooling(request, raw_request)

generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
Expand Down Expand Up @@ -493,7 +495,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


TASK_HANDLERS = {
TASK_HANDLERS: Dict[str, Dict[str,tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
Expand Down Expand Up @@ -651,7 +653,7 @@ async def add_request_id(request: Request, call_next):
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
Expand Down

0 comments on commit ae27a6e

Please sign in to comment.