Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
update web UI, add accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 18, 2023
1 parent d4d6250 commit c170351
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 43 deletions.
22 changes: 12 additions & 10 deletions src/glmtuner/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ async def create_chat_completion(request: ChatCompletionRequest):

prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
query = prev_messages.pop(0).content + query
prefix = prev_messages.pop(0).content
else:
prefix = None

history = []
if len(prev_messages) % 2 == 0:
Expand All @@ -66,17 +68,17 @@ async def create_chat_completion(request: ChatCompletionRequest):
history.append([prev_messages[i].content, prev_messages[i+1].content])

if request.stream:
generate = predict(query, history, request)
generate = predict(query, history, prefix, request)
return EventSourceResponse(generate, media_type="text/event-stream")

response = chat_model.chat(
query, history, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
)

usage = ChatCompletionResponseUsage( # too complex to compute
prompt_tokens=1,
completion_tokens=1,
total_tokens=2
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
)

choice_data = ChatCompletionResponseChoice(
Expand All @@ -87,7 +89,7 @@ async def create_chat_completion(request: ChatCompletionRequest):

return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)

async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompletionRequest):
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
Expand All @@ -97,7 +99,7 @@ async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompl
yield json.dumps(chunk, ensure_ascii=False)

for new_text in chat_model.stream_chat(
query, history, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue
Expand Down
67 changes: 49 additions & 18 deletions src/glmtuner/chat/stream_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer

from glmtuner.extras.misc import get_logits_processor
from glmtuner.extras.misc import auto_configure_device_map
from glmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from glmtuner.tuner import load_model_and_tokenizer


Expand All @@ -11,6 +14,7 @@ class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
) -> None:
Expand All @@ -23,10 +27,30 @@ def __init__(
else:
self.model = self.model.cuda()

self.model.eval()
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
self.generating_args = generating_args

def process_args(self, **input_kwargs) -> Dict[str, Any]:
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None
) -> str:
prefix = prefix + "\n" if prefix else "" # add separator for non-empty prefix
history = history if history else []
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i+1, old_query, response)
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query)
prompt = prefix + prompt
return prompt

def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix

inputs = self.tokenizer([self.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])

temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
Expand All @@ -36,10 +60,12 @@ def process_args(self, **input_kwargs) -> Dict[str, Any]:

gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"]
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))

if max_length:
Expand All @@ -50,24 +76,29 @@ def process_args(self, **input_kwargs) -> Dict[str, Any]:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens

return gen_kwargs
return gen_kwargs, prompt_length

@torch.inference_mode()
def chat(self, query: str, history: Optional[List[Tuple[str, str]]] = None, **input_kwargs) -> str:
gen_kwargs = self.process_args(**input_kwargs)
response = self.model.chat(self.tokenizer, query, history, **gen_kwargs)
return response
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
return response, (prompt_length, response_length)

@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:
gen_kwargs = self.process_args(**input_kwargs)
current_length = 0
for new_response, _ in self.model.stream_chat(self.tokenizer, query, history, **gen_kwargs):
if len(new_response) == current_length:
continue

new_text = new_response[current_length:]
current_length = len(new_response)
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()

for new_text in streamer:
yield new_text
14 changes: 7 additions & 7 deletions src/glmtuner/tuner/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,20 @@ def get_train_args(

def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, FinetuningArguments, GeneratingArguments]:
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:

parser = HfArgumentParser((ModelArguments, FinetuningArguments, GeneratingArguments))
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))

if args is not None:
model_args, finetuning_args, generating_args = parser.parse_dict(args)
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()

assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."

return model_args, finetuning_args, generating_args
return model_args, data_args, finetuning_args, generating_args
5 changes: 2 additions & 3 deletions src/glmtuner/tuner/sft/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
class ComputeMetrics:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""

tokenizer: PreTrainedTokenizer
Expand All @@ -25,7 +23,7 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
Expand All @@ -49,5 +47,6 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -

bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))

return {k: float(np.mean(v)) for k, v in score_dict.items()}
4 changes: 3 additions & 1 deletion src/glmtuner/webui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def predict(
):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(query, history, max_length=max_length, top_p=top_p, temperature=temperature):
for new_text in self.stream_chat(
query, history, max_length=max_length, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
Expand Down
8 changes: 7 additions & 1 deletion src/glmtuner/webui/components/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])

with gr.Row():
max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1)
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True)

with gr.Row():
Expand All @@ -42,6 +44,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
max_samples,
batch_size,
predict
Expand All @@ -57,6 +61,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
Expand Down
12 changes: 9 additions & 3 deletions src/glmtuner/webui/components/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])

with gr.Row():
max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1)
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")

with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1)
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
fp16 = gr.Checkbox(value=True)

with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)

with gr.Row():
start_btn = gr.Button()
Expand All @@ -62,6 +64,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
Expand All @@ -88,6 +92,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
Expand Down
20 changes: 20 additions & 0 deletions src/glmtuner/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,26 @@
"value": "关闭"
}
},
"max_source_length": {
"en": {
"label": "Max source length",
"info": "Max tokens in source sequence."
},
"zh": {
"label": "输入序列最大长度",
"info": "输入序列分词后的最大长度。"
}
},
"max_target_length": {
"en": {
"label": "Max target length",
"info": "Max tokens in target sequence."
},
"zh": {
"label": "输出序列最大长度",
"info": "输出序列分词后的最大长度。"
}
},
"learning_rate": {
"en": {
"label": "Learning rate",
Expand Down
8 changes: 8 additions & 0 deletions src/glmtuner/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def run_train(
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
Expand Down Expand Up @@ -100,6 +102,8 @@ def run_train(
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
Expand Down Expand Up @@ -142,6 +146,8 @@ def run_eval(
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
max_samples: str,
batch_size: int,
predict: bool
Expand Down Expand Up @@ -171,6 +177,8 @@ def run_eval(
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
output_dir=output_dir
Expand Down

1 comment on commit c170351

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix #317

Please sign in to comment.