Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-J support reuse_cache #1094

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, tokenizer, model, args, options):
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2"]:
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2", "gptj"]:
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
Expand Down
5 changes: 3 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,8 @@ def generate(
"mixtral",
"phi",
"qwen2",
], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi and qwen2 at the moment"
"gptj",
], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and gptj at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down Expand Up @@ -1014,7 +1015,7 @@ def generate(
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens

if self.config.model_type in ["llama", "falcon", "mistral", "qwen2"]:
if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)

Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
GaudiGPTJAttention,
GaudiGPTJBlock,
GaudiGPTJForCausalLM,
GaudiGPTJModel,
GaudiGPTNeoXForCausalLM,
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
Expand Down Expand Up @@ -137,7 +138,6 @@
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache,
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
Expand Down Expand Up @@ -341,7 +341,7 @@ def adapt_transformers_to_gaudi():
transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention
transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM
transformers.models.gptj.modeling_gptj.GPTJBlock = GaudiGPTJBlock
transformers.models.gptj.modeling_gptj.GPTJModel.forward = gaudi_gptj_model_forward
transformers.models.gptj.modeling_gptj.GPTJModel = GaudiGPTJModel

# Optimization for GPTBigCode on Gaudi
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention.forward = (
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
GaudiGPTJAttention,
GaudiGPTJBlock,
GaudiGPTJForCausalLM,
gaudi_gptj_model_forward,
GaudiGPTJModel,
)
from .llama import (
GaudiLlamaAttention,
Expand Down
4 changes: 3 additions & 1 deletion optimum/habana/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from transformers.models.gptj.configuration_gptj import GPTJConfig
atakaha marked this conversation as resolved.
Show resolved Hide resolved

from .modeling_gptj import (
GaudiGPTJAttention,
GaudiGPTJBlock,
GaudiGPTJForCausalLM,
gaudi_gptj_model_forward,
GaudiGPTJModel,
)
Loading
Loading