From c3bbed86fbd651ab8c03d1523adc898e8065dcdd Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Thu, 1 Feb 2024 19:10:58 -0600 Subject: [PATCH] Query num_key_value_heads when available to support GQA models. Models with GQA implemented would not have the same number of heads for K,V vs Query. Hence we need to query `num_key_value_heads` attribute to see if we require different value for KV cache size. Not all model has `num_key_value_heads` as a part of their config such as QWEN. Phi has it, but it is set to null, hence the code is structured this way to handle those cases. --- models/turbine_models/custom_models/stateless_llama.py | 6 ++++-- .../gen_external_params/gen_external_params.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index c907a3a2c..918a03dd7 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -128,8 +128,10 @@ def export_transformer_model( ) # TODO: generate these values instead of magic numbers NUM_LAYERS = mod.config.num_hidden_layers - HEADS = mod.config.num_attention_heads - HIDDEN_DIM = int(mod.config.hidden_size / HEADS) + HEADS = getattr(mod.config, "num_key_value_heads", None) + if HEADS is None: + HEADS = mod.config.num_attention_heads + HIDDEN_DIM = int(mod.config.hidden_size / mod.config.num_attention_heads) BATCH_SIZE = 1 MAX_STEP_SEQ = mod.config.max_position_embeddings - 1 global_pkv = torch.zeros( diff --git a/models/turbine_models/gen_external_params/gen_external_params.py b/models/turbine_models/gen_external_params/gen_external_params.py index cff729208..7539d5f10 100644 --- a/models/turbine_models/gen_external_params/gen_external_params.py +++ b/models/turbine_models/gen_external_params/gen_external_params.py @@ -127,7 +127,6 @@ def gen_external_params( auto_model=AutoModelForCausalLM, hf_auth_token=hf_auth_token, ) - model_builder.build_model() if precision == "f16": model = model_builder.model.half()