Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Use rocm ukernel when available + use num_layer for pkv. #381

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import re
import json

os.environ["TORCH_LOGS"] = "dynamic"
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand Down Expand Up @@ -61,19 +62,26 @@
help="Compile LLM with StreamingLLM optimizations",
)

# TODO (Dan): replace this with a file once I figure out paths on windows exe
json_schema_64 = """
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
"""

json_schema_16 = """
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
"""
def generate_schema(num_layers):
null = None
schema = [1, {"type": "builtins.tuple", "context": "null", "children_spec": []}]
kv_schema_per_layer = {
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{"type": null, "context": null, "children_spec": []},
{"type": null, "context": null, "children_spec": []},
],
}
for i in range(num_layers):
schema[1]["children_spec"].append(kv_schema_per_layer)
return json.dumps(schema)


def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
all_pkv_tensors = []
for i in range(heads * 2):
for i in range(num_layers * 2):
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
sliced = IREE.tensor_slice(
Expand Down Expand Up @@ -105,10 +113,8 @@ def export_transformer_model(
torch_dtype=torch.float,
token=hf_auth_token,
)
if mod.config.num_attention_heads == 8:
state_schema = pytree.treespec_loads(json_schema_16)
else:
state_schema = pytree.treespec_loads(json_schema_64)
schema_json = generate_schema(mod.config.num_hidden_layers)
state_schema = pytree.treespec_loads(schema_json)
if streaming_llm:
enable_llama_pos_shift_attention(mod)
dtype = torch.float32
Expand All @@ -121,12 +127,13 @@ def export_transformer_model(
token=hf_auth_token,
)
# TODO: generate these values instead of magic numbers
NUM_LAYERS = mod.config.num_hidden_layers
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
BATCH_SIZE = 1
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
global_pkv = torch.zeros(
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
dtype=dtype,
)

Expand Down Expand Up @@ -161,7 +168,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
self.global_seq_step = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
)
Expand All @@ -172,7 +179,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):

def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -183,7 +190,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -226,7 +233,7 @@ def run_cached_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -243,7 +250,7 @@ def run_cached_initialize(
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -278,7 +285,7 @@ def evict_kvcache_space(self):
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update_window_state = IREE.tensor_slice(
self.global_state,
i,
Expand Down Expand Up @@ -339,12 +346,14 @@ def evict_kvcache_space(self):
[
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-opt-strip-assertions=true",
"--iree-vm-target-truncate-unsupported-floats",
]
)
ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"}
if target_triple in ukernel_supported_arch:
flags.extend(["--iree-rocm-enable-ukernels=argmax"])
elif device == "cuda":
flags.extend(
[
Expand Down
Loading
Loading