Skip to content

Commit

Permalink
Add Granite to model builder (#1153)
Browse files Browse the repository at this point in the history
### Description

This PR adds [IBM's
Granite](https://www.ibm.com/new/ibm-granite-3-0-open-state-of-the-art-enterprise-models)
models to the model builder. It also adds the following improvements to
the model builder:

1. Always unpack any packed weights in the attention and MLP layers
2. Insert optional Add nodes if the MLP layer has a bias (the [Granite
code
models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
use the LLaMA architecture but with biases included)
3. Use `gate_up_proj` or `dense_h_to_4h` as the attribute name when
unpacking weights in the MLP layer

### Motivation and Context

Granite is a family of foundation models from IBM.
  • Loading branch information
kunal-vaishnavi authored Jan 9, 2025
1 parent 41c2543 commit 9143cfd
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 73 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ It implements the generative AI loop for ONNX models, including pre and post pro

See documentation at https://onnxruntime.ai/docs/genai.

|Support matrix|Supported now|Under development|On the roadmap|
|-|-|-|-|
|Model architectures| Gemma <br/> Llama * <br/> Mistral + <br/>Phi (language + vision)<br/>Qwen <br/>Nemotron <br/>|Whisper|Stable diffusion|
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||
|Hardware Acceleration|CUDA<br/>DirectML<br/>|QNN <br/> OpenVINO <br/> ROCm ||
|Features|| Interactive decoding <br/> Customization (fine-tuning)| Speculative decoding |
| Support matrix | Supported now | Under development | On the roadmap |
| -------------- | ------------- | ----------------- | -------------- |
| Model architectures | Gemma <br/> Llama * <br/> Mistral + <br/> Phi (language + vision) <br/> Qwen <br/> Nemotron <br/> Granite <br/> | Whisper | Stable diffusion |
| API | Python <br/> C# <br/> C/C++ <br/> Java ^ | Objective-C | |
| Platform | Linux <br/> Windows <br/> Mac ^ <br/> Android ^ | | iOS |
| Architecture | x86 <br/> x64 <br/> Arm64 ~ | | |
| Hardware Acceleration | CUDA <br/> DirectML <br/> | QNN <br/> OpenVINO <br/> ROCm | |
| Features | | Interactive decoding <br/> Customization (fine-tuning) | Speculative decoding |

\* The Llama model architecture supports similar model families such as CodeLlama, Vicuna, Yi, and more.

Expand Down
5 changes: 4 additions & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//
// Modifications Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved
#include <algorithm>
#include <set>
#include <string>
#include <thread>

#include "../generators.h"
Expand Down Expand Up @@ -580,9 +582,10 @@ std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, con
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, std::unique_ptr<Config> config) {
std::set<std::string> llm_types = {"chatglm", "decoder", "gemma", "gemma2", "granite", "llama", "mistral", "nemotron", "phi", "phimoe", "phi3", "phi3small", "qwen2"};
if (config->model.type == "gpt2")
return std::make_shared<Gpt_Model>(std::move(config), ort_env);
if (config->model.type == "llama" || config->model.type == "gemma" || config->model.type == "gemma2" || config->model.type == "mistral" || config->model.type == "phi" || config->model.type == "phi3" || config->model.type == "phi3small" || config->model.type == "phimoe" || config->model.type == "qwen2" || config->model.type == "nemotron" || config->model.type == "chatglm")
if (llm_types.find(config->model.type) != llm_types.end())
return std::make_shared<DecoderOnly_Model>(std::move(config), ort_env);
if (config->model.type == "whisper")
return std::make_shared<Whisper_Model>(std::move(config), ort_env);
Expand Down
1 change: 1 addition & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The tool currently supports the following model architectures.

- ChatGLM
- Gemma
- Granite
- LLaMA
- Mistral
- Nemotron
Expand Down
Loading

0 comments on commit 9143cfd

Please sign in to comment.