Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
update webui
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 15, 2023
1 parent cfc6e6a commit 6de7678
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 188 deletions.
2 changes: 1 addition & 1 deletion src/glmtuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from glmtuner.webui import create_ui


__version__ = "0.1.0"
__version__ = "0.1.1"
42 changes: 6 additions & 36 deletions src/glmtuner/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,11 @@

LAYERNORM_NAMES = ["layernorm"]

SUPPORTED_MODEL_LIST = [
{
"name": "chatglm-6b",
"pretrained_model_name": "THUDM/chatglm-6b",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
SUPPORTED_MODELS = {
"ChatGLM-6B": {
"hf_path": "THUDM/chatglm-6b"
},
{
"name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm-6b-int8",
"pretrained_model_name": "THUDM/chatglm-6b-int8",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm-6b-int4",
"pretrained_model_name": "THUDM/chatglm-6b-int4",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
{
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None,
"provides": "ChatGLMLLMChain"
"ChatGLM2-6B": {
"hf_path": "THUDM/chatglm2-6b"
}
]
}
6 changes: 1 addition & 5 deletions src/glmtuner/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,11 @@ class ModelArguments:
)

def __post_init__(self):
if self.checkpoint_dir == "":
if not self.checkpoint_dir:
self.checkpoint_dir = None
# if base model is already quantization version, ignore quantization_bit config
if self.quantization_bit == "" or "int" in self.model_name_or_path:
self.quantization_bit = None

if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]

if self.quantization_bit is not None:
self.quantization_bit = int(self.quantization_bit)
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
39 changes: 22 additions & 17 deletions src/glmtuner/webui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple

from glmtuner.chat.stream_chat import ChatModel
from glmtuner.extras.constants import SUPPORTED_MODELS
from glmtuner.extras.misc import torch_gc
from glmtuner.hparams import GeneratingArguments
from glmtuner.tuner import get_infer_args
Expand All @@ -15,30 +16,34 @@ def __init__(self):
self.tokenizer = None
self.generating_args = GeneratingArguments()

def load_model(self, base_model: str, model_path: str, checkpoints: list, quantization_bit: str):
def load_model(self, model_name: str, model_path: str, checkpoints: list, quantization_bit: str):
if self.model is not None:
yield "You have loaded a model, please unload it first."
return

if not base_model:
if not model_name:
yield "Please select a model."
return

if get_save_dir(base_model) and checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints])
if model_path:
if not os.path.isdir(model_path):
return None, "Cannot find model directory in local disk.", None, None
model_name_or_path = model_path
elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State
model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"]
else:
return None, "Invalid model.", None, None

if checkpoints:
checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints])
else:
checkpoint_dir = None

yield "Loading model..."
if model_path:
model_name_or_path = model_path
else:
model_name_or_path = base_model
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
quantization_bit=quantization_bit
quantization_bit=int(quantization_bit) if quantization_bit else None
)
super().__init__(*get_infer_args(args))

Expand All @@ -52,13 +57,13 @@ def unload_model(self):
yield "Model unloaded, please load a model first."

def predict(
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_length: int,
top_p: float,
temperature: float
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_length: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
Expand Down
45 changes: 21 additions & 24 deletions src/glmtuner/webui/common.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,42 @@
import codecs
import json
import os
from typing import List, Tuple
from typing import Dict, List, Tuple

import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME

CACHE_DIR = "cache" # to save models
CACHE_DIR = "cache"
DATA_DIR = "data"
SAVE_DIR = "saves"
TEMP_USE_CONFIG = "tmp.use.config"
USER_CONFIG = "user.config"


def get_temp_use_config_path():
return os.path.join(SAVE_DIR, TEMP_USE_CONFIG)
def get_config_path():
return os.path.join(CACHE_DIR, USER_CONFIG)


def load_temp_use_config():
if not os.path.exists(get_temp_use_config_path()):
def load_config() -> Dict[str, str]:
if not os.path.exists(get_config_path()):
return {}
with codecs.open(get_temp_use_config_path()) as f:

with open(get_config_path(), "r", encoding="utf-8") as f:
try:
user_config = json.load(f)
return user_config
except Exception as e:
except:
return {}


def save_temp_use_config(user_config: dict):
with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f:
json.dump(f, user_config, ensure_ascii=False)


def save_model_config(model_name: str, model_path: str):
with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f:
json.dump({"model_name": model_name, "model_path": model_path}, f, ensure_ascii=False)
def save_config(model_name: str, model_path: str) -> None:
os.makedirs(CACHE_DIR, exist_ok=True)
user_config = dict(model_name=model_name, model_path=model_path)
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, ensure_ascii=False)


def get_save_dir(model_name: str) -> str:
return os.path.join(SAVE_DIR, model_name.split("/")[-1])
return os.path.join(SAVE_DIR, os.path.split(model_name)[-1])


def add_model(model_list: list, model_name: str, model_path: str) -> Tuple[list, str, str]:
Expand All @@ -62,11 +59,11 @@ def list_checkpoints(model_name: str) -> dict:
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
Expand Down
3 changes: 1 addition & 2 deletions src/glmtuner/webui/components/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Tuple

import gradio as gr
from gradio.components import Component
from typing import Tuple


def create_preview_box() -> Tuple[Component, Component, Component]:
Expand Down
30 changes: 20 additions & 10 deletions src/glmtuner/webui/components/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
from gradio.components import Component

from glmtuner.webui.common import list_datasets
from glmtuner.webui.components.data import create_preview_box
from glmtuner.webui.runner import Runner
from glmtuner.webui.utils import can_preview, get_preview


def create_eval_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None:
def create_eval_tab(model_name: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None:
with gr.Row():
dataset = gr.Dropdown(
label="Dataset", info="The name of dataset(s).", choices=list_datasets(), multiselect=True, interactive=True
)
dataset = gr.Dropdown(label="Dataset", choices=list_datasets(), multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button("Preview", interactive=False, scale=1)

preview_box, preview_count, preview_samples = create_preview_box()

dataset.change(can_preview, [dataset], [preview_btn])
preview_btn.click(
get_preview, [dataset], [preview_count, preview_samples]
).then(
lambda: gr.update(visible=True), outputs=[preview_box]
)

with gr.Row():
max_samples = gr.Textbox(
Expand All @@ -18,17 +28,17 @@ def create_eval_tab(base_model: Component, model_path: Component, checkpoints: C
per_device_eval_batch_size = gr.Slider(
label="Batch size", value=8, minimum=1, maximum=128, step=1, info="Eval batch size.", interactive=True
)
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit")
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.")

with gr.Row():
start = gr.Button("Start evaluation")
stop = gr.Button("Abort")
start_btn = gr.Button("Start evaluation")
stop_btn = gr.Button("Abort")

output = gr.Markdown(value="Ready")

start.click(
start_btn.click(
runner.run_eval,
[base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit],
[model_name, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit],
[output]
)
stop.click(runner.set_abort, queue=False)
stop_btn.click(runner.set_abort, queue=False)
25 changes: 13 additions & 12 deletions src/glmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,26 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com
query = gr.Textbox(show_label=False, placeholder="Input...", lines=10)

with gr.Column(min_width=32, scale=1):
submit = gr.Button("Submit", variant="primary")
submit_btn = gr.Button("Submit", variant="primary")

with gr.Column(scale=1):
clear = gr.Button("Clear History")
clear_btn = gr.Button("Clear History")
max_length = gr.Slider(
10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length",
interactive=True
10, 2048, value=chat_model.generating_args.max_length, step=1.0,
label="Maximum length", interactive=True
)
top_p = gr.Slider(
0, 1, value=chat_model.generating_args.top_p, step=0.01, label="Top P", interactive=True
0, 1, value=chat_model.generating_args.top_p, step=0.01,
label="Top P", interactive=True
)
temperature = gr.Slider(
0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature",
interactive=True
0, 1.5, value=chat_model.generating_args.temperature, step=0.01,
label="Temperature", interactive=True
)

history = gr.State([])

submit.click(
submit_btn.click(
chat_model.predict,
[chatbot, query, history, max_length, top_p, temperature],
[chatbot, history],
Expand All @@ -43,12 +44,12 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com
lambda: gr.update(value=""), outputs=[query]
)

clear.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)

return chat_box, chatbot, history


def create_infer_tab(base_model: Component, model_path: Component, checkpoints: Component) -> None:
def create_infer_tab(model_name: Component, model_path: Component, checkpoints: Component) -> None:
info_box = gr.Markdown(value="Model unloaded, please load a model first.")

chat_model = WebChatModel()
Expand All @@ -57,10 +58,10 @@ def create_infer_tab(base_model: Component, model_path: Component, checkpoints:
with gr.Row():
load_btn = gr.Button("Load model")
unload_btn = gr.Button("Unload model")
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit")
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.")

load_btn.click(
chat_model.load_model, [base_model, model_path, checkpoints, quantization_bit], [info_box]
chat_model.load_model, [model_name, model_path, checkpoints, quantization_bit], [info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
Expand Down
33 changes: 18 additions & 15 deletions src/glmtuner/webui/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,31 @@
import gradio as gr
from gradio.components import Component

from glmtuner.extras.constants import SUPPORTED_MODEL_LIST
from glmtuner.webui.common import list_checkpoints, load_temp_use_config, save_model_config
from glmtuner.extras.constants import SUPPORTED_MODELS
from glmtuner.webui.common import list_checkpoints, load_config, save_config


def create_model_tab() -> Tuple[Component, Component, Component]:
user_config = load_temp_use_config()
gr_state = gr.State([]) # gr.State does not accept a dict
user_config = load_config()
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]

with gr.Row():
model_name = gr.Dropdown([model["pretrained_model_name"] for model in SUPPORTED_MODEL_LIST] + ["custom"],
label="Base Model", info="Model Version of ChatGLM",
value=user_config.get("model_name"))
model_path = gr.Textbox(lines=1, label="Local model path(Optional)",
info="The absolute path of the directory where the local model file is located",
value=user_config.get("model_path"))
model_name = gr.Dropdown(choices=available_models, label="Model", value=user_config.get("model_name", None))
model_path = gr.Textbox(
label="Local path (Optional)", value=user_config.get("model_path", None),
info="The absolute path of the directory where the local model file is located."
)

with gr.Row():
checkpoints = gr.Dropdown(label="Checkpoints", multiselect=True, interactive=True, scale=5)
refresh = gr.Button("Refresh checkpoints", scale=1)

model_name.change(list_checkpoints, [model_name], [checkpoints])
model_path.change(save_model_config, [model_name, model_path])
refresh.click(list_checkpoints, [model_name], [checkpoints])
refresh_btn = gr.Button("Refresh checkpoints", scale=1)

model_name.change(
list_checkpoints, [model_name], [checkpoints]
).then( # TODO: save list
lambda: gr.update(value=""), outputs=[model_path]
)
model_path.change(save_config, [model_name, model_path])
refresh_btn.click(list_checkpoints, [model_name], [checkpoints])

return model_name, model_path, checkpoints
Loading

0 comments on commit 6de7678

Please sign in to comment.