diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index c907a3a2c..918a03dd7 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -128,8 +128,10 @@ def export_transformer_model( ) # TODO: generate these values instead of magic numbers NUM_LAYERS = mod.config.num_hidden_layers - HEADS = mod.config.num_attention_heads - HIDDEN_DIM = int(mod.config.hidden_size / HEADS) + HEADS = getattr(mod.config, "num_key_value_heads", None) + if HEADS is None: + HEADS = mod.config.num_attention_heads + HIDDEN_DIM = int(mod.config.hidden_size / mod.config.num_attention_heads) BATCH_SIZE = 1 MAX_STEP_SEQ = mod.config.max_position_embeddings - 1 global_pkv = torch.zeros( diff --git a/models/turbine_models/gen_external_params/gen_external_params.py b/models/turbine_models/gen_external_params/gen_external_params.py index cff729208..7539d5f10 100644 --- a/models/turbine_models/gen_external_params/gen_external_params.py +++ b/models/turbine_models/gen_external_params/gen_external_params.py @@ -127,7 +127,6 @@ def gen_external_params( auto_model=AutoModelForCausalLM, hf_auth_token=hf_auth_token, ) - model_builder.build_model() if precision == "f16": model = model_builder.model.half()