Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama2 7b model C++ example #3666

Open
wants to merge 57 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4b6c099
WIP mgx llama2-7b example
gyulaz-htec Oct 22, 2024
1edf198
Code works with offload copy
gyulaz-htec Oct 22, 2024
7cd6a0c
Add support to load onnx file
ototh-htec Oct 24, 2024
595771f
Add dockerization for mgx_llama2 example
ototh-htec Oct 24, 2024
fdd16ed
Rework buffer allocation so offload_copy can be turned on/off
gyulaz-htec Oct 24, 2024
1b78a1c
Use dedicated hipStream for synchronization
gyulaz-htec Oct 24, 2024
0731e26
Save onnx model to mxr file
ototh-htec Oct 25, 2024
2ab6659
Only copy changed data
gyulaz-htec Oct 25, 2024
e7f84c2
Extend model loading options with fast_math
gyulaz-htec Oct 25, 2024
0fe7924
Fix quant message
gyulaz-htec Oct 25, 2024
984e0dc
Basic tokens/sec counting
gyulaz-htec Oct 28, 2024
4134505
Add preprocess dataset script
ototh-htec Oct 28, 2024
6735534
Use dataset from numpy files if available
ototh-htec Oct 28, 2024
f1960f7
Support npy dataset with multiple samples
ototh-htec Oct 29, 2024
55e41f4
Add missing upload to device for multiple samples
ototh-htec Oct 29, 2024
4d25c49
Add accuracy calculation for mgx_llama2 example
ototh-htec Nov 4, 2024
3dd55b5
Use MIGraphX from develop branch in Dockerfile
ototh-htec Nov 5, 2024
2416786
Fix dataset loading
gyulaz-htec Nov 7, 2024
0877f72
Add README to C++ LLama2 example
gyulaz-htec Nov 7, 2024
513cc0d
Add buffers for llama2 7b quantized models
ototh-htec Nov 7, 2024
048e587
Fix Llama2-7b model file parse and input buffers
ototh-htec Nov 8, 2024
71b7522
Fix llama 7b quantized model evaluation step, use 2 models
ototh-htec Nov 13, 2024
9220cf1
Support llama 7b quantized model without offload copy
ototh-htec Nov 13, 2024
3842b54
Connect dataset to llama 7b quantized model
ototh-htec Nov 14, 2024
d8a85cd
Fix output buffer usage for new sample
ototh-htec Nov 14, 2024
37ed15b
Comment out past/present_key_value binding
ototh-htec Nov 14, 2024
d970b30
Add new migraphx public API functions: replace_return and get_last_in…
gyulaz-htec Nov 15, 2024
a7801b0
Add romxProfileData to the example container
gyulaz-htec Nov 15, 2024
9152cd1
Update MGX branch and en variables in the example docker
gyulaz-htec Nov 15, 2024
14df5db
Fix example readme
gyulaz-htec Nov 15, 2024
f7d3fbf
Fix typo in example preproc script
gyulaz-htec Nov 15, 2024
58bc2c2
Update LLama2 example to target specific device and enable fast_math …
gyulaz-htec Nov 15, 2024
e2856a7
Add eval_accuracy script dependencies to Dockerfile
ototh-htec Nov 15, 2024
632c026
Add argmax program to lessen DToH copy + CPU computation overhead
gyulaz-htec Nov 15, 2024
1929b44
Update LLama2 example docker file ENVs
gyulaz-htec Nov 17, 2024
19f36b4
Disable fast_math
gyulaz-htec Nov 17, 2024
74c9f4d
Improve input_ids buffer handling, use different program arguments fo…
ototh-htec Nov 19, 2024
9f16a14
Move mgxllama2 components to multiple files
ototh-htec Nov 19, 2024
e9d0a81
Move one dim input ids to input class
ototh-htec Nov 20, 2024
10596c1
Refactor outputs to llama2outputs file and move main to MGXLlama2 struct
ototh-htec Nov 21, 2024
d9c2746
Refactor mgxllama2
ototh-htec Nov 25, 2024
fbd46b8
Use batch size from config
ototh-htec Nov 27, 2024
583a439
Remove offload copy option
ototh-htec Nov 27, 2024
a95ee80
Implement batching for mgxllama2 example
ototh-htec Nov 28, 2024
2142a40
Fix last batch when there is not enough sample
ototh-htec Nov 29, 2024
b919b0f
Revert "Add new migraphx public API functions: replace_return and get…
ototh-htec Dec 4, 2024
41447b0
Remove unused python imports
ototh-htec Dec 4, 2024
91f368b
Format files with clang-format
ototh-htec Dec 4, 2024
e5c0d62
Format files with clang-format
ototh-htec Dec 4, 2024
7d0bb2f
Fix delete modifier ident
ototh-htec Dec 4, 2024
43364a1
Fix python files format issues
ototh-htec Dec 4, 2024
5f2190f
Fix python files format issues 2
ototh-htec Dec 4, 2024
7468a3e
Merge branch 'develop' into htec/mgx-llama2-7b-example
ototh-htec Dec 5, 2024
0640809
Make output results const for cppcheck
ototh-htec Dec 5, 2024
20fb8bc
Pass GPU_TARGET from build_docker script to Dockerfile
ototh-htec Dec 5, 2024
9165b90
Merge branch 'develop' into htec/mgx-llama2-7b-example
ototh-htec Dec 5, 2024
b499638
Add license
ototh-htec Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/transformers/mgx_llama2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
project(MGXLlama2)
cmake_minimum_required(VERSION 3.22)

set(TARGET_NAME mgxllama2)

set(HARNESS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/harness)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CXX /opt/rocm/llvm/bin/clang++)

list (APPEND CMAKE_PREFIX_PATH /opt/rocm ${HIP_PATH})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -W -Wall -pthread -D__HIP_PLATFORM_HCC__=1")

find_package(migraphx REQUIRED)
find_package(hip REQUIRED)

include_directories(${HARNESS_DIR})

add_executable(${TARGET_NAME}
mgxllama2.cc
)

target_include_directories(${TARGET_NAME}
PUBLIC ${HARNESS_DIR}
)

target_link_libraries(${TARGET_NAME}
migraphx::c
hip::device
pthread
)

47 changes: 47 additions & 0 deletions examples/transformers/mgx_llama2/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
FROM rocm/dev-ubuntu-22.04:6.2

ENV DEBIAN_FRONTEND=noninteractive

SHELL ["/bin/bash", "-c"]

RUN apt-get update && apt-get install -y --allow-unauthenticated \
apt-utils \
cmake \
half \
sqlite3 \
libsqlite3-dev \
libfmt-dev \
git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

RUN mkdir /app && cd /app && git clone https://github.com/ROCm/rocmProfileData
WORKDIR /app/rocmProfileData
RUN make; make install

RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git
WORKDIR /migraphx/AMDMIGraphX

RUN ./tools/install_prereqs.sh
ENV MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1
ENV MIGRAPHX_USE_HIPBLASLT=1
ENV MIGRAPHX_USE_MIOPEN=1

#TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS
RUN mkdir build && cd build && \
CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS='gfx942' && \
make -j$(nproc) && \
make install

RUN mkdir /mgx_llama2

COPY . /mgx_llama2

RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build

WORKDIR /mgx_llama2/build

RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make

RUN pip install pandas evaluate nltk transformers sentencepiece rouge_score

93 changes: 93 additions & 0 deletions examples/transformers/mgx_llama2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## Getting the model

### Getting the pre-quantized model from HuggingFace
```bash
pip install -U "huggingface_hub[cli]"
huggingface-cli login YOUR_HF_TOKEN
hugginggface-cli download https://huggingface.co/amd/Llama-2-7b-chat-hf-awq-int4-asym-gs128-onnx
```
Alternatively you can quantize the model yourself.

### Quantizing the model

**If you are using the pre-quantized model you can skip this section.**

Get the latest quark quantizer version from https://xcoartifactory/ui/native/uai-pip-local/com/amd/quark/main/nightly/ . Downloading the zip is recommended because it contains the required scripts. The quark version used when this was created: quark-1.0.0.dev20241028+eb46b7438 (28-10-24).

Also we will need to install the onnxruntime-genai (OGA) tool to convert the quark_safetensors format to onnx format properly.

#### Installing quark and it's dependencies:
```bash
# install OGA tool
pip install onnxruntime-genai

# Quark dependencies according to https://quark.docs.amd.com/latest/install.html, we assume pytorch is already installed. You can use the following base docker image which has torch installed: rocm/pytorch:rocm6.2.2_ubuntu20.04_py3.9_pytorch_release_2.2.1
pip install onnxruntime onnx

# Install the whl
unzip quark-1.0.0.dev20241028+eb46b7438.zip -d quark
cd quark
RUN pip install quark-1.0.0.dev20241028+eb46b7438-py3-none-any.whl
```

#### Quantizing the model and converting to ONNX
```bash
cd quark/examples/torch/language_modeling/llm_ptq

export MODEL_DIR = [local model checkpoint folder] or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf
export QUANTIZED_MODEL_DIR = [output model checkpoint folder]

python3 quantize_quark.py --model_dir $MODEL_DIR \
--data_type float16 \
--quant_scheme w_uint4_per_group_asym \
--num_calib_data 128 \
--quant_algo awq \
--dataset pileval_for_awq_benchmark \
--seq_len 1024 \
--output_dir $MODEL_DIR-awq-uint4-asym-g128-f16 \
--model_export quark_safetensors \
--custom_mode awq

python3 -m onnxruntime_genai.models.builder \
-i "$QUANTIZED_MODEL_DIR" \
-o "$QUANTIZED_MODEL_DIR-onnx" \
-p int4 \
-e cpu
```

## Getting the dataset

Download the preprocessed open-orca dataset files using the instructions in https://github.com/mlcommons/inference/tree/master/language/llama2-70b#preprocessed

### Running the example

#### Starting migraphx docker

```bash

./build_docker.sh

export MODEL_DIR_PATH=path/to/quantized/llama2-7[0]b-model
export DATA_DIR_PATH=path/to/open_orca_dataset
./run_docker.sh
```

#### Building and running the example

```bash
# Convert dataset to numpy format
./prepocess_dataset.py

# Builidng the example
cd mgx_llama2
mkdir build && cd build
CXX=/opt/rocm/llvm/bin/clang++ cmake ..
make -j

# Running the example
export MIOPEN_FIND_ENFORCE=3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I don't believe we make any MIOpen calls for this model.

./mgxllama2

# Test the accuracy of the output
python3 eval_accuracy.py
```
3 changes: 3 additions & 0 deletions examples/transformers/mgx_llama2/build_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --file Dockerfile .
118 changes: 118 additions & 0 deletions examples/transformers/mgx_llama2/eval_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from argparse import ArgumentParser
import numpy as np
import pickle
from pathlib import Path
import os
import evaluate
import nltk
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM


MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

G_MAX_TOK_LEN = 1024
G_LLAMA2_EOS = 2
SAMPLE_SIZE = 10

DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl"
RESULT_PATH = "build/result.txt"

def main(dataset_path, result_path, sample_size, sequence_size):
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
model_max_length=sequence_size,
padding_side="left",
use_fast=False,)

metric = evaluate.load("rouge")
nltk.download("punkt_tab")

_p = Path(DATASET_PATH)
if _p.exists():
with _p.open(mode="rb") as f:
d = pickle.load(f)


target = d['output'].to_list()
targets = target[0:sample_size]
results, gen_tok_len = readResult(result_path)

preds = tokenizer.batch_decode(
results, skip_special_tokens=True
)

postprocess_text(preds, target)

result = metric.compute(
predictions=preds, references=targets, use_stemmer=True, use_aggregator=False
)

result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
prediction_lens = [len(pred) for pred in preds]
gen_num = len(preds)

result = {
**result,
"gen_len": np.sum(prediction_lens),
"gen_num": gen_num,
"gen_tok_len": gen_tok_len,
"tokens_per_sample": round(gen_tok_len / gen_num, 1),
}

print("\nResults\n")
print(result)

def readResult(path):
results = []
tok_len = 0
f = open(path, "r")
for res in f:
result = res.split(",")
result = [int(num_res) for num_res in result]
results.append(result)
tok_len += len(result)
return results, tok_len

def postprocess_text(preds, targets):
preds = [pred.strip() for pred in preds]
targets = [target.strip() for target in targets]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]

return preds, targets

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"-d",
"--dataset-path",
help="Path to the dataset pickle file",
default=DATASET_PATH
)

parser.add_argument(
"-r",
"--result_path",
help="Path to output tokens result file",
default=RESULT_PATH
)

parser.add_argument(
"-size",
"--sample-size",
help="Sample size of dataset",
type=int,
default=SAMPLE_SIZE
)

parser.add_argument(
"-seq_size",
"--sequence_size",
help="Size of sequence",
type=int,
default=G_MAX_TOK_LEN
)

main(**vars(parser.parse_args()))
Loading
Loading