From c1703510a691a3deb253693424a42f25f263dbf4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 18 Jul 2023 20:07:55 +0800 Subject: [PATCH] update web UI, add accuracy --- src/glmtuner/api/app.py | 22 +++++---- src/glmtuner/chat/stream_chat.py | 67 ++++++++++++++++++++------- src/glmtuner/tuner/core/parser.py | 14 +++--- src/glmtuner/tuner/sft/metric.py | 5 +- src/glmtuner/webui/chat.py | 4 +- src/glmtuner/webui/components/eval.py | 8 +++- src/glmtuner/webui/components/sft.py | 12 +++-- src/glmtuner/webui/locales.py | 20 ++++++++ src/glmtuner/webui/runner.py | 8 ++++ 9 files changed, 117 insertions(+), 43 deletions(-) diff --git a/src/glmtuner/api/app.py b/src/glmtuner/api/app.py index 63dbd72..ba1fb7d 100644 --- a/src/glmtuner/api/app.py +++ b/src/glmtuner/api/app.py @@ -57,7 +57,9 @@ async def create_chat_completion(request: ChatCompletionRequest): prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: - query = prev_messages.pop(0).content + query + prefix = prev_messages.pop(0).content + else: + prefix = None history = [] if len(prev_messages) % 2 == 0: @@ -66,17 +68,17 @@ async def create_chat_completion(request: ChatCompletionRequest): history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: - generate = predict(query, history, request) + generate = predict(query, history, prefix, request) return EventSourceResponse(generate, media_type="text/event-stream") - response = chat_model.chat( - query, history, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + response, (prompt_length, response_length) = chat_model.chat( + query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ) - usage = ChatCompletionResponseUsage( # too complex to compute - prompt_tokens=1, - completion_tokens=1, - total_tokens=2 + usage = ChatCompletionResponseUsage( + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length+response_length ) choice_data = ChatCompletionResponseChoice( @@ -87,7 +89,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) - async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompletionRequest): + async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role=Role.ASSISTANT), @@ -97,7 +99,7 @@ async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompl yield json.dumps(chunk, ensure_ascii=False) for new_text in chat_model.stream_chat( - query, history, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ): if len(new_text) == 0: continue diff --git a/src/glmtuner/chat/stream_chat.py b/src/glmtuner/chat/stream_chat.py index d2d3f11..ae56df2 100644 --- a/src/glmtuner/chat/stream_chat.py +++ b/src/glmtuner/chat/stream_chat.py @@ -1,8 +1,11 @@ import torch from typing import Any, Dict, Generator, List, Optional, Tuple +from threading import Thread +from transformers import TextIteratorStreamer +from glmtuner.extras.misc import get_logits_processor from glmtuner.extras.misc import auto_configure_device_map -from glmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments +from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from glmtuner.tuner import load_model_and_tokenizer @@ -11,6 +14,7 @@ class ChatModel: def __init__( self, model_args: ModelArguments, + data_args: DataArguments, finetuning_args: FinetuningArguments, generating_args: GeneratingArguments ) -> None: @@ -23,10 +27,30 @@ def __init__( else: self.model = self.model.cuda() - self.model.eval() + self.source_prefix = data_args.source_prefix if data_args.source_prefix else "" self.generating_args = generating_args - def process_args(self, **input_kwargs) -> Dict[str, Any]: + def get_prompt( + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None + ) -> str: + prefix = prefix + "\n" if prefix else "" # add separator for non-empty prefix + history = history if history else [] + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i+1, old_query, response) + prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query) + prompt = prefix + prompt + return prompt + + def process_args( + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs + ) -> Tuple[Dict[str, Any], int]: + prefix = prefix if prefix else self.source_prefix + + inputs = self.tokenizer([self.get_prompt(query, history, prefix)], return_tensors="pt") + inputs = inputs.to(self.model.device) + prompt_length = len(inputs["input_ids"][0]) + temperature = input_kwargs.pop("temperature", None) top_p = input_kwargs.pop("top_p", None) top_k = input_kwargs.pop("top_k", None) @@ -36,10 +60,12 @@ def process_args(self, **input_kwargs) -> Dict[str, Any]: gen_kwargs = self.generating_args.to_dict() gen_kwargs.update(dict( + input_ids=inputs["input_ids"], temperature=temperature or gen_kwargs["temperature"], top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], - repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"] + repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], + logits_processor=get_logits_processor() )) if max_length: @@ -50,24 +76,29 @@ def process_args(self, **input_kwargs) -> Dict[str, Any]: gen_kwargs.pop("max_length", None) gen_kwargs["max_new_tokens"] = max_new_tokens - return gen_kwargs + return gen_kwargs, prompt_length @torch.inference_mode() - def chat(self, query: str, history: Optional[List[Tuple[str, str]]] = None, **input_kwargs) -> str: - gen_kwargs = self.process_args(**input_kwargs) - response = self.model.chat(self.tokenizer, query, history, **gen_kwargs) - return response + def chat( + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs + ) -> Tuple[str, Tuple[int, int]]: + gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) + generation_output = self.model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][prompt_length:] + response = self.tokenizer.decode(outputs, skip_special_tokens=True) + response_length = len(outputs) + return response, (prompt_length, response_length) @torch.inference_mode() def stream_chat( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, **input_kwargs + self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs = self.process_args(**input_kwargs) - current_length = 0 - for new_response, _ in self.model.stream_chat(self.tokenizer, query, history, **gen_kwargs): - if len(new_response) == current_length: - continue - - new_text = new_response[current_length:] - current_length = len(new_response) + gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) + streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + + thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread.start() + + for new_text in streamer: yield new_text diff --git a/src/glmtuner/tuner/core/parser.py b/src/glmtuner/tuner/core/parser.py index 0fa8a8a..409c63d 100644 --- a/src/glmtuner/tuner/core/parser.py +++ b/src/glmtuner/tuner/core/parser.py @@ -102,20 +102,20 @@ def get_train_args( def get_infer_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, FinetuningArguments, GeneratingArguments]: +) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: - parser = HfArgumentParser((ModelArguments, FinetuningArguments, GeneratingArguments)) + parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments)) if args is not None: - model_args, finetuning_args, generating_args = parser.parse_dict(args) + model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args) elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - model_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) + model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) else: - model_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() + model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \ or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - return model_args, finetuning_args, generating_args + return model_args, data_args, finetuning_args, generating_args diff --git a/src/glmtuner/tuner/sft/metric.py b/src/glmtuner/tuner/sft/metric.py index 541c45c..284c935 100644 --- a/src/glmtuner/tuner/sft/metric.py +++ b/src/glmtuner/tuner/sft/metric.py @@ -14,8 +14,6 @@ class ComputeMetrics: r""" Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM. - - Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 """ tokenizer: PreTrainedTokenizer @@ -25,7 +23,7 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) - Uses the model predictions to compute metrics. """ preds, labels = eval_preds - score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) @@ -49,5 +47,6 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) - bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label)) return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/src/glmtuner/webui/chat.py b/src/glmtuner/webui/chat.py index 5c7d449..2c9e868 100644 --- a/src/glmtuner/webui/chat.py +++ b/src/glmtuner/webui/chat.py @@ -76,7 +76,9 @@ def predict( ): chatbot.append([query, ""]) response = "" - for new_text in self.stream_chat(query, history, max_length=max_length, top_p=top_p, temperature=temperature): + for new_text in self.stream_chat( + query, history, max_length=max_length, top_p=top_p, temperature=temperature + ): response += new_text new_history = history + [(query, response)] chatbot[-1] = [query, response] diff --git a/src/glmtuner/webui/components/eval.py b/src/glmtuner/webui/components/eval.py index f681ac3..7fe16e5 100644 --- a/src/glmtuner/webui/components/eval.py +++ b/src/glmtuner/webui/components/eval.py @@ -21,8 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): + max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) + max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) max_samples = gr.Textbox(value="100000") - batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1) + batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) predict = gr.Checkbox(value=True) with gr.Row(): @@ -42,6 +44,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str top_elems["source_prefix"], dataset_dir, dataset, + max_source_length, + max_target_length, max_samples, batch_size, predict @@ -57,6 +61,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn, + max_source_length=max_source_length, + max_target_length=max_target_length, max_samples=max_samples, batch_size=batch_size, predict=predict, diff --git a/src/glmtuner/webui/components/sft.py b/src/glmtuner/webui/components/sft.py index 7bbed7a..a5d83bb 100644 --- a/src/glmtuner/webui/components/sft.py +++ b/src/glmtuner/webui/components/sft.py @@ -23,13 +23,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): + max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) + max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) learning_rate = gr.Textbox(value="5e-5") num_train_epochs = gr.Textbox(value="3.0") max_samples = gr.Textbox(value="100000") with gr.Row(): - batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1) - gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1) + batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) + gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) lr_scheduler_type = gr.Dropdown( value="cosine", choices=[scheduler.value for scheduler in SchedulerType] ) @@ -37,7 +39,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, with gr.Row(): logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) - save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10) + save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) with gr.Row(): start_btn = gr.Button() @@ -62,6 +64,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, top_elems["source_prefix"], dataset_dir, dataset, + max_source_length, + max_target_length, learning_rate, num_train_epochs, max_samples, @@ -88,6 +92,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn, + max_source_length=max_source_length, + max_target_length=max_target_length, learning_rate=learning_rate, num_train_epochs=num_train_epochs, max_samples=max_samples, diff --git a/src/glmtuner/webui/locales.py b/src/glmtuner/webui/locales.py index 398f3c7..cf41496 100644 --- a/src/glmtuner/webui/locales.py +++ b/src/glmtuner/webui/locales.py @@ -129,6 +129,26 @@ "value": "关闭" } }, + "max_source_length": { + "en": { + "label": "Max source length", + "info": "Max tokens in source sequence." + }, + "zh": { + "label": "输入序列最大长度", + "info": "输入序列分词后的最大长度。" + } + }, + "max_target_length": { + "en": { + "label": "Max target length", + "info": "Max tokens in target sequence." + }, + "zh": { + "label": "输出序列最大长度", + "info": "输出序列分词后的最大长度。" + } + }, "learning_rate": { "en": { "label": "Learning rate", diff --git a/src/glmtuner/webui/runner.py b/src/glmtuner/webui/runner.py index 8023631..e355f3d 100644 --- a/src/glmtuner/webui/runner.py +++ b/src/glmtuner/webui/runner.py @@ -67,6 +67,8 @@ def run_train( source_prefix: str, dataset_dir: str, dataset: List[str], + max_source_length: int, + max_target_length: int, learning_rate: str, num_train_epochs: str, max_samples: str, @@ -100,6 +102,8 @@ def run_train( source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset), + max_source_length=max_source_length, + max_target_length=max_target_length, learning_rate=float(learning_rate), num_train_epochs=float(num_train_epochs), max_samples=int(max_samples), @@ -142,6 +146,8 @@ def run_eval( source_prefix: str, dataset_dir: str, dataset: List[str], + max_source_length: int, + max_target_length: int, max_samples: str, batch_size: int, predict: bool @@ -171,6 +177,8 @@ def run_eval( source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset), + max_source_length=max_source_length, + max_target_length=max_target_length, max_samples=int(max_samples), per_device_eval_batch_size=batch_size, output_dir=output_dir