Skip to content

Commit

Permalink
Merge branch 'main' into fix_peft_ds_recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
yafshar committed Oct 18, 2024
2 parents 1d7ae3a + f98688d commit 2dfebda
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--attn_softmax_bf16 \
--bucket_size=128 \
--bucket_internal \
--batch_size 10 \
--batch_size 8 \
--max_input_tokens 40960 \
--max_new_tokens 5120 \
--use_flash_attention \
Expand Down
6 changes: 0 additions & 6 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,6 @@ def setup_parser(parser):
help="Run the inference with dataset for specified --n_iterations(default:5)",
)

parser.add_argument(
"--run_partial_dataset",
action="store_true",
help="Run the inference with dataset for specified --n_iterations(default:5)",
)

quant_parser_group = parser.add_mutually_exclusive_group()
quant_parser_group.add_argument(
"--load_quantized_model_with_autogptq",
Expand Down
19 changes: 11 additions & 8 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def patch_scoped_linear_all_reduce(model):


def get_torch_compiled_model(model):
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom"]:
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]:
model.transformer = torch.compile(
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
)
Expand Down Expand Up @@ -245,12 +245,14 @@ def setup_model(args, model_dtype, model_kwargs, logger):
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
)
elif args.load_quantized_model_with_inc:
#TODO: This will be removed in v1.19 Synapse release
#Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
# TODO: This will be removed in v1.19 Synapse release
# Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl

nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight

from neural_compressor.torch.quantization import load

model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
elif args.local_quantized_inc_model_path:
org_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -667,9 +669,10 @@ def initialize_model(args, logger):
logger.info(f"Model initialization took {(init_end - init_start):.3f}s")
return model, assistant_model, tokenizer, generation_config

#TODO:This will be removed from Synapse v1.19 release.
#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self,model):

# TODO:This will be removed from Synapse v1.19 release.
# This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self, model):
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict

resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
Expand All @@ -687,11 +690,11 @@ def local_load_remaining_pretrained_weight(self,model):
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

params_dict={
params_dict = {
"model": model,
"state_dict": state_dict,
"start_prefix": "",
"expected_keys": list(state_dict.keys()),
"expected_keys": self.loaded_state_dict_keys,
"device_map": {"": self.device},
"offload_folder": offload_folder,
"state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None,
Expand Down
1 change: 0 additions & 1 deletion examples/text-to-speech/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ python3 run_pipeline.py \
```
Models that have been validated:
- [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts)
- [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium)
- [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng)
2 changes: 0 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,12 @@
"starcoder2",
"persimmon",
"qwen2",
"starcoder2",
"llava",
"llava_next",
"stablelm",
"mamba",
"deci",
"qwen2_moe",
"gemma",
"whisper",
]

Expand Down
17 changes: 14 additions & 3 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ...modeling_rope_utils import GaudiRotaryEmbedding


try:
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.block_size = 4096
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
Expand All @@ -155,7 +157,7 @@ def update_sincos_cache(self, seq_len):
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
self.rotary_emb._set_cos_sin_cache(seq_len, self.k_proj.weight.device, self.k_proj.weight.dtype)

def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
Expand Down Expand Up @@ -252,8 +254,8 @@ def pre_attn_forward(
else:
kv_seq_len = past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos[position_ids], sin[position_ids])

if use_cache:
# reuse k, v, self_attention
Expand Down Expand Up @@ -697,6 +699,15 @@ def forward(


class GaudiGemmaForCausalLM(GemmaForCausalLM):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)

def update_sincos_cache(self, seq_len):
self.model.update_sincos_cache(seq_len)

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
# Truncate the cached max sequence length to 8k to limit cached register buffer size
if config.max_position_embeddings >= 8192:
if config.max_position_embeddings > 8192 and self.rope_type == "llama3":
self.max_seq_len_cached = 8192
self.original_max_seq_len = config.max_position_embeddings

Expand Down

0 comments on commit 2dfebda

Please sign in to comment.