From 8333795fcb25728dc7146596869f06e5e9533e58 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 20 Jul 2023 22:14:54 +0800 Subject: [PATCH] fix api --- src/api_demo.py | 5 ++++- src/llmtuner/api/app.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/api_demo.py b/src/api_demo.py index 45f64faf16..d2d6119799 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -5,11 +5,14 @@ import uvicorn +from llmtuner import ChatModel from llmtuner.api.app import create_app +from llmtuner.tuner import get_infer_args def main(): - app = create_app() + chat_model = ChatModel(*get_infer_args()) + app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 4c68afd46c..26e41effae 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -30,7 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory torch_gc() -def create_app(chat_model: ChatModel): +def create_app(chat_model: ChatModel) -> FastAPI: app = FastAPI(lifespan=lifespan) app.add_middleware(