From 1c3601a17635aab2347a6a22cd4f5e252ef57b2f Mon Sep 17 00:00:00 2001 From: anon-contributor-0 <160194672+anon-contributor-0@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:26:11 -0500 Subject: [PATCH] Add an API endpoint to reload the last-used model --- extensions/openai/script.py | 17 ++++++++++++++++- modules/models.py | 7 ++++++- modules/shared.py | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8ded..0d29464338 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -26,7 +26,7 @@ from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger -from modules.models import unload_model +from modules.models import unload_model, load_last_model from modules.text_generation import stop_everything_event from .typing import ( @@ -325,6 +325,21 @@ async def handle_load_model(request_data: LoadModelRequest): return HTTPException(status_code=400, detail="Failed to load the model.") +@app.post("/v1/internal/model/loadlast", dependencies=check_admin_key) +async def handle_load_last_model(): + ''' + This endpoint is experimental and may change in the future. + + Loads the last model used before it was unloaded. + ''' + try: + load_last_model() + return JSONResponse(content="OK") + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to load the last-used model.") + + @app.post("/v1/internal/model/unload", dependencies=check_admin_key) async def handle_unload_model(): unload_model() diff --git a/modules/models.py b/modules/models.py index 687af8ba2b..6d0899b1f9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -428,12 +428,17 @@ def clear_torch_cache(): def unload_model(): shared.model = shared.tokenizer = None + shared.last_model_name = shared.model_name shared.model_name = 'None' shared.lora_names = [] shared.model_dirty_from_training = False clear_torch_cache() +def load_last_model(): + shared.model, shared.tokenizer = load_model(shared.last_model_name) + + def reload_model(): unload_model() - shared.model, shared.tokenizer = load_model(shared.model_name) + load_last_model() diff --git a/modules/shared.py b/modules/shared.py index a3ce584c85..50e3bcec86 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ model = None tokenizer = None model_name = 'None' +last_model_name = 'None' is_seq2seq = False model_dirty_from_training = False lora_names = []