Skip to content

Commit

Permalink
[Feature] Support ChatGLM (#62)
Browse files Browse the repository at this point in the history
* no space for ChatGLM

* cast to inference during generate

* support chatglm tokenizer

* support chatglm qlora

* fix pre-commit

* update doc

* update

* Update README.md

* Update README_zh-CN.md

* add chatglm template

* fix chat bugs

* add round for template_map_fn

* add round for EvaluateChatHook
  • Loading branch information
LZHgrla authored Aug 29, 2023
1 parent 9f03ddf commit 3bdeacb
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 17 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ English | [简体中文](README_zh-CN.md)
XTuner is a toolkit for efficiently fine-tuning LLM, developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams.

- **Efficiency**: Support LLM fine-tuning on consumer-grade GPUs. The minimum GPU memory required for 7B LLM fine-tuning is only **8GB**, indicating that users can use nearly any GPU (even the free resource, *e.g.*, Colab) to fine-tune custom LLMs.
- **Versatile**: Support various **LLMs** ([InternLM](https://github.com/InternLM/InternLM), [Llama2](https://github.com/facebookresearch/llama), [Qwen](https://github.com/QwenLM/Qwen-7B), [Baichuan](https://github.com/baichuan-inc), ...), **datasets** ([MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Arxiv GenTitle](https://github.com/WangRongsheng/ChatGenTitle), [Chinese Law](https://github.com/LiuHC0428/LAW-GPT), [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), ...) and **algorithms** ([QLoRA](http://arxiv.org/abs/2305.14314), [LoRA](http://arxiv.org/abs/2106.09685)), allowing users to choose the most suitable solution for their requirements.
- **Versatile**: Support various **LLMs** ([InternLM](https://github.com/InternLM/InternLM), [Llama2](https://github.com/facebookresearch/llama), [ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b), [Qwen](https://github.com/QwenLM/Qwen-7B), [Baichuan](https://github.com/baichuan-inc), ...), **datasets** ([MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Arxiv GenTitle](https://github.com/WangRongsheng/ChatGenTitle), [Chinese Law](https://github.com/LiuHC0428/LAW-GPT), [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), ...) and **algorithms** ([QLoRA](http://arxiv.org/abs/2305.14314), [LoRA](http://arxiv.org/abs/2106.09685)), allowing users to choose the most suitable solution for their requirements.
- **Compatibility**: Compatible with [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 and [HuggingFace](https://huggingface.co) 🤗 training pipeline, enabling effortless integration and utilization.

## 🌟 Demos
Expand Down Expand Up @@ -56,6 +56,7 @@ XTuner is a toolkit for efficiently fine-tuning LLM, developed by the [MMRazor](
<li><a href="https://github.com/facebookresearch/llama">Llama</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen-Chat</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
Expand Down
3 changes: 2 additions & 1 deletion README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
XTuner 是一个轻量级微调大语言模型的工具库,由 [MMRazor](https://github.com/open-mmlab/mmrazor)[MMDeploy](https://github.com/open-mmlab/mmdeploy) 团队联合开发。

- **轻量级**: 支持在消费级显卡上微调大语言模型。对于 7B 参数量,微调所需的最小显存仅为 **8GB**,这使得用户可以使用几乎任何显卡(甚至免费资源,例如Colab)来微调获得自定义大语言模型助手。
- **多样性**: 支持多种**大语言模型**[InternLM](https://github.com/InternLM/InternLM)[Llama2](https://github.com/facebookresearch/llama)[Qwen](https://github.com/QwenLM/Qwen-7B)[Baichuan](https://github.com/baichuan-inc), ...),**数据集**[MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data)[Colorist](https://huggingface.co/datasets/burkelibbey/colors)[Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K)[Arxiv GenTitle](https://github.com/WangRongsheng/ChatGenTitle)[Chinese Law](https://github.com/LiuHC0428/LAW-GPT)[OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca)[Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus),...)和**微调算法**[QLoRA](http://arxiv.org/abs/2305.14314)[LoRA](http://arxiv.org/abs/2106.09685)),支撑用户根据自身具体需求选择合适的解决方案。
- **多样性**: 支持多种**大语言模型**[InternLM](https://github.com/InternLM/InternLM)[Llama2](https://github.com/facebookresearch/llama)[ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b)[Qwen](https://github.com/QwenLM/Qwen-7B)[Baichuan](https://github.com/baichuan-inc), ...),**数据集**[MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data)[Colorist](https://huggingface.co/datasets/burkelibbey/colors)[Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K)[Arxiv GenTitle](https://github.com/WangRongsheng/ChatGenTitle)[Chinese Law](https://github.com/LiuHC0428/LAW-GPT)[OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca)[Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus),...)和**微调算法**[QLoRA](http://arxiv.org/abs/2305.14314)[LoRA](http://arxiv.org/abs/2106.09685)),支撑用户根据自身具体需求选择合适的解决方案。
- **兼容性**: 兼容 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 和 [HuggingFace](https://huggingface.co) 🤗 的训练流程,支撑用户无感式集成与使用。

## 🌟 示例
Expand Down Expand Up @@ -56,6 +56,7 @@ XTuner 是一个轻量级微调大语言模型的工具库,由 [MMRazor](https
<li><a href="https://github.com/facebookresearch/llama">Llama</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen-Chat</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
Expand Down
6 changes: 3 additions & 3 deletions xtuner/dataset/map_fns/template_map_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def template_map_fn(example, template):
input = single_turn_conversation['input']
if i == 0:
single_turn_conversation[
'input'] = template.INSTRUCTION_START.format(input=input)
'input'] = template.INSTRUCTION_START.format(
input=input, round=i + 1)
else:
single_turn_conversation['input'] = template.INSTRUCTION.format(
input=input)

input=input, round=i + 1)
return {'conversation': conversation}


Expand Down
3 changes: 3 additions & 0 deletions xtuner/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def encode_fn(example, tokenizer, max_length, input_ids_with_output=True):
if tokenizer.__class__.__name__ == 'QWenTokenizer':
bos_token = ''
eos_token = '<|endoftext|>'
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token = ''
eos_token = tokenizer.eos_token
else:
bos_token = tokenizer.bos_token
eos_token = tokenizer.eos_token
Expand Down
15 changes: 14 additions & 1 deletion xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,16 @@ def _generate_samples(self, runner, max_new_tokens=None):

device = next(iter(model.parameters())).device

is_checkpointing = model.llm.is_gradient_checkpointing
use_cache = model.llm.config.use_cache

# Cast to inference mode
model.llm.gradient_checkpointing_disable()
model.llm.config.use_cache = True

for sample_input in self.sample_inputs:
inputs = self.instruction.format(input=sample_input, **runner.cfg)
inputs = self.instruction.format(
input=sample_input, round=1, **runner.cfg)
input_ids = self.tokenizer(
inputs, return_tensors='pt')['input_ids']
input_ids = input_ids.to(device)
Expand All @@ -50,6 +58,11 @@ def _generate_samples(self, runner, max_new_tokens=None):
f'Sample output:\n'
f'{self.tokenizer.decode(generation_output[0])}\n')

# Cast to training mode
if is_checkpointing:
model.llm.gradient_checkpointing_enable()
model.llm.config.use_cache = use_cache

def before_train(self, runner):
runner.logger.info('before_train in EvaluateChatHook .')
self._generate_samples(runner, max_new_tokens=50)
Expand Down
2 changes: 2 additions & 0 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def find_all_linear_names(model):

if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
if 'output_layer' in lora_module_names: # needed for 16-bit
lora_module_names.remove('output_layer')
return list(lora_module_names)


Expand Down
14 changes: 9 additions & 5 deletions xtuner/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def main():
cfg.merge_from_dict(args.cfg_options)

model = BUILDER.build(cfg.model)
# Cast to inference mode
model.llm.gradient_checkpointing_disable()
model.llm.config.use_cache = True

tokenizer = BUILDER.build(cfg.tokenizer)

if args.adapter is not None:
Expand Down Expand Up @@ -169,9 +173,10 @@ def main():
template = PROMPT_TEMPLATE[args.prompt_template]
if 'INSTRUCTION_START' in template and n_turn == 0:
prompt_text = template['INSTRUCTION_START'].format(
input=text, **cfg)
input=text, round=n_turn + 1, **cfg)
else:
prompt_text = template['INSTRUCTION'].format(input=text, **cfg)
prompt_text = template['INSTRUCTION'].format(
input=text, round=n_turn + 1, **cfg)
if args.prompt_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
Expand All @@ -192,8 +197,7 @@ def main():
inputs += prompt_text
else:
inputs += text
ids = tokenizer.encode(
inputs, return_tensors='pt', add_special_tokens=n_turn == 0)
ids = tokenizer.encode(inputs, return_tensors='pt')
streamer = Streamer(tokenizer) if Streamer is not None else None
if args.with_plugins is not None:
generate_output = model.generate(
Expand Down Expand Up @@ -241,7 +245,7 @@ def main():
generate_output[0][len(ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs = tokenizer.decode(generate_output[0]) + '\n'
inputs = tokenizer.decode(generate_output[0])
n_turn += 1
if len(generate_output[0]) >= args.max_new_tokens:
print('Remove the memory of history responses, since '
Expand Down
9 changes: 4 additions & 5 deletions xtuner/tools/chat_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def main():
template = PROMPT_TEMPLATE[args.prompt_template]
if 'INSTRUCTION_START' in template and n_turn == 0:
prompt_text = template['INSTRUCTION_START'].format(
input=text, bot_name=args.bot_name)
input=text, round=n_turn + 1, bot_name=args.bot_name)
else:
prompt_text = template['INSTRUCTION'].format(
input=text, bot_name=args.bot_name)
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
Expand All @@ -174,8 +174,7 @@ def main():
inputs += prompt_text
else:
inputs += text
ids = tokenizer.encode(
inputs, return_tensors='pt', add_special_tokens=n_turn == 0)
ids = tokenizer.encode(inputs, return_tensors='pt')
streamer = Streamer(tokenizer) if Streamer is not None else None
if args.with_plugins is not None:
generate_output = model.generate(
Expand Down Expand Up @@ -223,7 +222,7 @@ def main():
generate_output[0][len(ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs = tokenizer.decode(generate_output[0]) + '\n'
inputs = tokenizer.decode(generate_output[0])
n_turn += 1
if len(generate_output[0]) >= args.max_new_tokens:
print('Remove the memory of history responses, since '
Expand Down
2 changes: 1 addition & 1 deletion xtuner/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_chat_utils(model):
is_internlm = 'InternLM' in base_model_name
is_qwen = 'QWen' in base_model_name
no_space = 'InternLM' in base_model_name or 'QWen' in base_model_name or \
'BaiChuan' in base_model_name
'BaiChuan' in base_model_name or 'ChatGLM' in base_model_name
stop_criteria = StoppingCriteriaList()
if is_internlm:
stop_criteria.append(InternLMStoppingCriteria())
Expand Down
3 changes: 3 additions & 0 deletions xtuner/utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,7 @@
'TABLE statement.\n'
'### Question: {input}\n### Query: '),
INSTRUCTION=('### Question: {input}\n### Query: ')),
chatglm=dict(
INSTRUCTION_START='[Round {round}]\n\n问:{input}\n\n答:',
INSTRUCTION='\n\n[Round {round}]\n\n问:{input}\n\n答:'),
)

0 comments on commit 3bdeacb

Please sign in to comment.