From cfc6e6abb178075ddcf00dcbe94148f775bdfc00 Mon Sep 17 00:00:00 2001 From: codingma Date: Sat, 15 Jul 2023 23:43:18 +0800 Subject: [PATCH] update webui module (#291) * add SUPPORTED_MODEL_LIST * 1. save and reload model config edited by user 2. add predefined model configs 3. remove use_v2 config 4. add quantization config 5. simplify model configuration operations 6. fix plot bugs --- src/glmtuner/extras/constants.py | 39 +++++++++++++++ src/glmtuner/hparams/model_args.py | 7 +++ src/glmtuner/webui/chat.py | 36 +++++++------- src/glmtuner/webui/common.py | 49 ++++++++++++++----- src/glmtuner/webui/components/__init__.py | 4 +- src/glmtuner/webui/components/data.py | 3 +- src/glmtuner/webui/components/eval.py | 6 +-- src/glmtuner/webui/components/infer.py | 15 +++--- src/glmtuner/webui/components/model.py | 59 ++++++----------------- src/glmtuner/webui/components/sft.py | 14 +++--- src/glmtuner/webui/interface.py | 5 +- src/glmtuner/webui/runner.py | 53 +++++++++++--------- src/glmtuner/webui/utils.py | 18 ++++--- src/train_web.py | 2 +- 14 files changed, 185 insertions(+), 125 deletions(-) diff --git a/src/glmtuner/extras/constants.py b/src/glmtuner/extras/constants.py index eeb1144..c5d7c02 100644 --- a/src/glmtuner/extras/constants.py +++ b/src/glmtuner/extras/constants.py @@ -5,3 +5,42 @@ FINETUNING_ARGS_NAME = "finetuning_args.json" LAYERNORM_NAMES = ["layernorm"] + +SUPPORTED_MODEL_LIST = [ + { + "name": "chatglm-6b", + "pretrained_model_name": "THUDM/chatglm-6b", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + }, + { + "name": "chatglm2-6b", + "pretrained_model_name": "THUDM/chatglm2-6b", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + }, + { + "name": "chatglm-6b-int8", + "pretrained_model_name": "THUDM/chatglm-6b-int8", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + }, + { + "name": "chatglm2-6b-int8", + "pretrained_model_name": "THUDM/chatglm2-6b-int8", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + }, + { + "name": "chatglm-6b-int4", + "pretrained_model_name": "THUDM/chatglm-6b-int4", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + }, + { + "name": "chatglm2-6b-int4", + "pretrained_model_name": "THUDM/chatglm2-6b-int4", + "local_model_path": None, + "provides": "ChatGLMLLMChain" + } +] \ No newline at end of file diff --git a/src/glmtuner/hparams/model_args.py b/src/glmtuner/hparams/model_args.py index 4f3bd7f..5c880ce 100644 --- a/src/glmtuner/hparams/model_args.py +++ b/src/glmtuner/hparams/model_args.py @@ -70,8 +70,15 @@ class ModelArguments: ) def __post_init__(self): + if self.checkpoint_dir == "": + self.checkpoint_dir = None + # if base model is already quantization version, ignore quantization_bit config + if self.quantization_bit == "" or "int" in self.model_name_or_path: + self.quantization_bit = None + if self.checkpoint_dir is not None: # support merging lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] if self.quantization_bit is not None: + self.quantization_bit = int(self.quantization_bit) assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." diff --git a/src/glmtuner/webui/chat.py b/src/glmtuner/webui/chat.py index 6f5f27e..6a160f9 100644 --- a/src/glmtuner/webui/chat.py +++ b/src/glmtuner/webui/chat.py @@ -1,8 +1,8 @@ import os from typing import List, Tuple -from glmtuner.extras.misc import torch_gc from glmtuner.chat.stream_chat import ChatModel +from glmtuner.extras.misc import torch_gc from glmtuner.hparams import GeneratingArguments from glmtuner.tuner import get_infer_args from glmtuner.webui.common import get_save_dir @@ -15,7 +15,7 @@ def __init__(self): self.tokenizer = None self.generating_args = GeneratingArguments() - def load_model(self, base_model: str, model_list: list, checkpoints: list, use_v2: bool): + def load_model(self, base_model: str, model_path: str, checkpoints: list, quantization_bit: str): if self.model is not None: yield "You have loaded a model, please unload it first." return @@ -24,21 +24,21 @@ def load_model(self, base_model: str, model_list: list, checkpoints: list, use_v yield "Please select a model." return - if len(model_list) == 0: - yield "No model detected." - return - - model_path = [path for name, path in model_list if name == base_model] if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) else: checkpoint_dir = None yield "Loading model..." - + if model_path: + model_name_or_path = model_path + else: + model_name_or_path = base_model args = dict( - model_name_or_path=model_path[0], - checkpoint_dir=checkpoint_dir + model_name_or_path=model_name_or_path, + checkpoint_dir=checkpoint_dir, + quantization_bit=quantization_bit ) super().__init__(*get_infer_args(args)) @@ -52,13 +52,13 @@ def unload_model(self): yield "Model unloaded, please load a model first." def predict( - self, - chatbot: List[Tuple[str, str]], - query: str, - history: List[Tuple[str, str]], - max_length: int, - top_p: float, - temperature: float + self, + chatbot: List[Tuple[str, str]], + query: str, + history: List[Tuple[str, str]], + max_length: int, + top_p: float, + temperature: float ): chatbot.append([query, ""]) response = "" diff --git a/src/glmtuner/webui/common.py b/src/glmtuner/webui/common.py index 1c19234..54fcf70 100644 --- a/src/glmtuner/webui/common.py +++ b/src/glmtuner/webui/common.py @@ -1,18 +1,45 @@ -import os +import codecs import json -import gradio as gr +import os from typing import List, Tuple -from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME +import gradio as gr +from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME +from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -CACHE_DIR = "cache" # to save models +CACHE_DIR = "cache" # to save models DATA_DIR = "data" SAVE_DIR = "saves" +TEMP_USE_CONFIG = "tmp.use.config" + + +def get_temp_use_config_path(): + return os.path.join(SAVE_DIR, TEMP_USE_CONFIG) + + +def load_temp_use_config(): + if not os.path.exists(get_temp_use_config_path()): + return {} + with codecs.open(get_temp_use_config_path()) as f: + try: + user_config = json.load(f) + return user_config + except Exception as e: + return {} + + +def save_temp_use_config(user_config: dict): + with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f: + json.dump(f, user_config, ensure_ascii=False) + + +def save_model_config(model_name: str, model_path: str): + with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f: + json.dump({"model_name": model_name, "model_path": model_path}, f, ensure_ascii=False) def get_save_dir(model_name: str) -> str: - return os.path.join(SAVE_DIR, model_name) + return os.path.join(SAVE_DIR, model_name.split("/")[-1]) def add_model(model_list: list, model_name: str, model_path: str) -> Tuple[list, str, str]: @@ -35,11 +62,11 @@ def list_checkpoints(model_name: str) -> dict: if save_dir and os.path.isdir(save_dir): for checkpoint in os.listdir(save_dir): if ( - os.path.isdir(os.path.join(save_dir, checkpoint)) - and any([ - os.path.isfile(os.path.join(save_dir, checkpoint, name)) - for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) - ]) + os.path.isdir(os.path.join(save_dir, checkpoint)) + and any([ + os.path.isfile(os.path.join(save_dir, checkpoint, name)) + for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) + ]) ): checkpoints.append(checkpoint) return gr.update(value=[], choices=checkpoints) diff --git a/src/glmtuner/webui/components/__init__.py b/src/glmtuner/webui/components/__init__.py index 6a24fab..129bb70 100644 --- a/src/glmtuner/webui/components/__init__.py +++ b/src/glmtuner/webui/components/__init__.py @@ -1,4 +1,4 @@ -from glmtuner.webui.components.model import create_model_tab -from glmtuner.webui.components.sft import create_sft_tab from glmtuner.webui.components.eval import create_eval_tab from glmtuner.webui.components.infer import create_infer_tab +from glmtuner.webui.components.model import create_model_tab +from glmtuner.webui.components.sft import create_sft_tab diff --git a/src/glmtuner/webui/components/data.py b/src/glmtuner/webui/components/data.py index 56d304b..e6cc8fe 100644 --- a/src/glmtuner/webui/components/data.py +++ b/src/glmtuner/webui/components/data.py @@ -1,5 +1,6 @@ -import gradio as gr from typing import Tuple + +import gradio as gr from gradio.components import Component diff --git a/src/glmtuner/webui/components/eval.py b/src/glmtuner/webui/components/eval.py index 2b82b58..ebaa837 100644 --- a/src/glmtuner/webui/components/eval.py +++ b/src/glmtuner/webui/components/eval.py @@ -5,7 +5,7 @@ from glmtuner.webui.runner import Runner -def create_eval_tab(base_model: Component, model_list: Component, checkpoints: Component, runner: Runner) -> None: +def create_eval_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: with gr.Row(): dataset = gr.Dropdown( label="Dataset", info="The name of dataset(s).", choices=list_datasets(), multiselect=True, interactive=True @@ -18,7 +18,7 @@ def create_eval_tab(base_model: Component, model_list: Component, checkpoints: C per_device_eval_batch_size = gr.Slider( label="Batch size", value=8, minimum=1, maximum=128, step=1, info="Eval batch size.", interactive=True ) - use_v2 = gr.Checkbox(label="use ChatGLM2", value=True) + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit") with gr.Row(): start = gr.Button("Start evaluation") @@ -28,7 +28,7 @@ def create_eval_tab(base_model: Component, model_list: Component, checkpoints: C start.click( runner.run_eval, - [base_model, model_list, checkpoints, dataset, max_samples, per_device_eval_batch_size, use_v2], + [base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit], [output] ) stop.click(runner.set_abort, queue=False) diff --git a/src/glmtuner/webui/components/infer.py b/src/glmtuner/webui/components/infer.py index b5f0740..47bfb9f 100644 --- a/src/glmtuner/webui/components/infer.py +++ b/src/glmtuner/webui/components/infer.py @@ -1,5 +1,6 @@ -import gradio as gr from typing import Tuple + +import gradio as gr from gradio.components import Component from glmtuner.webui.chat import WebChatModel @@ -20,13 +21,15 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com with gr.Column(scale=1): clear = gr.Button("Clear History") max_length = gr.Slider( - 10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length", interactive=True + 10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length", + interactive=True ) top_p = gr.Slider( 0, 1, value=chat_model.generating_args.top_p, step=0.01, label="Top P", interactive=True ) temperature = gr.Slider( - 0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature", interactive=True + 0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature", + interactive=True ) history = gr.State([]) @@ -45,7 +48,7 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com return chat_box, chatbot, history -def create_infer_tab(base_model: Component, model_list: Component, checkpoints: Component) -> None: +def create_infer_tab(base_model: Component, model_path: Component, checkpoints: Component) -> None: info_box = gr.Markdown(value="Model unloaded, please load a model first.") chat_model = WebChatModel() @@ -54,10 +57,10 @@ def create_infer_tab(base_model: Component, model_list: Component, checkpoints: with gr.Row(): load_btn = gr.Button("Load model") unload_btn = gr.Button("Unload model") - use_v2 = gr.Checkbox(label="use ChatGLM2", value=True) + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit") load_btn.click( - chat_model.load_model, [base_model, model_list, checkpoints, use_v2], [info_box] + chat_model.load_model, [base_model, model_path, checkpoints, quantization_bit], [info_box] ).then( lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] ) diff --git a/src/glmtuner/webui/components/model.py b/src/glmtuner/webui/components/model.py index ee808e5..bcc9654 100644 --- a/src/glmtuner/webui/components/model.py +++ b/src/glmtuner/webui/components/model.py @@ -1,57 +1,30 @@ -import gradio as gr from typing import Tuple -from gradio.components import Component - -from glmtuner.webui.common import add_model, del_model, list_models, list_checkpoints - - -def create_model_manager(base_model: Component, model_list: Component) -> Component: - with gr.Box(visible=False, elem_classes="modal-box") as model_manager: - model_name = gr.Textbox(lines=1, label="Model name") - model_path = gr.Textbox(lines=1, label="Model path", info="The absolute path to your model.") - - with gr.Row(): - confirm = gr.Button("Save") - cancel = gr.Button("Cancel") - - confirm.click( - add_model, [model_list, model_name, model_path], [model_list, model_name, model_path] - ).then( - lambda: gr.update(visible=False), outputs=[model_manager] - ).then( - list_models, [model_list], [base_model] - ) - cancel.click(lambda: gr.update(visible=False), outputs=[model_manager]) +import gradio as gr +from gradio.components import Component - return model_manager +from glmtuner.extras.constants import SUPPORTED_MODEL_LIST +from glmtuner.webui.common import list_checkpoints, load_temp_use_config, save_model_config def create_model_tab() -> Tuple[Component, Component, Component]: - - model_list = gr.State([]) # gr.State does not accept a dict + user_config = load_temp_use_config() + gr_state = gr.State([]) # gr.State does not accept a dict with gr.Row(): - base_model = gr.Dropdown(label="Model", interactive=True, scale=4) - add_btn = gr.Button("Add model", scale=1) - del_btn = gr.Button("Delete model", scale=1) + model_name = gr.Dropdown([model["pretrained_model_name"] for model in SUPPORTED_MODEL_LIST] + ["custom"], + label="Base Model", info="Model Version of ChatGLM", + value=user_config.get("model_name")) + model_path = gr.Textbox(lines=1, label="Local model path(Optional)", + info="The absolute path of the directory where the local model file is located", + value=user_config.get("model_path")) with gr.Row(): checkpoints = gr.Dropdown(label="Checkpoints", multiselect=True, interactive=True, scale=5) refresh = gr.Button("Refresh checkpoints", scale=1) - model_manager = create_model_manager(base_model, model_list) - - base_model.change(list_checkpoints, [base_model], [checkpoints]) - - add_btn.click(lambda: gr.update(visible=True), outputs=[model_manager]).then( - list_models, [model_list], [base_model] - ) - - del_btn.click(del_model, [model_list, base_model], [model_list]).then( - list_models, [model_list], [base_model] - ) - - refresh.click(list_checkpoints, [base_model], [checkpoints]) + model_name.change(list_checkpoints, [model_name], [checkpoints]) + model_path.change(save_model_config, [model_name, model_path]) + refresh.click(list_checkpoints, [model_name], [checkpoints]) - return base_model, model_list, checkpoints + return model_name, model_path, checkpoints diff --git a/src/glmtuner/webui/components/sft.py b/src/glmtuner/webui/components/sft.py index 1bc1a09..79b6362 100644 --- a/src/glmtuner/webui/components/sft.py +++ b/src/glmtuner/webui/components/sft.py @@ -2,13 +2,13 @@ from gradio.components import Component from transformers.trainer_utils import SchedulerType -from glmtuner.webui.components.data import create_preview_box from glmtuner.webui.common import list_datasets +from glmtuner.webui.components.data import create_preview_box from glmtuner.webui.runner import Runner from glmtuner.webui.utils import can_preview, get_preview, get_time, gen_plot -def create_sft_tab(base_model: Component, model_list: Component, checkpoints: Component, runner: Runner) -> None: +def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: with gr.Row(): finetuning_type = gr.Dropdown( label="Finetuning method", value="lora", choices=["full", "freeze", "p_tuning", "lora"], interactive=True @@ -37,14 +37,16 @@ def create_sft_tab(base_model: Component, model_list: Component, checkpoints: Co max_samples = gr.Textbox( label="Max samples", value="100000", info="Number of samples for training.", interactive=True ) - use_v2 = gr.Checkbox(label="use ChatGLM2", value=True) + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit", + interactive=True) with gr.Row(): per_device_train_batch_size = gr.Slider( label="Batch size", value=4, minimum=1, maximum=128, step=1, info="Train batch size.", interactive=True ) gradient_accumulation_steps = gr.Slider( - label="Gradient accumulation", value=4, minimum=1, maximum=16, step=1, info='Accumulation steps.', interactive=True + label="Gradient accumulation", value=4, minimum=1, maximum=16, step=1, info='Accumulation steps.', + interactive=True ) lr_scheduler_type = gr.Dropdown( label="LR Scheduler", value="cosine", info="Scheduler type.", @@ -77,9 +79,9 @@ def create_sft_tab(base_model: Component, model_list: Component, checkpoints: Co start.click( runner.run_train, [ - base_model, model_list, checkpoints, output_dir, finetuning_type, + base_model, model_path, checkpoints, output_dir, finetuning_type, dataset, learning_rate, num_train_epochs, max_samples, - fp16, use_v2, per_device_train_batch_size, gradient_accumulation_steps, + fp16, quantization_bit, per_device_train_batch_size, gradient_accumulation_steps, lr_scheduler_type, logging_steps, save_steps ], output_info diff --git a/src/glmtuner/webui/interface.py b/src/glmtuner/webui/interface.py index e5b862d..c83560e 100644 --- a/src/glmtuner/webui/interface.py +++ b/src/glmtuner/webui/interface.py @@ -1,15 +1,14 @@ import gradio as gr from transformers.utils.versions import require_version -from glmtuner.webui.css import CSS -from glmtuner.webui.runner import Runner from glmtuner.webui.components import ( create_model_tab, create_sft_tab, create_eval_tab, create_infer_tab ) - +from glmtuner.webui.css import CSS +from glmtuner.webui.runner import Runner require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") diff --git a/src/glmtuner/webui/runner.py b/src/glmtuner/webui/runner.py index e1c8497..4824151 100644 --- a/src/glmtuner/webui/runner.py +++ b/src/glmtuner/webui/runner.py @@ -1,13 +1,14 @@ -import os -import time import logging +import os import threading -import transformers +import time from typing import Optional, Tuple -from glmtuner.extras.misc import torch_gc +import transformers + from glmtuner.extras.callbacks import LogCallback from glmtuner.extras.logging import LoggerHandler +from glmtuner.extras.misc import torch_gc from glmtuner.tuner import get_train_args, run_sft from glmtuner.webui.common import get_save_dir, DATA_DIR from glmtuner.webui.utils import format_info, get_eval_results @@ -23,16 +24,13 @@ def set_abort(self): self.aborted = True self.running = False - def initialize(self, base_model: str, model_list: list, dataset: list) -> Tuple[str, LoggerHandler, LogCallback]: + def initialize(self, base_model: str, model_path: str, dataset: list) -> Tuple[str, LoggerHandler, LogCallback]: if self.running: return "A process is in running, please abort it firstly.", None, None if not base_model: return "Please select a model.", None, None - if len(model_list) == 0: - return "No model detected.", None, None - if len(dataset) == 0: return "Please choose datasets.", None, None @@ -56,24 +54,28 @@ def finalize(self, finish_info: Optional[str] = None) -> str: return finish_info if finish_info is not None else "Finished" def run_train( - self, base_model, model_list, checkpoints, output_dir, finetuning_type, - dataset, learning_rate, num_train_epochs, max_samples, - fp16, use_v2, per_device_train_batch_size, gradient_accumulation_steps, - lr_scheduler_type, logging_steps, save_steps + self, base_model, model_path, checkpoints, output_dir, finetuning_type, + dataset, learning_rate, num_train_epochs, max_samples, + fp16, quantization_bit, per_device_train_batch_size, gradient_accumulation_steps, + lr_scheduler_type, logging_steps, save_steps ): - error, logger_handler, trainer_callback = self.initialize(base_model, model_list, dataset) + error, logger_handler, trainer_callback = self.initialize(base_model, model_path, dataset) if error: yield error return - model_path = [path for name, path in model_list if name == base_model] if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) else: checkpoint_dir = None + if model_path: + model_name_or_path = model_path + else: + model_name_or_path = base_model args = dict( - model_name_or_path=model_path[0], + model_name_or_path=model_name_or_path, do_train=True, finetuning_type=finetuning_type, dataset=",".join(dataset), @@ -90,7 +92,7 @@ def run_train( learning_rate=float(learning_rate), num_train_epochs=float(num_train_epochs), fp16=fp16, - use_v2=use_v2 + quantization_bit=quantization_bit ) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) @@ -114,23 +116,28 @@ def run_train( yield self.finalize() def run_eval( - self, base_model, model_list, checkpoints, dataset, max_samples, per_device_eval_batch_size, use_v2 + self, base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, + quantization_bit ): - error, logger_handler, trainer_callback = self.initialize(base_model, model_list, dataset) + error, logger_handler, trainer_callback = self.initialize(base_model, model_path, dataset) if error: yield error return - model_path = [path for name, path in model_list if name == base_model] if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) output_dir = os.path.join(get_save_dir(base_model), "eval_" + "_".join(checkpoints)) else: checkpoint_dir = None output_dir = os.path.join(get_save_dir(base_model), "eval_base") + if model_path: + model_name_or_path = model_path + else: + model_name_or_path = base_model args = dict( - model_name_or_path=model_path[0], + model_name_or_path=model_name_or_path, do_eval=True, dataset=",".join(dataset), dataset_dir=DATA_DIR, @@ -140,7 +147,7 @@ def run_eval( overwrite_cache=True, predict_with_generate=True, per_device_eval_batch_size=per_device_eval_batch_size, - use_v2=use_v2 + quantization_bit=quantization_bit ) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) diff --git a/src/glmtuner/webui/utils.py b/src/glmtuner/webui/utils.py index 206a17c..5621b84 100644 --- a/src/glmtuner/webui/utils.py +++ b/src/glmtuner/webui/utils.py @@ -1,10 +1,11 @@ -import os import json +import os +from datetime import datetime +from typing import Tuple + import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import Tuple -from datetime import datetime from glmtuner.extras.ploting import smooth from glmtuner.webui.common import get_save_dir, DATA_DIR @@ -27,9 +28,9 @@ def can_preview(dataset: list) -> dict: with open(os.path.join(DATA_DIR, "dataset_info.json"), "r", encoding="utf-8") as f: dataset_info = json.load(f) if ( - len(dataset) > 0 - and "file_name" in dataset_info[dataset[0]] - and os.path.isfile(os.path.join(DATA_DIR, dataset_info[dataset[0]]["file_name"])) + len(dataset) > 0 + and "file_name" in dataset_info[dataset[0]] + and os.path.isfile(os.path.join(DATA_DIR, dataset_info[dataset[0]]["file_name"])) ): return gr.update(visible=True) else: @@ -63,8 +64,9 @@ def gen_plot(base_model: str, output_dir: str) -> matplotlib.figure.Figure: with open(log_file, "r", encoding="utf-8") as f: for line in f: log_info = json.loads(line) - steps.append(log_info["current_steps"]) - losses.append(log_info["loss"]) + if log_info["loss"]: + steps.append(log_info["current_steps"]) + losses.append(log_info["loss"]) ax.plot(steps, losses, alpha=0.4, label="original") ax.plot(steps, smooth(losses), label="smoothed") ax.legend() diff --git a/src/train_web.py b/src/train_web.py index 92a1992..d048fb9 100644 --- a/src/train_web.py +++ b/src/train_web.py @@ -4,7 +4,7 @@ def main(): demo = create_ui() demo.queue() - demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) + demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) if __name__ == "__main__":