Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
update webui module (#291)
Browse files Browse the repository at this point in the history
* add SUPPORTED_MODEL_LIST

* 1. save and reload model config edited by user
2. add predefined model configs
3. remove use_v2 config
4. add quantization config
5. simplify model configuration operations
6. fix plot bugs
  • Loading branch information
codemayq authored Jul 15, 2023
1 parent c9d7bca commit cfc6e6a
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 125 deletions.
39 changes: 39 additions & 0 deletions src/glmtuner/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,42 @@
FINETUNING_ARGS_NAME = "finetuning_args.json"

LAYERNORM_NAMES = ["layernorm"]

SUPPORTED_MODEL_LIST = [
{
"name": "chatglm-6b",
"pretrained_model_name": "THUDM/chatglm-6b",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm-6b-int8",
"pretrained_model_name": "THUDM/chatglm-6b-int8",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm-6b-int4",
"pretrained_model_name": "THUDM/chatglm-6b-int4",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
}
]
7 changes: 7 additions & 0 deletions src/glmtuner/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,15 @@ class ModelArguments:
)

def __post_init__(self):
if self.checkpoint_dir == "":
self.checkpoint_dir = None
# if base model is already quantization version, ignore quantization_bit config
if self.quantization_bit == "" or "int" in self.model_name_or_path:
self.quantization_bit = None

if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]

if self.quantization_bit is not None:
self.quantization_bit = int(self.quantization_bit)
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
36 changes: 18 additions & 18 deletions src/glmtuner/webui/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from typing import List, Tuple

from glmtuner.extras.misc import torch_gc
from glmtuner.chat.stream_chat import ChatModel
from glmtuner.extras.misc import torch_gc
from glmtuner.hparams import GeneratingArguments
from glmtuner.tuner import get_infer_args
from glmtuner.webui.common import get_save_dir
Expand All @@ -15,7 +15,7 @@ def __init__(self):
self.tokenizer = None
self.generating_args = GeneratingArguments()

def load_model(self, base_model: str, model_list: list, checkpoints: list, use_v2: bool):
def load_model(self, base_model: str, model_path: str, checkpoints: list, quantization_bit: str):
if self.model is not None:
yield "You have loaded a model, please unload it first."
return
Expand All @@ -24,21 +24,21 @@ def load_model(self, base_model: str, model_list: list, checkpoints: list, use_v
yield "Please select a model."
return

if len(model_list) == 0:
yield "No model detected."
return

model_path = [path for name, path in model_list if name == base_model]
if get_save_dir(base_model) and checkpoints:
checkpoint_dir = ",".join([os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints])
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints])
else:
checkpoint_dir = None

yield "Loading model..."

if model_path:
model_name_or_path = model_path
else:
model_name_or_path = base_model
args = dict(
model_name_or_path=model_path[0],
checkpoint_dir=checkpoint_dir
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
quantization_bit=quantization_bit
)
super().__init__(*get_infer_args(args))

Expand All @@ -52,13 +52,13 @@ def unload_model(self):
yield "Model unloaded, please load a model first."

def predict(
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_length: int,
top_p: float,
temperature: float
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_length: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
Expand Down
49 changes: 38 additions & 11 deletions src/glmtuner/webui/common.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
import os
import codecs
import json
import gradio as gr
import os
from typing import List, Tuple
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME

import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME

CACHE_DIR = "cache" # to save models
CACHE_DIR = "cache" # to save models
DATA_DIR = "data"
SAVE_DIR = "saves"
TEMP_USE_CONFIG = "tmp.use.config"


def get_temp_use_config_path():
return os.path.join(SAVE_DIR, TEMP_USE_CONFIG)


def load_temp_use_config():
if not os.path.exists(get_temp_use_config_path()):
return {}
with codecs.open(get_temp_use_config_path()) as f:
try:
user_config = json.load(f)
return user_config
except Exception as e:
return {}


def save_temp_use_config(user_config: dict):
with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f:
json.dump(f, user_config, ensure_ascii=False)


def save_model_config(model_name: str, model_path: str):
with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f:
json.dump({"model_name": model_name, "model_path": model_path}, f, ensure_ascii=False)


def get_save_dir(model_name: str) -> str:
return os.path.join(SAVE_DIR, model_name)
return os.path.join(SAVE_DIR, model_name.split("/")[-1])


def add_model(model_list: list, model_name: str, model_path: str) -> Tuple[list, str, str]:
Expand All @@ -35,11 +62,11 @@ def list_checkpoints(model_name: str) -> dict:
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
Expand Down
4 changes: 2 additions & 2 deletions src/glmtuner/webui/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from glmtuner.webui.components.model import create_model_tab
from glmtuner.webui.components.sft import create_sft_tab
from glmtuner.webui.components.eval import create_eval_tab
from glmtuner.webui.components.infer import create_infer_tab
from glmtuner.webui.components.model import create_model_tab
from glmtuner.webui.components.sft import create_sft_tab
3 changes: 2 additions & 1 deletion src/glmtuner/webui/components/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gradio as gr
from typing import Tuple

import gradio as gr
from gradio.components import Component


Expand Down
6 changes: 3 additions & 3 deletions src/glmtuner/webui/components/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from glmtuner.webui.runner import Runner


def create_eval_tab(base_model: Component, model_list: Component, checkpoints: Component, runner: Runner) -> None:
def create_eval_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None:
with gr.Row():
dataset = gr.Dropdown(
label="Dataset", info="The name of dataset(s).", choices=list_datasets(), multiselect=True, interactive=True
Expand All @@ -18,7 +18,7 @@ def create_eval_tab(base_model: Component, model_list: Component, checkpoints: C
per_device_eval_batch_size = gr.Slider(
label="Batch size", value=8, minimum=1, maximum=128, step=1, info="Eval batch size.", interactive=True
)
use_v2 = gr.Checkbox(label="use ChatGLM2", value=True)
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit")

with gr.Row():
start = gr.Button("Start evaluation")
Expand All @@ -28,7 +28,7 @@ def create_eval_tab(base_model: Component, model_list: Component, checkpoints: C

start.click(
runner.run_eval,
[base_model, model_list, checkpoints, dataset, max_samples, per_device_eval_batch_size, use_v2],
[base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit],
[output]
)
stop.click(runner.set_abort, queue=False)
15 changes: 9 additions & 6 deletions src/glmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gradio as gr
from typing import Tuple

import gradio as gr
from gradio.components import Component

from glmtuner.webui.chat import WebChatModel
Expand All @@ -20,13 +21,15 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com
with gr.Column(scale=1):
clear = gr.Button("Clear History")
max_length = gr.Slider(
10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length", interactive=True
10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length",
interactive=True
)
top_p = gr.Slider(
0, 1, value=chat_model.generating_args.top_p, step=0.01, label="Top P", interactive=True
)
temperature = gr.Slider(
0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature", interactive=True
0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature",
interactive=True
)

history = gr.State([])
Expand All @@ -45,7 +48,7 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com
return chat_box, chatbot, history


def create_infer_tab(base_model: Component, model_list: Component, checkpoints: Component) -> None:
def create_infer_tab(base_model: Component, model_path: Component, checkpoints: Component) -> None:
info_box = gr.Markdown(value="Model unloaded, please load a model first.")

chat_model = WebChatModel()
Expand All @@ -54,10 +57,10 @@ def create_infer_tab(base_model: Component, model_list: Component, checkpoints:
with gr.Row():
load_btn = gr.Button("Load model")
unload_btn = gr.Button("Unload model")
use_v2 = gr.Checkbox(label="use ChatGLM2", value=True)
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit")

load_btn.click(
chat_model.load_model, [base_model, model_list, checkpoints, use_v2], [info_box]
chat_model.load_model, [base_model, model_path, checkpoints, quantization_bit], [info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
Expand Down
59 changes: 16 additions & 43 deletions src/glmtuner/webui/components/model.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,30 @@
import gradio as gr
from typing import Tuple
from gradio.components import Component

from glmtuner.webui.common import add_model, del_model, list_models, list_checkpoints


def create_model_manager(base_model: Component, model_list: Component) -> Component:
with gr.Box(visible=False, elem_classes="modal-box") as model_manager:
model_name = gr.Textbox(lines=1, label="Model name")
model_path = gr.Textbox(lines=1, label="Model path", info="The absolute path to your model.")

with gr.Row():
confirm = gr.Button("Save")
cancel = gr.Button("Cancel")

confirm.click(
add_model, [model_list, model_name, model_path], [model_list, model_name, model_path]
).then(
lambda: gr.update(visible=False), outputs=[model_manager]
).then(
list_models, [model_list], [base_model]
)

cancel.click(lambda: gr.update(visible=False), outputs=[model_manager])
import gradio as gr
from gradio.components import Component

return model_manager
from glmtuner.extras.constants import SUPPORTED_MODEL_LIST
from glmtuner.webui.common import list_checkpoints, load_temp_use_config, save_model_config


def create_model_tab() -> Tuple[Component, Component, Component]:

model_list = gr.State([]) # gr.State does not accept a dict
user_config = load_temp_use_config()
gr_state = gr.State([]) # gr.State does not accept a dict

with gr.Row():
base_model = gr.Dropdown(label="Model", interactive=True, scale=4)
add_btn = gr.Button("Add model", scale=1)
del_btn = gr.Button("Delete model", scale=1)
model_name = gr.Dropdown([model["pretrained_model_name"] for model in SUPPORTED_MODEL_LIST] + ["custom"],
label="Base Model", info="Model Version of ChatGLM",
value=user_config.get("model_name"))
model_path = gr.Textbox(lines=1, label="Local model path(Optional)",
info="The absolute path of the directory where the local model file is located",
value=user_config.get("model_path"))

with gr.Row():
checkpoints = gr.Dropdown(label="Checkpoints", multiselect=True, interactive=True, scale=5)
refresh = gr.Button("Refresh checkpoints", scale=1)

model_manager = create_model_manager(base_model, model_list)

base_model.change(list_checkpoints, [base_model], [checkpoints])

add_btn.click(lambda: gr.update(visible=True), outputs=[model_manager]).then(
list_models, [model_list], [base_model]
)

del_btn.click(del_model, [model_list, base_model], [model_list]).then(
list_models, [model_list], [base_model]
)

refresh.click(list_checkpoints, [base_model], [checkpoints])
model_name.change(list_checkpoints, [model_name], [checkpoints])
model_path.change(save_model_config, [model_name, model_path])
refresh.click(list_checkpoints, [model_name], [checkpoints])

return base_model, model_list, checkpoints
return model_name, model_path, checkpoints
Loading

0 comments on commit cfc6e6a

Please sign in to comment.