This repository contains fine-tuning scripts for both supervised fine-tuning (SFT) and alignment scripts. Our goal is to create a model-agnostic fine-tuning pipeline and evaluation scripts focusing on the usability of the Thai language. The repository consists of three training scripts: (i) supervised fine-tuning (SFT), (ii) direct preference optimization (DPO), and (iii) odds ratio preference optimization (ORPO).
- Supported base LLMs
- Released Models
- Evaluation
- Installation
- Prepare Dataset (Optional)
- Fine-tuning
- Inference
- Deployment
- Retrieval Augmented Generation (RAG)
- Acknowledgements
- Future Plans
- Citation
Here is the list of supported base LLMs that we have tested on our scripts.
- LLaMa3
- SeaLLMs
- PolyLM
- Typhoon
- SEA-LION (Please refer to GitHub: vistec-AI/WangchanLion for the full detail)
- Gemma 2
We apply our fine-tuning pipeline to various open-source models and publish their weights as follows:
The models that trained on small instruction datasets
The models that trained on large instruction datasets. For reproducibility, we provide the scripts for dataset collection and preprocessing in this repository.
We evaluate LLMs using the Benchmark Suite for Southeast Asian Languages. For detailed information on our evaluation methodology and benchmarking process, visit the SEACrowd project repository.
- Please install all dependencies in
requirements.txt
using pip install as
pip3 install -r requirements.txt
- Please install Flash Attention 2 using pip install as
pip3 install flash-attn --no-build-isolation
- Go to the
Fine-tuning
section and select the training strategy that is suitable for your constraints.
- If you want to use a custom dataset, you need to reformat the file by editing it.
python3 reformat.py
- If you want to use the demo dataset, you can download it from this.
This dataset includes 6 datasets:
- pythainlp/han-instruct-dataset-v2.0
- databricks/databricks-dolly-15k
- databricks/databricks-dolly-15k (translated English to Thai by Gemini)
- math_14k
- math_14k (translated English to Thai by Gemini)
- iapp_wiki_qa_squad
-
Creating the Dataset:
-
Go to the create dataset script page.
-
Download the script provided there.
-
Run the following command in your terminal:
python main.py --output_dir /<path>/flan_dataset
This will create the full dataset in a directory called
flan_dataset
.
-
-
Updating the Configuration:
-
Find the configuration file for your specific model and training mode.
-
The file will be located at:
recipes/<model_name>/<mode>/config_<method>.yaml
-
For example, if you're using the LLaMA3-8b model for supervised fine-tuning (sft), the file would be:
recipe/llama3-8b/sft/config_full.yaml
-
Open this file and update the
dataset_mixer
section to point to your newly created dataset:# Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" dataset_mixer: /<path>/flan_dataset: 1.0 # <- This is the path to your newly created dataset dataset_splits: - train preprocessing_num_workers: 12
The key change is in the
dataset_mixer
section, where/<path>/flan_dataset
should be the path to your created dataset. -
By following these steps, you'll have prepared the full dataset and updated your configuration file to use it for training your model.
To start fine-tuning your own LLM, we recommend using QLoRa fine-tuning because it consumes much fewer resources compared to fully fine-tuning the LLM. Please note that the provided examples are all LLaMa3. The main template for the script is structured as
{RUNNER} scripts/run_{MODE}.py {RECIPE}
The main parameters are
Parameter | Description |
---|---|
RUNNER |
Can be python for single-GPU fine-tuning or accelerate with the argument --config_file {ACCELERATION_CONFIG} for multi-GPU training. |
ACCELERATION_CONFIG |
The mode to launch the trainer in multiple setups. Mainly, there are vanilla multi-GPU and ZeRO3 offloading for lower GPU memory usage with IO overhead. Available configurations are in recipes/accelerate_configs . |
MODE |
Can be sft (supervised fine-tuning) or dpo (direct preference optimization). |
RECIPE |
Based on the model types in the recipes folder. |
QLoRa fine-tuning example
The simplest way to start fine-tuning your LLM is to use plain Python on a single GPU. You can do the supervised fine-tuning (SFT) and direct preference optimization (DPO) as in the following step.
# Step 1 - SFT
python scripts/run_sft.py recipes/llama3-8b/sft/config_qlora.yaml
# Step 2 - DPO (optional)
python scripts/run_dpo.py recipes/llama3-8b/dpo/config_qlora.yaml
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=4 scripts/run_sft.py recipes/llama3-8b/sft/config_qlora.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=4 scripts/run_dpo.py recipes/llama3-8b/dpo/config_qlora.yaml
Full fine-tuning example
You can fine-tune the whole model using the following scripts.
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml scripts/run_sft.py recipes/llama3-8b/sft/config_full.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml scripts/run_dpo.py recipes/llama3-8b/dpo/config_full.yaml
config_file
argument, you are good to go!
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/llama3-8b/sft/config_full.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/llama3-8b/dpo/config_full.yaml
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Model path
path = "airesearch/LLaMa3-8b-WangchanX-sft-Full"
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto")
messages = [
{"role": "user", "content": "ลิเก กับ งิ้ว ต่างกันอย่างไร"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
print(tokenizer.decode(tokenized_chat[0]))
Output:
<|user|>
ลิเก กับ งิ้ว ต่างกันอย่างไร<|end_of_text|>
<|assistant|>
outputs = model.generate(tokenized_chat, max_length=2048)
print(tokenizer.decode(outputs[0]))
Output:
<|user|>
ลิเก กับ งิ้ว ต่างกันอย่างไร<|end_of_text|>
<|assistant|>
ก่อนอื่นเราต้องรู้ความหมายของคำทั้งสอง คำว่า ลิเก เป็นศิลปะการแสดงแบบดั้งเดิมในประเทศไทย ส่วนคำว่า งิ้วน่าจะเป็นการนำภาษาไทยมาแปลจากคำว่า อินโดปีเลีย (indoplea) ซึ่งเป็นชื่อเรียกดนตรีที่มีต้นกำเนิดจากรัฐอุตตาร์ประเทศ ในอินเดีย และได้แพร่หลายไปยังเอเชียตะวันออกเฉียงใต้ โดยเฉพาะสาธารณรัฐประชาชนจีนและเวียดนาม จึงทำให้เกิดคำว่า งิ้วด้วย แต่ทุกคนไม่รู้ว่ามันก็คืออะไรจริง ๆ แล้ว มันมีความแตกต่างกันมาก เพราะถ้าไปถามชาวบ้านบางแห่งอาจจะบอกว่าเป็นอีกประเภทหนึ่งของเพลงโบราณหรือเพลงพื้นเมือง หรือถ้าพูดตามหลักทางประวัติศาสตร์ก็จะกล่าวว่านั่นคือ การขับร้องเพลงที่ใช้รูปแบบการประสานเสียงแบบฮินดู-ซิกห์วัล ที่ผสมผสานระหว่างภาษาอังกฤษ ภาษาจีนกลาง ภาษาพม่า และภาษาทางเหนือกับภาษาลาว รวมถึงภาษากลุ่มออสเตรโลไนว์ในอดีต ดังนั้นตอนนี้คุณสามารถสรุปได้อย่างแม่นยำว่าสองอย่างเหล่านี้แตกต่างกันอย่างไร: ลิเก คือ ศิลปะการแสดงที่มีมายาวนานกว่า 100 ปีในประเทศไทย เช่น ลิเกล้านนา, ลิเกตลุง, ลิเกล้อ ฯลฯ ขณะที่ งิ้ว หมายถึง เพลงประสานเสียงที่มีรากเหง้าของวงการเพลงคลาสสิคในอินเดีย และแพร่กระจายในเอเชียตะวันตกเฉียงใต้เป็นสิ่งแรกๆ หลังจากการเผยแผ่ศาสนายุคแรกๆ นอกจากนี้ ยังมีการรวมแนวเพลงเพื่อรวมเข้ากับการเต้นร่วมสมัยและบทละครที่มีอิทธิพลจากวรรณกรรมจีน<|end_of_text|>
See Deployments.md for details on deploying pre-trained Large Language Models (LLMs) using Text Generation Inference (TGI), LocalAI, and Ollama frameworks.
See RAG.md for details on setting up a Retrieval Augmented Generation system using Flowise, LocalAI, and Ollama frameworks for enhancing language model generation with retrieved knowledge.
We would like to thank all codes and structures from alignment-handbook. This project is sponsored by VISTEC, PTT, SCBX, and SCB.
Here are some future plans and what we are doing:
- Adding model and codes for ORPO. Currently, we have codes and preliminary models from the ORPO technique. We are planning to release them soon.
- Thai LLMs benchmark. We are planning to create a machine reading comprehension leaderboard for Thai LLMs. We are happy for any ideas or contributions from everyone.
If you use WangchanX or WangchanX Eval in your project or publication, please cite the library as follows
@misc{phatthiyaphaibun2024wangchanlion,
title={WangchanLion and WangchanX MRC Eval},
author={Wannaphong Phatthiyaphaibun and Surapon Nonesung and Patomporn Payoungkhamdee and Peerat Limkonchotiwat and Can Udomcharoenchaikit and Jitkapat Sawatphol and Chompakorn Chaksangchaichot and Ekapol Chuangsuwanich and Sarana Nutanong},
year={2024},
eprint={2403.16127},
archivePrefix={arXiv},
primaryClass={cs.CL}
}