Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Nov 7, 2024
1 parent a7b6668 commit 258a77e
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 28 deletions.
16 changes: 10 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def __init__(
if self.cached_cos is None:
if mode == "prefill":
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), dtype=np.float32)
self.cos = self.convert_to_fp16(cos)
sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), dtype=np.float32)
self.sin = self.convert_to_fp16(sin)
else:
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), dtype=np.float32)
self.cos = self.convert_to_fp16(cos)
sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), dtype=np.float32)
self.sin = self.convert_to_fp16(sin)
else:
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
cos = self.constant(self.cached_cos)
Expand Down Expand Up @@ -367,7 +371,7 @@ def forward(
)

if self.cached_cos is None:
inputs += (cos.to(torch.float16), sin.to(torch.float16))
inputs += (cos.to(torch.float32), sin.to(torch.float32))
else:
inputs += (position_ids.to(torch.int64),)

Expand Down Expand Up @@ -496,7 +500,7 @@ def forward(
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
if self.cached_cos is None:
inputs += (cos.to(torch.float16), sin.to(torch.float16),)
inputs += (cos.to(torch.float32), sin.to(torch.float32),)
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def run_model(

# Reshape input
input_dtype = x[0].dtype
x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else
set_contiguous(elem).to(torch.float16).numpy() for elem in x]
x_np = [set_contiguous(elem).numpy() for elem in x]
op_args = []
op_args_flatten = []
for w in weights:
Expand Down Expand Up @@ -651,8 +650,7 @@ def set_weights_async(self, op_id, weights):

@staticmethod
def run_decoders(inputs, decoders, models_ptr=None):
x_np = [elem.numpy() if elem.dtype == torch.int64 else
elem.to(torch.float16).numpy() for elem in inputs]
x_np = [elem.numpy() for elem in inputs]

num_decoders = len(decoders)
num_inputs = len(x_np)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,12 @@ def convert_llm(model: torch.nn.Module,
model.num_layers = layer_num
model.transpose_value_cache = transpose_value_cache

if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
model_type = "llama"
else:
model_type = "llama_32"
try:
res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num,
res = InitLLMPipeline(model_type, kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path,
os.path.join(temp_dir, "decoder_layer"), layernorm_const)
Expand Down
127 changes: 110 additions & 17 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,68 @@
import numpy as np
import os
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
from intel_npu_acceleration_library.backend.factory import NNFactory


class Llama32Embedding(NNFactory):
def __init__(
self,
vocab_size,
embedding_dim,
embedding_weight,
padding_idx,
inv_freq,
attention_scaling,
dtype, # fp16
device: str = "NPU",
):
super().__init__(False, device)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.attention_scaling = attention_scaling
self.dtype = dtype

# define input
weight = self.constant(embedding_weight)
input = self.parameter((1, 1), dtype=np.int32)
position_ids = self.parameter((1, 1), dtype=np.int64)
inv_freq = self.constant(inv_freq)

# embed_tokens module
if padding_idx == -1:
padding_idx += vocab_size

axis_node = self.constant(np.array([0], dtype=np.int64))
if padding_idx is not None:
masked_embeddings = np.ones(weight.shape, dtype=np.float16)
masked_embeddings[padding_idx, :] = 0.0 # mask

node_mask = self.constant(masked_embeddings)
node_masked_w = self.eltwise_mul(weight, node_mask)
res = self.gather(node_masked_w, input, axis_node, 0)
else:
res = self.gather(weight, input, axis_node, 0)

# rotary_emb module
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
position_ids = self.reshape(position_ids, (1, 1, 1))
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
self.convert_to_fp32(position_ids))
freqs = self.transpose(freqs, [0, 2, 1])
emb = self.concat(freqs, freqs, axis=2)
cos = self.cos(emb)
sin = self.sin(emb)
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

# define outputs
res = self.convert_to_fp16(res)
cos = self.convert_to_fp32(cos)
sin = self.convert_to_fp32(sin)

print("start compiling")
self.compile()


def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
Expand Down Expand Up @@ -71,14 +133,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
weight.tofile(bin_file)

embedding_layer = model.model.embed_tokens
new_embedding = LLMEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B
embedding_layer = model.model.embed_tokens
new_embedding = LLMEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
else:
# llama-3.2-3B & llama-3.2-1B
new_embedding = Llama32Embedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
attention_scaling=model.model.rotary_emb.attention_scaling,
dtype=np.float16,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
return first_blob_path, last_blob_path
Expand Down Expand Up @@ -135,8 +210,14 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
else:
# llama-3.2-3B & llama-3.2-1B
cached_cos = None
cached_sin = None
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

Expand Down Expand Up @@ -168,14 +249,26 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
f"decoder_layer_{layer_idx}",
temp_dir)

if layernorm_const:
st_idx = 5
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B
if layernorm_const:
st_idx = 5
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
# llama-3.2-3B & llama-3.2-1B
if layernorm_const:
st_idx = 6
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 8
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
Expand Down

0 comments on commit 258a77e

Please sign in to comment.