Skip to content

Commit

Permalink
Add an API endpoint to reload the last-used model
Browse files Browse the repository at this point in the history
  • Loading branch information
anon-contributor-0 committed Feb 21, 2024
1 parent 080f713 commit 9f0fe76
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
17 changes: 16 additions & 1 deletion extensions/openai/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.models import unload_model, load_last_model
from modules.text_generation import stop_everything_event

from .typing import (
Expand Down Expand Up @@ -314,6 +314,21 @@ async def handle_load_model(request_data: LoadModelRequest):
return HTTPException(status_code=400, detail="Failed to load the model.")


@app.post("/v1/internal/model/loadlast", dependencies=check_admin_key)
async def handle_load_last_model():
'''
This endpoint is experimental and may change in the future.
Loads the last model used before it was unloaded.
'''
try:
load_last_model()
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the last-used model.")


@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
unload_model()
Expand Down
7 changes: 6 additions & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,17 @@ def clear_torch_cache():

def unload_model():
shared.model = shared.tokenizer = None
shared.last_model_name = shared.model_name
shared.model_name = 'None'
shared.lora_names = []
shared.model_dirty_from_training = False
clear_torch_cache()


def load_last_model():
shared.model, shared.tokenizer = load_model(shared.last_model_name)


def reload_model():
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name)
load_last_model()
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
model = None
tokenizer = None
model_name = 'None'
last_model_name = 'None'
is_seq2seq = False
model_dirty_from_training = False
lora_names = []
Expand Down

0 comments on commit 9f0fe76

Please sign in to comment.