Skip to content

Commit

Permalink
Merge branch 'main' into add_llamacpp
Browse files Browse the repository at this point in the history
  • Loading branch information
gabeweisz authored Oct 22, 2024
2 parents afbf197 + ff12b67 commit a18a259
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 53 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_devices_plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
conda install pylint
pip install pytest
pip install -e plugins/devices
pip install -e . # Required to test current tkml package instead of pypi version
pip install transformers timm
python -m pip check
- name: Lint with PyLint
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_turnkey.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ jobs:
python -m pip install --upgrade pip
conda install pylint=3.2.7
pip install pytest
pip install -e .
pip install -e plugins/devices
pip install transformers timm
pip install -e . # Required to test current tkml package instead of pypi version
python -m pip check
- name: Lint with PyLint
shell: bash -el {0}
Expand Down
7 changes: 2 additions & 5 deletions plugins/devices/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,15 @@ def get_specific_version(plugin_name: str, version_key: str) -> str:
],
python_requires=">=3.8, <3.12",
install_requires=[
"turnkeyml==4.0.0",
"turnkeyml>=4.0.0",
"importlib_metadata",
"onnx_tool",
"numpy<2",
"gitpython",
"timm==0.9.10",
],
include_package_data=True,
package_data={
"turnkeyml_plugin_devices": [
]
},
package_data={"turnkeyml_plugin_devices": []},
extras_require={
"onnxrt": [],
"torchrt": [],
Expand Down
6 changes: 1 addition & 5 deletions plugins/devices/test/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,8 @@ def test_008_cli_timeout(self):
"""
Make sure that the --timeout option and its associated reporting features work.
timeout.py is designed to take a long time to export, which gives us the
timeout.py is designed to take 20s to discover, which gives us the
opportunity to kill it with a timeout.
NOTE: this test can become flakey if:
- exporting timeout.py takes less time than the timeout
- the timeout kills the process before it has a chance to create a stats.yaml file
"""

testargs = [
Expand Down
7 changes: 4 additions & 3 deletions src/turnkeyml/common/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,23 @@ def forward(self, x):
"timeout.py": """
# labels: name::timeout author::turnkey license::mit test_group::a task::test
import torch
import time
torch.manual_seed(0)
class LinearTestModel(torch.nn.Module):
def __init__(self, input_features, output_features):
super(LinearTestModel, self).__init__()
time.sleep(20)
self.fc = torch.nn.Linear(input_features, output_features)
def forward(self, x):
output = self.fc(x)
return output
input_features = 500000
output_features = 1000
input_features = 50
output_features = 10
# Model and input configurations
model = LinearTestModel(input_features, output_features)
Expand Down
111 changes: 111 additions & 0 deletions src/turnkeyml/llm/docs/ort_genai_npu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Introduction

onnxruntime-genai (aka OGA) is a new framework created by Microsoft for running ONNX LLMs: https://github.com/microsoft/onnxruntime-genai/tree/main?tab=readme-ov-file

## NPU instructions

### Warnings

- Users have experienced inconsistent results across models and machines. If one model isn't working well on your laptop, try one of the other models.
- The OGA wheels need to be installed in a specific order or you will end up with the wrong packages in your environment. If you see pip dependency errors, please delete your conda env and start over with a fresh environment.

### Installation

1. NOTE: ⚠️ DO THESE STEPS IN EXACTLY THIS ORDER ⚠️
1. Install `lemonade`:
1. Create a conda environment: `conda create -n oga-npu python=3.10` (Python 3.10 is required)
1. Activate: `conda activate oga-npu`
1. `cd REPO_ROOT`
1. `pip install -e .[oga-npu]`
1. Download required OGA packages
1. Access the [AMD RyzenAI EA Lounge](https://account.amd.com/en/member/ryzenai-sw-ea.html#tabs-a5e122f973-item-4757898120-tab) and download `amd_oga_Oct4_2024.zip` from `Ryzen AI 1.3 Preview Release`.
1. Unzip `amd_oga_Oct4_2024.zip`
1. Setup your folder structure:
1. Copy all of the content inside `amd_oga` to lemonade's `REPO_ROOT\src\lemonade\tools\ort_genai\models\`
1. Move all dlls from `REPO_ROOT\src\lemonade\tools\ort_genai\models\libs` to `REPO_ROOT\src\lemonade\tools\ort_genai\models\`
1. Install the wheels:
1. `cd amd_oga\wheels`
1. `pip install onnxruntime_genai-0.5.0.dev0-cp310-cp310-win_amd64.whl`
1. `pip install onnxruntime_vitisai-1.20.0-cp310-cp310-win_amd64.whl`
1. `pip install voe-1.2.0-cp310-cp310-win_amd64.whl`
1. Ensure you have access to the models on Hungging Face:
1. Ensure you can access the models under [quark-quantized-onnx-llms-for-ryzen-ai-13-ea](https://huggingface.co/collections/amd/quark-quantized-onnx-llms-for-ryzen-ai-13-ea-66fc8e24927ec45504381902) on Hugging Face. Models are gated and you may have to request access.
1. Create a Hugging Face Access Token [here](https://huggingface.co/settings/tokens). Ensure you select `Read access to contents of all public gated repos you can access` if creating a finegrained token.
1. Set your Hugging Face token as an environment variable: `set HF_TOKEN=<your token>`
1. Install driver
1. Access the [AMD RyzenAI EA Lounge](https://account.amd.com/en/member/ryzenai-sw-ea.html#tabs-a5e122f973-item-4757898120-tab) and download `Win24AIDriver.zip` from `Ryzen AI 1.3 Preview Release`.
1. Unzip `Win24AIDriver.zip`
1. Right click `kipudrv.inf` and select `Install`
1. Check under `Device Manager` to ensure that `NPU Compute Accelerator` is using version `32.0.203.219`.

### Runtime

To test basic functionality, point lemonade to any of the models under under [quark-quantized-onnx-llms-for-ryzen-ai-13-ea](https://huggingface.co/collections/amd/quark-quantized-onnx-llms-for-ryzen-ai-13-ea-66fc8e24927ec45504381902):

```
lemonade -i amd/Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix oga-load --device npu --dtype int4 llm-prompt -p "hello whats your name?" --max-new-tokens 15
```

```
Building "amd_Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix"
[Vitis AI EP] No. of Operators : CPU 73 MATMULNBITS 99
[Vitis AI EP] No. of Subgraphs :MATMULNBITS 33
✓ Loading OnnxRuntime-GenAI model
✓ Prompting LLM
amd/Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix:
<built-in function input> (executed 1x)
Build dir: C:\Users\danie/.cache/lemonade\amd_Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix
Status: Successful build!
Dtype: int4
Device: npu
Response: hello whats your name?
Hi, I'm a 21 year old male from the
```

To test/use the websocket server:

```
lemonade -i amd/Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix oga-load --device npu --dtype int4 serve --max-new-tokens 50
```

Then open the address (http://localhost:8000) in a browser and chat with it.

```
Building "amd_Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix"
[Vitis AI EP] No. of Operators : CPU 73 MATMULNBITS 99
[Vitis AI EP] No. of Subgraphs :MATMULNBITS 33
✓ Loading OnnxRuntime-GenAI model
INFO: Started server process [27752]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://localhost:8000 (Press CTRL+C to quit)
INFO: ::1:54973 - "GET / HTTP/1.1" 200 OK
INFO: ('::1', 54975) - "WebSocket /ws" [accepted]
INFO: connection open
I'm a newbie here. I'm looking for a good place to buy a domain name. I've been looking around and i've found a few good places.
```

To run a single MMLU test:

```
lemonade -i amd/Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix oga-load --device npu --dtype int4 accuracy-mmlu --tests management
```

```
Building "amd_Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix"
[Vitis AI EP] No. of Operators : CPU 73 MATMULNBITS 99
[Vitis AI EP] No. of Subgraphs :MATMULNBITS 33
✓ Loading OnnxRuntime-GenAI model
✓ Measuring accuracy with MMLU
amd/Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix:
<built-in function input> (executed 1x)
Build dir: C:\Users\danie/.cache/lemonade\amd_Llama-2-7b-hf-awq-g128-int4-asym-fp32-onnx-ryzen-strix
Status: Successful build!
Dtype: int4
Device: npu
Mmlu Management Accuracy: 56.31 %
```
2 changes: 1 addition & 1 deletion src/turnkeyml/llm/leap.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def from_pretrained(

return state.model, state.tokenizer

if recipe == "hf-dgpu":
elif recipe == "hf-dgpu":
# Huggingface Transformers recipe for discrete GPU (Nvidia, Instinct, Radeon)

import torch
Expand Down
101 changes: 69 additions & 32 deletions src/turnkeyml/llm/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import time
import json
from fnmatch import fnmatch
from queue import Queue
from huggingface_hub import snapshot_download, login
import onnxruntime_genai as og
from turnkeyml.state import State
from turnkeyml.tools import FirstTool
Expand All @@ -18,7 +20,6 @@
)
from turnkeyml.llm.cache import Keys


class OrtGenaiTokenizer(TokenizerAdapter):
def __init__(self, model: og.Model):
# Initialize the tokenizer and produce the initial tokens.
Expand Down Expand Up @@ -74,9 +75,9 @@ def __init__(self, input_folder):
self.config = self.load_config(input_folder)

def load_config(self, input_folder):
config_path = os.path.join(input_folder, 'genai_config.json')
config_path = os.path.join(input_folder, "genai_config.json")
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
return None

Expand All @@ -99,21 +100,23 @@ def generate(
max_length = len(input_ids) + max_new_tokens

params.input_ids = input_ids
if self.config and 'search' in self.config:
search_config = self.config['search']
if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
do_sample=search_config.get('do_sample', do_sample),
top_k=search_config.get('top_k', top_k),
top_p=search_config.get('top_p', top_p),
temperature=search_config.get('temperature', temperature),
do_sample=search_config.get("do_sample", do_sample),
top_k=search_config.get("top_k", top_k),
top_p=search_config.get("top_p", top_p),
temperature=search_config.get("temperature", temperature),
max_length=max_length,
min_length=0,
early_stopping=search_config.get('early_stopping', False),
length_penalty=search_config.get('length_penalty', 1.0),
num_beams=search_config.get('num_beams', 1),
num_return_sequences=search_config.get('num_return_sequences', 1),
repetition_penalty=search_config.get('repetition_penalty', 1.0),
past_present_share_buffer=search_config.get('past_present_share_buffer', True),
early_stopping=search_config.get("early_stopping", False),
length_penalty=search_config.get("length_penalty", 1.0),
num_beams=search_config.get("num_beams", 1),
num_return_sequences=search_config.get("num_return_sequences", 1),
repetition_penalty=search_config.get("repetition_penalty", 1.0),
past_present_share_buffer=search_config.get(
"past_present_share_buffer", True
),
# Not currently supported by OGA
# diversity_penalty=search_config.get('diversity_penalty', 0.0),
# no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
Expand Down Expand Up @@ -192,6 +195,7 @@ class OgaLoad(FirstTool):
llama_2 = "meta-llama/Llama-2-7b-chat-hf"
phi_3_mini_4k = "microsoft/Phi-3-mini-4k-instruct"
phi_3_mini_128k = "microsoft/Phi-3-mini-128k-instruct"
And models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern
Output:
state.model: handle to a Huggingface-style LLM loaded on DirectML device
Expand Down Expand Up @@ -244,7 +248,7 @@ def run(
checkpoint = input

# Map of models[device][dtype][checkpoint] to the name of the model folder on disk
supported_models = {
local_supported_models = {
"igpu": {
"int4": {
phi_3_mini_128k: os.path.join(
Expand All @@ -261,6 +265,7 @@ def run(
},
"npu": {
"int4": {
# Legacy RyzenAI 1.2 models for NPU
llama_2: "llama2-7b-int4",
llama_3: "llama3-8b-int4",
qwen_1dot5: "qwen1.5-7b-int4",
Expand All @@ -277,28 +282,60 @@ def run(
},
}

hf_supported_models = {"npu": {"int4": "amd/**-onnx-ryzen-strix"}}

supported_locally = True
try:
dir_name = supported_models[device][dtype][checkpoint]
dir_name = local_supported_models[device][dtype][checkpoint]
except KeyError as e:
raise ValueError(
"The device;dtype;checkpoint combination is not supported: "
f"{device};{dtype};{checkpoint}. The supported combinations "
f"are: {supported_models}"
) from e

model_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"models",
dir_name,
)
supported_locally = False
hf_supported = (
device in hf_supported_models
and dtype in hf_supported_models[device]
and fnmatch(checkpoint, hf_supported_models[device][dtype])
)
if not hf_supported:
raise ValueError(
"The device;dtype;checkpoint combination is not supported: "
f"{device};{dtype};{checkpoint}. The supported combinations "
f"are: {local_supported_models} for local models and {hf_supported_models}"
" for models on Hugging Face."
) from e

# Create models dir if it doesn't exist
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
if not os.path.exists(models_dir):
os.makedirs(models_dir)

# If the model is supported though Hugging Face, download it
if not supported_locally:
hf_model_name = checkpoint.split("amd/")[1]
dir_name = "_".join(hf_model_name.split("-")[:6]).lower()
api_key = os.getenv("HF_TOKEN")
login(api_key)
snapshot_download(
repo_id=checkpoint,
local_dir=os.path.join(models_dir, dir_name),
ignore_patterns=["*.md", "*.txt"],
)

# The NPU requires the CWD to be in the model folder
current_cwd = os.getcwd()
if device == "npu":
os.chdir(model_dir)
# Required environment variable for NPU
os.environ["DOD_ROOT"] = ".\\bins"
# Change to the models directory
os.chdir(models_dir)

# Common environment variables for all NPU models
os.environ["DD_ROOT"] = ".\\bins"
os.environ["DEVICE"] = "stx"
os.environ["XLNX_ENABLE_CACHE"] = "0"

# Phi models require USE_AIE_RoPE=0
if "phi-" in checkpoint.lower():
os.environ["USE_AIE_RoPE"] = "0"
else:
os.environ["USE_AIE_RoPE"] = "1"

model_dir = os.path.join(models_dir, dir_name)
state.model = OrtGenaiModel(model_dir)
state.tokenizer = OrtGenaiTokenizer(state.model.model)
state.dtype = dtype
Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.2"
__version__ = "4.0.3"
6 changes: 1 addition & 5 deletions test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,8 @@ def test_018_cli_timeout(self):
"""
Make sure that the --timeout option and its associated reporting features work.
timeout.py is designed to take a long time to export, which gives us the
timeout.py is designed to take 20s to discover, which gives us the
opportunity to kill it with a timeout.
NOTE: this test can become flakey if:
- exporting timeout.py takes less time than the timeout
- the timeout kills the process before it has a chance to create a stats.yaml file
"""

testargs = [
Expand Down

0 comments on commit a18a259

Please sign in to comment.