-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama2 7b model C++ example #3666
Open
ototh-htec
wants to merge
57
commits into
develop
Choose a base branch
from
htec/mgx-llama2-7b-example
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,226
−0
Open
Changes from 45 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
4b6c099
WIP mgx llama2-7b example
gyulaz-htec 1edf198
Code works with offload copy
gyulaz-htec 7cd6a0c
Add support to load onnx file
ototh-htec 595771f
Add dockerization for mgx_llama2 example
ototh-htec fdd16ed
Rework buffer allocation so offload_copy can be turned on/off
gyulaz-htec 1b78a1c
Use dedicated hipStream for synchronization
gyulaz-htec 0731e26
Save onnx model to mxr file
ototh-htec 2ab6659
Only copy changed data
gyulaz-htec e7f84c2
Extend model loading options with fast_math
gyulaz-htec 0fe7924
Fix quant message
gyulaz-htec 984e0dc
Basic tokens/sec counting
gyulaz-htec 4134505
Add preprocess dataset script
ototh-htec 6735534
Use dataset from numpy files if available
ototh-htec f1960f7
Support npy dataset with multiple samples
ototh-htec 55e41f4
Add missing upload to device for multiple samples
ototh-htec 4d25c49
Add accuracy calculation for mgx_llama2 example
ototh-htec 3dd55b5
Use MIGraphX from develop branch in Dockerfile
ototh-htec 2416786
Fix dataset loading
gyulaz-htec 0877f72
Add README to C++ LLama2 example
gyulaz-htec 513cc0d
Add buffers for llama2 7b quantized models
ototh-htec 048e587
Fix Llama2-7b model file parse and input buffers
ototh-htec 71b7522
Fix llama 7b quantized model evaluation step, use 2 models
ototh-htec 9220cf1
Support llama 7b quantized model without offload copy
ototh-htec 3842b54
Connect dataset to llama 7b quantized model
ototh-htec d8a85cd
Fix output buffer usage for new sample
ototh-htec 37ed15b
Comment out past/present_key_value binding
ototh-htec d970b30
Add new migraphx public API functions: replace_return and get_last_in…
gyulaz-htec a7801b0
Add romxProfileData to the example container
gyulaz-htec 9152cd1
Update MGX branch and en variables in the example docker
gyulaz-htec 14df5db
Fix example readme
gyulaz-htec f7d3fbf
Fix typo in example preproc script
gyulaz-htec 58bc2c2
Update LLama2 example to target specific device and enable fast_math …
gyulaz-htec e2856a7
Add eval_accuracy script dependencies to Dockerfile
ototh-htec 632c026
Add argmax program to lessen DToH copy + CPU computation overhead
gyulaz-htec 1929b44
Update LLama2 example docker file ENVs
gyulaz-htec 19f36b4
Disable fast_math
gyulaz-htec 74c9f4d
Improve input_ids buffer handling, use different program arguments fo…
ototh-htec 9f16a14
Move mgxllama2 components to multiple files
ototh-htec e9d0a81
Move one dim input ids to input class
ototh-htec 10596c1
Refactor outputs to llama2outputs file and move main to MGXLlama2 struct
ototh-htec d9c2746
Refactor mgxllama2
ototh-htec fbd46b8
Use batch size from config
ototh-htec 583a439
Remove offload copy option
ototh-htec a95ee80
Implement batching for mgxllama2 example
ototh-htec 2142a40
Fix last batch when there is not enough sample
ototh-htec b919b0f
Revert "Add new migraphx public API functions: replace_return and get…
ototh-htec 41447b0
Remove unused python imports
ototh-htec 91f368b
Format files with clang-format
ototh-htec e5c0d62
Format files with clang-format
ototh-htec 7d0bb2f
Fix delete modifier ident
ototh-htec 43364a1
Fix python files format issues
ototh-htec 5f2190f
Fix python files format issues 2
ototh-htec 7468a3e
Merge branch 'develop' into htec/mgx-llama2-7b-example
ototh-htec 0640809
Make output results const for cppcheck
ototh-htec 20fb8bc
Pass GPU_TARGET from build_docker script to Dockerfile
ototh-htec 9165b90
Merge branch 'develop' into htec/mgx-llama2-7b-example
ototh-htec b499638
Add license
ototh-htec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
project(MGXLlama2) | ||
cmake_minimum_required(VERSION 3.22) | ||
|
||
set(TARGET_NAME mgxllama2) | ||
|
||
set(HARNESS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/harness) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CXX /opt/rocm/llvm/bin/clang++) | ||
|
||
list (APPEND CMAKE_PREFIX_PATH /opt/rocm ${HIP_PATH}) | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -W -Wall -pthread -D__HIP_PLATFORM_HCC__=1") | ||
|
||
find_package(migraphx REQUIRED) | ||
find_package(hip REQUIRED) | ||
|
||
include_directories(${HARNESS_DIR}) | ||
|
||
add_executable(${TARGET_NAME} | ||
mgxllama2.cc | ||
) | ||
|
||
target_include_directories(${TARGET_NAME} | ||
PUBLIC ${HARNESS_DIR} | ||
) | ||
|
||
target_link_libraries(${TARGET_NAME} | ||
migraphx::c | ||
hip::device | ||
pthread | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
FROM rocm/dev-ubuntu-22.04:6.2 | ||
|
||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
SHELL ["/bin/bash", "-c"] | ||
|
||
RUN apt-get update && apt-get install -y --allow-unauthenticated \ | ||
apt-utils \ | ||
cmake \ | ||
half \ | ||
sqlite3 \ | ||
libsqlite3-dev \ | ||
libfmt-dev \ | ||
git && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
RUN mkdir /app && cd /app && git clone https://github.com/ROCm/rocmProfileData | ||
WORKDIR /app/rocmProfileData | ||
RUN make; make install | ||
|
||
RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git | ||
WORKDIR /migraphx/AMDMIGraphX | ||
|
||
RUN ./tools/install_prereqs.sh | ||
ENV MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1 | ||
ENV MIGRAPHX_USE_HIPBLASLT=1 | ||
ENV MIGRAPHX_USE_MIOPEN=1 | ||
|
||
#TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS | ||
RUN mkdir build && cd build && \ | ||
CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS='gfx942' && \ | ||
make -j$(nproc) && \ | ||
make install | ||
|
||
RUN mkdir /mgx_llama2 | ||
|
||
COPY . /mgx_llama2 | ||
|
||
RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build | ||
|
||
WORKDIR /mgx_llama2/build | ||
|
||
RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make | ||
|
||
RUN pip install pandas evaluate nltk transformers sentencepiece rouge_score | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
## Getting the model | ||
|
||
### Getting the pre-quantized model from HuggingFace | ||
```bash | ||
pip install -U "huggingface_hub[cli]" | ||
huggingface-cli login YOUR_HF_TOKEN | ||
hugginggface-cli download https://huggingface.co/amd/Llama-2-7b-chat-hf-awq-int4-asym-gs128-onnx | ||
``` | ||
Alternatively you can quantize the model yourself. | ||
|
||
### Quantizing the model | ||
|
||
**If you are using the pre-quantized model you can skip this section.** | ||
|
||
Get the latest quark quantizer version from https://xcoartifactory/ui/native/uai-pip-local/com/amd/quark/main/nightly/ . Downloading the zip is recommended because it contains the required scripts. The quark version used when this was created: quark-1.0.0.dev20241028+eb46b7438 (28-10-24). | ||
|
||
Also we will need to install the onnxruntime-genai (OGA) tool to convert the quark_safetensors format to onnx format properly. | ||
|
||
#### Installing quark and it's dependencies: | ||
```bash | ||
# install OGA tool | ||
pip install onnxruntime-genai | ||
|
||
# Quark dependencies according to https://quark.docs.amd.com/latest/install.html, we assume pytorch is already installed. You can use the following base docker image which has torch installed: rocm/pytorch:rocm6.2.2_ubuntu20.04_py3.9_pytorch_release_2.2.1 | ||
pip install onnxruntime onnx | ||
|
||
# Install the whl | ||
unzip quark-1.0.0.dev20241028+eb46b7438.zip -d quark | ||
cd quark | ||
RUN pip install quark-1.0.0.dev20241028+eb46b7438-py3-none-any.whl | ||
``` | ||
|
||
#### Quantizing the model and converting to ONNX | ||
```bash | ||
cd quark/examples/torch/language_modeling/llm_ptq | ||
|
||
export MODEL_DIR = [local model checkpoint folder] or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf | ||
export QUANTIZED_MODEL_DIR = [output model checkpoint folder] | ||
|
||
python3 quantize_quark.py --model_dir $MODEL_DIR \ | ||
--data_type float16 \ | ||
--quant_scheme w_uint4_per_group_asym \ | ||
--num_calib_data 128 \ | ||
--quant_algo awq \ | ||
--dataset pileval_for_awq_benchmark \ | ||
--seq_len 1024 \ | ||
--output_dir $MODEL_DIR-awq-uint4-asym-g128-f16 \ | ||
--model_export quark_safetensors \ | ||
--custom_mode awq | ||
|
||
python3 -m onnxruntime_genai.models.builder \ | ||
-i "$QUANTIZED_MODEL_DIR" \ | ||
-o "$QUANTIZED_MODEL_DIR-onnx" \ | ||
-p int4 \ | ||
-e cpu | ||
``` | ||
|
||
## Getting the dataset | ||
|
||
Download the preprocessed open-orca dataset files using the instructions in https://github.com/mlcommons/inference/tree/master/language/llama2-70b#preprocessed | ||
|
||
### Running the example | ||
|
||
#### Starting migraphx docker | ||
|
||
```bash | ||
|
||
./build_docker.sh | ||
|
||
export MODEL_DIR_PATH=path/to/quantized/llama2-7[0]b-model | ||
export DATA_DIR_PATH=path/to/open_orca_dataset | ||
./run_docker.sh | ||
``` | ||
|
||
#### Building and running the example | ||
|
||
```bash | ||
# Convert dataset to numpy format | ||
./prepocess_dataset.py | ||
|
||
# Builidng the example | ||
cd mgx_llama2 | ||
mkdir build && cd build | ||
CXX=/opt/rocm/llvm/bin/clang++ cmake .. | ||
make -j | ||
|
||
# Running the example | ||
export MIOPEN_FIND_ENFORCE=3 | ||
./mgxllama2 | ||
|
||
# Test the accuracy of the output | ||
python3 eval_accuracy.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --file Dockerfile . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from argparse import ArgumentParser | ||
import numpy as np | ||
import pickle | ||
from pathlib import Path | ||
import os | ||
import evaluate | ||
import nltk | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
|
||
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | ||
|
||
G_MAX_TOK_LEN = 1024 | ||
G_LLAMA2_EOS = 2 | ||
SAMPLE_SIZE = 10 | ||
|
||
DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl" | ||
RESULT_PATH = "build/result.txt" | ||
|
||
def main(dataset_path, result_path, sample_size, sequence_size): | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
MODEL_NAME, | ||
model_max_length=sequence_size, | ||
padding_side="left", | ||
use_fast=False,) | ||
|
||
metric = evaluate.load("rouge") | ||
nltk.download("punkt_tab") | ||
|
||
_p = Path(DATASET_PATH) | ||
if _p.exists(): | ||
with _p.open(mode="rb") as f: | ||
d = pickle.load(f) | ||
|
||
|
||
target = d['output'].to_list() | ||
targets = target[0:sample_size] | ||
results, gen_tok_len = readResult(result_path) | ||
|
||
preds = tokenizer.batch_decode( | ||
results, skip_special_tokens=True | ||
) | ||
|
||
postprocess_text(preds, target) | ||
|
||
result = metric.compute( | ||
predictions=preds, references=targets, use_stemmer=True, use_aggregator=False | ||
) | ||
|
||
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} | ||
prediction_lens = [len(pred) for pred in preds] | ||
gen_num = len(preds) | ||
|
||
result = { | ||
**result, | ||
"gen_len": np.sum(prediction_lens), | ||
"gen_num": gen_num, | ||
"gen_tok_len": gen_tok_len, | ||
"tokens_per_sample": round(gen_tok_len / gen_num, 1), | ||
} | ||
|
||
print("\nResults\n") | ||
print(result) | ||
|
||
def readResult(path): | ||
results = [] | ||
tok_len = 0 | ||
f = open(path, "r") | ||
for res in f: | ||
result = res.split(",") | ||
result = [int(num_res) for num_res in result] | ||
results.append(result) | ||
tok_len += len(result) | ||
return results, tok_len | ||
|
||
def postprocess_text(preds, targets): | ||
preds = [pred.strip() for pred in preds] | ||
targets = [target.strip() for target in targets] | ||
|
||
# rougeLSum expects newline after each sentence | ||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | ||
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] | ||
|
||
return preds, targets | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument( | ||
"-d", | ||
"--dataset-path", | ||
help="Path to the dataset pickle file", | ||
default=DATASET_PATH | ||
) | ||
|
||
parser.add_argument( | ||
"-r", | ||
"--result_path", | ||
help="Path to output tokens result file", | ||
default=RESULT_PATH | ||
) | ||
|
||
parser.add_argument( | ||
"-size", | ||
"--sample-size", | ||
help="Sample size of dataset", | ||
type=int, | ||
default=SAMPLE_SIZE | ||
) | ||
|
||
parser.add_argument( | ||
"-seq_size", | ||
"--sequence_size", | ||
help="Size of sequence", | ||
type=int, | ||
default=G_MAX_TOK_LEN | ||
) | ||
|
||
main(**vars(parser.parse_args())) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? I don't believe we make any MIOpen calls for this model.