Skip to content

Commit

Permalink
take yaml config in model load endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 30, 2024
1 parent 118bbfe commit 3391557
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 53 deletions.
65 changes: 58 additions & 7 deletions aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from http import HTTPStatus
from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple

from fastapi import APIRouter, FastAPI, Request
import yaml
from fastapi import APIRouter, FastAPI, Form, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (HTMLResponse, JSONResponse, Response,
Expand All @@ -33,7 +34,6 @@
random_uuid)
from aphrodite.endpoints.logger import RequestLogger
from aphrodite.endpoints.openai.args import make_arg_parser
from aphrodite.endpoints.openai.model_management import ModelLoadRequest
from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
Expand Down Expand Up @@ -321,8 +321,13 @@ async def unload_model(raw_request: Request):
)

@router.post("/v1/model/load")
async def load_model(request: ModelLoadRequest, raw_request: Request):
"""Load a new model after unloading the previous one."""
async def load_model(
raw_request: Request,
config_file: Optional[UploadFile] = None,
request: Optional[str] = Form(None)
):
"""Load a new model after unloading the previous one.
Accept either a config file, a JSON request body, or both."""
if raw_request.app.state.model_is_loaded:
return JSONResponse(
content={
Expand All @@ -347,9 +352,53 @@ async def load_model(request: ModelLoadRequest, raw_request: Request):
if hasattr(original_args, param):
setattr(new_args, param, getattr(original_args, param))

for key, value in request.model_dump().items():
if hasattr(new_args, key):
setattr(new_args, key, value)
if config_file:
yaml_content = await config_file.read()
config_args = yaml.safe_load(yaml_content)
if config_args:
for key, value in config_args.items():
if hasattr(new_args, key):
setattr(new_args, key, value)

json_args = None
if request:
try:
json_args = json.loads(request)
except json.JSONDecodeError:
return JSONResponse(
content={
"status": "error",
"message": "Invalid JSON in request form field."
},
status_code=400
)
else:
try:
json_args = await raw_request.json()
except Exception:
if not config_file:
return JSONResponse(
content={
"status": "error",
"message": "Must provide either config_file or "
"valid JSON request body."
},
status_code=400
)

if json_args:
for key, value in json_args.items():
if hasattr(new_args, key):
setattr(new_args, key, value)

if not hasattr(new_args, 'model') or not new_args.model:
return JSONResponse(
content={
"status": "error",
"message": "No model specified in config or request body."
},
status_code=400
)

engine_args = AsyncEngineArgs.from_cli_args(new_args)

Expand Down Expand Up @@ -1055,6 +1104,8 @@ def signal_handler(*_) -> None:
host_name = args.host if args.host else "localhost"
port_str = str(args.port)

app.state.model_is_loaded = True


if SERVE_KOBOLD_LITE_UI:
ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/"
Expand Down
46 changes: 0 additions & 46 deletions aphrodite/endpoints/openai/model_management.py

This file was deleted.

0 comments on commit 3391557

Please sign in to comment.