diff --git a/src/api_demo.py b/src/api_demo.py index 45f64faf16..d2d6119799 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -5,11 +5,14 @@ import uvicorn +from llmtuner import ChatModel from llmtuner.api.app import create_app +from llmtuner.tuner import get_infer_args def main(): - app = create_app() + chat_model = ChatModel(*get_infer_args()) + app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 4c68afd46c..26e41effae 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -30,7 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory torch_gc() -def create_app(chat_model: ChatModel): +def create_app(chat_model: ChatModel) -> FastAPI: app = FastAPI(lifespan=lifespan) app.add_middleware(