diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index b05e6dfb51..9a23428866 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -46,7 +46,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index 11ff5a55b0..42b9e8a468 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -56,7 +56,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index 941dade8f9..6a8ca235e1 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -61,7 +61,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 4d2e229db1..b2694665a3 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -63,7 +63,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index ec6b345d89..5a8d25b0ed 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -62,7 +62,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 7a660447b8..30315bfc84 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -61,7 +61,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_multitask_prompt_tuning.py b/examples/language-modeling/run_multitask_prompt_tuning.py index 48f9cefcb7..1d81bcc496 100644 --- a/examples/language-modeling/run_multitask_prompt_tuning.py +++ b/examples/language-modeling/run_multitask_prompt_tuning.py @@ -60,7 +60,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risk. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_prompt_tuning_clm.py b/examples/language-modeling/run_prompt_tuning_clm.py index 2d2b9c4c3e..e263c0c1b6 100644 --- a/examples/language-modeling/run_prompt_tuning_clm.py +++ b/examples/language-modeling/run_prompt_tuning_clm.py @@ -62,7 +62,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index b983055f31..d22949c076 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -60,7 +60,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py index 8249e659a1..1f045552bd 100644 --- a/examples/question-answering/run_seq2seq_qa.py +++ b/examples/question-answering/run_seq2seq_qa.py @@ -56,7 +56,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py index c1367e0668..83865556d1 100644 --- a/examples/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/speech-recognition/run_speech_recognition_ctc.py @@ -59,7 +59,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index e9abca3b92..4dcf0b498b 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -55,7 +55,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") @@ -580,7 +580,8 @@ def compute_metrics(pred): # save feature extractor, tokenizer and config feature_extractor.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) - config.save_pretrained(training_args.output_dir) + # TODO: uncomment the line below when this is fixed in Transformers + # config.save_pretrained(training_args.output_dir) processor = AutoProcessor.from_pretrained(training_args.output_dir) diff --git a/examples/stable-diffusion/unconditional_image_generation.py b/examples/stable-diffusion/unconditional_image_generation.py index df0575c0a7..baca71b6ba 100644 --- a/examples/stable-diffusion/unconditional_image_generation.py +++ b/examples/stable-diffusion/unconditional_image_generation.py @@ -19,7 +19,7 @@ def check_optimum_habana_min_version(*a, **b): return () -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") # Setup logging diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index ea5e002450..8715c4e75f 100755 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -65,7 +65,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 9dfd2adcfc..57bf7cbb05 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -57,7 +57,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py index 8d13b39923..c2def132a7 100644 --- a/examples/translation/run_translation.py +++ b/examples/translation/run_translation.py @@ -62,7 +62,7 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.43.0") +check_min_version("4.45.0") check_optimum_habana_min_version("1.14.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/optimum/habana/transformers/generation/__init__.py b/optimum/habana/transformers/generation/__init__.py index 6b43ee2ae3..09f85a5451 100644 --- a/optimum/habana/transformers/generation/__init__.py +++ b/optimum/habana/transformers/generation/__init__.py @@ -3,7 +3,6 @@ from .stopping_criteria import ( gaudi_EosTokenCriteria_call, gaudi_MaxLengthCriteria_call, - gaudi_MaxNewTokensCriteria_call, gaudi_MaxTimeCriteria_call, gaudi_StoppingCriteriaList_call, ) diff --git a/optimum/habana/transformers/generation/candidate_generator.py b/optimum/habana/transformers/generation/candidate_generator.py index 171161074f..6688553459 100644 --- a/optimum/habana/transformers/generation/candidate_generator.py +++ b/optimum/habana/transformers/generation/candidate_generator.py @@ -8,8 +8,8 @@ if TYPE_CHECKING: + from transformers.generation.logits_process import LogitsProcessorList from transformers.modeling_utils import PreTrainedModel - from transfromers.generation.logits_process import LogitsProcessorList from .configuration_utils import GaudiGenerationConfig diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index dac7aadd92..69325ab7b3 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -52,18 +52,6 @@ def gaudi_MaxLengthCriteria_call( return create_return_const_tensor(input_ids, is_done) -def gaudi_MaxNewTokensCriteria_call( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs -) -> Union[torch.BoolTensor, bool]: - token_idx = kwargs.get("token_idx", None) - if token_idx is not None: - assert not kwargs["needs_tensor_output"] - return token_idx >= self.max_length - else: - is_done = input_ids.shape[-1] >= self.max_length - return create_return_const_tensor(input_ids, is_done) - - def gaudi_MaxTimeCriteria_call( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> Union[torch.BoolTensor, bool]: diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index c22fedc57b..25c454b4d5 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -22,7 +22,7 @@ import torch import torch.distributed as dist -from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantizedCacheConfig +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation.candidate_generator import ( @@ -32,8 +32,10 @@ _prepare_attention_mask, _prepare_token_type_ids, ) +from transformers.generation.configuration_utils import NEED_SETUP_CACHE_CLASSES_MAPPING, QUANT_BACKEND_CLASSES_MAPPING from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, @@ -41,8 +43,6 @@ StopStringCriteria, ) from transformers.generation.utils import ( - NEED_SETUP_CACHE_CLASSES_MAPPING, - QUANT_BACKEND_CLASSES_MAPPING, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, GenerateBeamOutput, @@ -52,7 +52,6 @@ GenerateOutput, GenerationMixin, GenerationMode, - _ranking_fast, _split_model_inputs, _split_model_outputs, stack_model_outputs, @@ -291,6 +290,10 @@ def _expand_inputs_for_generation( Copied from Transformers: https://github.com/huggingface/transformers/blob/527ab894e59b6582578008e3b47648a65063f73d/src/transformers/generation/utils.py#L704 The tensor `token_idx` is not expanded. """ + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, model_kwargs def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: @@ -365,7 +368,6 @@ def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, num_new_tokens: int = 1, ) -> Dict[str, Any]: """ @@ -377,9 +379,7 @@ def _update_model_kwargs_for_generation( model_kwargs["first_token"] = False if not model_kwargs.get("pad_done", False): # update past_key_values keeping its naming used in model code - cache_name, cache = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) + cache_name, cache = self._extract_past_from_model_output(outputs) model_kwargs[cache_name] = cache if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -521,6 +521,7 @@ def _get_candidate_generator( ) -> CandidateGenerator: if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, @@ -564,6 +565,14 @@ def _get_stopping_criteria( criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if not generation_config.ignore_eos and generation_config._eos_token_tensor is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) + if ( + generation_config.is_assistant + and generation_config.assistant_confidence_threshold is not None + and generation_config.assistant_confidence_threshold > 0 + ): + criteria.append( + ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) + ) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -641,42 +650,36 @@ def _prepare_generation_config( using_model_generation_config = False if generation_config is None: # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # three conditions must be met + # the following conditions must be met # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. + # 3) there are non-default generation parameters in the model config. + # 4) the user must have set new generation parameters in the model config. # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. if ( not is_torchdynamo_compiling() - and self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() + and self.generation_config._from_model_config # 1) + and self.generation_config._original_object_hash == hash(self.generation_config) # 2) + and len(self.config._get_non_default_generation_parameters()) > 0 # 3) ): new_generation_config = GaudiGenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: + if new_generation_config != self.generation_config: # 4) warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." + " deprecated strategy to control generation and will be removed in v5." " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", + UserWarning, ) self.generation_config = new_generation_config - using_model_generation_config = True + generation_config = self.generation_config + using_model_generation_config = True # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` - # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. - if is_torchdynamo_compiling(): - model_kwargs = kwargs - generate_attributes_in_kwargs = [ - key for key, value in kwargs.items() if getattr(generation_config, key, None) != value - ] - if len(generate_attributes_in_kwargs) > 0: - raise ValueError( - "`torch.compile` exception: all generation configuration attributes must be passed within a " - f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." - ) - else: + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) if generation_config.static_shapes is None: generation_config.static_shapes = self.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES @@ -702,9 +705,124 @@ def _prepare_generation_config( generation_config.pad_token_id = self.generation_config.pad_token_id if generation_config.decoder_start_token_id is None: generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id + else: + model_kwargs = kwargs return generation_config, model_kwargs + def _prepare_cache_for_generation( + self, + generation_config: GaudiGenerationConfig, + model_kwargs: Dict, + assistant_model: "PreTrainedModel", + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> bool: + """ + Copied from: https://github.com/huggingface/transformers/blob/65bb28444849976f853063edb958b3ef3dd59d12/src/transformers/generation/utils.py#L1467 + + Changes: + - change the default from DynamicCache to tuples + """ + + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if assistant_model is not None and generation_config.cache_implementation is not None: + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, + max_cache_len=max_cache_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue and tag @zucchini-nlp." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + + # Use tuples by default (.i.e. legacy format). + else: + return + @torch.no_grad() def generate( self, @@ -891,6 +1009,10 @@ def generate( model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) + elif kwargs_has_attention_mask: + # TODO (joao): generalize this check with other types of inputs + if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: + raise ValueError("`attention_mask` passed to `generate` must be 2D.") is_greedy_or_beam_and_bucket = ( not generation_config.bucket_internal @@ -1041,76 +1163,11 @@ def generate( has_token_idx="token_idx" in model_kwargs, ) - use_dynamic_cache_by_default = False - if "mamba" in self.__class__.__name__.lower(): - cache_name = "cache_params" - else: - cache_name = "past_key_values" - if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): - raise ValueError( - f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " - "Cache object) is unsupported. Please use only one of the two." - ) - elif generation_config.cache_implementation is not None: - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" - ) - model_kwargs[cache_name] = self._get_cache( - generation_config.cache_implementation, - getattr(generation_config, "num_beams", 1) * batch_size, - generation_config.max_length, - model_kwargs, - ) - elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - if cache_config.backend == "quanto" and not is_quanto_available(): - raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" - ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs[cache_name] = cache_class(cache_config) - # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that - # keeps copying the cache thus using much more memory - # elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): - # past = model_kwargs.get(cache_name, None) - # requires_cross_attention_cache = ( - # self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - # ) - # if past is None: - # model_kwargs[cache_name] = ( - # DynamicCache() - # if not requires_cross_attention_cache - # else EncoderDecoderCache(DynamicCache(), DynamicCache()) - # ) - # use_dynamic_cache_by_default = True - # elif isinstance(past, tuple): - # model_kwargs[cache_name] = ( - # DynamicCache.from_legacy_cache(past) - # if not requires_cross_attention_cache - # else EncoderDecoderCache.from_legacy_cache(past) - # ) - # use_dynamic_cache_by_default = True + # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: + model_kwargs["num_logits_to_keep"] = 1 self._validate_generated_length( generation_config, @@ -1118,6 +1175,24 @@ def generate( has_default_max_length, ) + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + user_defined_cache = model_kwargs.get(cache_name) + max_cache_length = generation_config.max_length + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + # determine whether introduce trim_logits feature model_kwargs["trim_logits"] = generation_config.trim_logits @@ -1158,7 +1233,7 @@ def generate( if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) - # 7. determine generation mode + # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if generation_config.bucket_size > 0: @@ -1178,7 +1253,7 @@ def generate( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if self.device.type != input_ids.device.type: + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: warnings.warn( ( "You are calling .generate() with the `input_ids` being on a device type different" @@ -1191,7 +1266,7 @@ def generate( UserWarning, ) - # 8. prepare distribution pre_processing samplers + # 9. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -1203,8 +1278,6 @@ def generate( negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) - - # 9. prepare stopping criteria self.generation_config.generation_mode = generation_mode prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, @@ -1230,8 +1303,8 @@ def generate( raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation == "static": - raise ValueError("assisted generate is not supported with `static_cache`") + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + raise ValueError("assisted generate is not supported with Static cache classes`") if self._is_stateful: # In assisted generation we need the ability to confirm whether the model would pick certain tokens, # which is not possible with stateful models (they can't reset to a previous subset of generated text) @@ -1249,22 +1322,11 @@ def generate( model_kwargs=model_kwargs, ) - # 12. prepare logits warper (if `do_sample` is `True`) - prepared_logits_warper = ( - self._get_logits_warper( - generation_config, - device=input_ids.device, - ) - if generation_config.do_sample - else None - ) - - # 13. run assisted generate + # 12. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1282,16 +1344,10 @@ def generate( raise ValueError( f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" ) - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) result = self._dola_decoding( input_ids, dola_layers=generation_config.dola_layers, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1325,26 +1381,18 @@ def generate( ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, ) - if generation_mode == GenerationMode.SAMPLE: - # 12. expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1359,14 +1407,7 @@ def generate( ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - - # 12. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -1377,7 +1418,7 @@ def generate( max_length=generation_config.max_length, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, @@ -1385,12 +1426,11 @@ def generate( **model_kwargs, ) - # 14. run beam sample + # 13. run beam sample result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1512,11 +1552,34 @@ def typeerror(): **model_kwargs, ) - # Convert to legacy cache if needed - if use_dynamic_cache_by_default and generation_config.return_legacy_cache: - if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): - if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): - result.past_key_values = result.past_key_values.to_legacy_cache() + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + and not is_torchdynamo_compiling() + and hasattr(result, "past_key_values") + and hasattr(result.past_key_values, "to_legacy_cache") + and result.past_key_values.to_legacy_cache is not None + ): + # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) + should_convert_cache = generation_config.return_legacy_cache + is_user_defined_cache = user_defined_cache is not None + is_default_cache_type = ( + type(result.past_key_values) == DynamicCache # noqa E721 + or ( + isinstance(result.past_key_values, EncoderDecoderCache) + and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 + and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 + ) + ) + if not is_user_defined_cache and is_default_cache_type: + logger.warning_once( + "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " + "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " + "keep returning the legacy format, please set `return_legacy_cache=True`." + ) + should_convert_cache = True + if should_convert_cache: + result.past_key_values = result.past_key_values.to_legacy_cache() return result @@ -1529,7 +1592,6 @@ def _dola_decoding( generation_config: GaudiGenerationConfig, synced_gpus: bool, streamer: "BaseStreamer", - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -1558,10 +1620,6 @@ def _dola_decoding( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1754,15 +1812,13 @@ def _contrastive_search( else: logit_for_next_step = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) else: - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration - # (the clone itself is always small) - logit_for_next_step = outputs.logits[:, -1, :].clone() + # .float() is needed to retain precision for later logits manipulations + logit_for_next_step = outputs.logits[:, -1, :].float() model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, - standardize_cache_format=True, ) if not sequential: @@ -1892,7 +1948,7 @@ def _contrastive_search( model_kwargs["past_key_values"].crop(-1) all_outputs.append(outputs) - outputs = stack_model_outputs(all_outputs) + outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) else: # compute the candidate tokens by the language model and collect their hidden_states @@ -1923,7 +1979,8 @@ def _contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + logits = outputs.logits[:, -1, :].float() context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -1974,7 +2031,7 @@ def _contrastive_search( next_past_key_values = selected_outputs["past_key_values"] else: - _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + _, next_past_key_values = self._extract_past_from_model_output(outputs) # Do it in-place layer per layer to save memory if isinstance(next_past_key_values, DynamicCache) or ( isinstance(next_past_key_values, EncoderDecoderCache) @@ -2182,7 +2239,6 @@ def _sample( generation_config: GaudiGenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], - logits_warper: Optional[LogitsProcessorList], lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, @@ -2211,11 +2267,6 @@ def _sample( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). ignore_eos (`bool`, *optional*, defaults to `False`): @@ -2245,13 +2296,9 @@ def _sample( output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2299,7 +2346,9 @@ def _sample( model_kwargs["pad_done"] = False model_kwargs["mqa_model"] = False model_kwargs["lazy_mode"] = lazy_mode - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): if lazy_mode: self.htcore_generation.mark_step() @@ -2333,7 +2382,7 @@ def _sample( if token_idx is not None and outputs.logits.shape[-2] > 1: # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size] if self.config.is_encoder_decoder: - next_token_logits = outputs.logits[:, token_idx - 1, :] + next_token_logits = outputs.logits[:, token_idx - 1, :].float() next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) else: if model_kwargs.get("num_virtual_tokens", 0) > 0: @@ -2347,7 +2396,8 @@ def _sample( next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) next_token_scores = logits_processor(input_ids, next_token_logits) else: - next_token_logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].float() if token_idx is not None and self.config.is_encoder_decoder: # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size] next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) @@ -2355,10 +2405,6 @@ def _sample( # case3 (default case): token_idx is None next_token_scores = logits_processor(input_ids, next_token_logits) - # pre-process distribution - if do_sample: - next_token_scores = logits_warper(input_ids, next_token_scores) - # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -2382,6 +2428,7 @@ def _sample( # token selection if do_sample: probs = torch.nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) @@ -2531,7 +2578,6 @@ def _beam_search( stopping_criteria: StoppingCriteriaList, generation_config: GaudiGenerationConfig, synced_gpus: bool, - logits_warper: Optional[LogitsProcessorList], lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -2559,11 +2605,6 @@ def _beam_search( The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). profiling_warmup_steps (`int`, *optional*, defaults to 0): @@ -2593,11 +2634,6 @@ def _beam_search( return_dict_in_generate = generation_config.return_dict_in_generate sequential = generation_config.low_memory do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -2769,7 +2805,6 @@ def expand_if_needed(tensor, new_size, value, dim=-1): for model_name in [ "fsmt", "reformer", - "bloom", "ctrl", "gpt_bigcode", "transo_xl", @@ -2784,13 +2819,16 @@ def expand_if_needed(tensor, new_size, value, dim=-1): ) inputs_per_sub_batches = _split_model_inputs( - model_inputs, split_size=batch_size, full_batch_size=batch_beam_size + model_inputs, + split_size=batch_size, + full_batch_size=batch_beam_size, + config=self.config.get_text_config(), ) outputs_per_sub_batch = [ self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches ] - outputs = stack_model_outputs(outputs_per_sub_batch) + outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config()) else: hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) outputs = self( @@ -2815,9 +2853,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): else: next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) else: - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :].float() next_token_scores = torch.nn.functional.log_softmax( next_token_logits, dim=-1 @@ -2828,8 +2864,6 @@ def expand_if_needed(tensor, new_size, value, dim=-1): next_token_scores_processed = logits_processor(input_ids[:, :idx], next_token_scores) else: next_token_scores_processed = logits_processor(input_ids, next_token_scores) - if do_sample: - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -3148,10 +3182,6 @@ def _constrained_beam_search( stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. generation_config ([`GaudiGenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): @@ -3269,9 +3299,7 @@ def _constrained_beam_search( else: next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) else: - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :].float() next_token_scores = torch.nn.functional.log_softmax( next_token_logits, dim=-1 @@ -3428,7 +3456,6 @@ def _assisted_decoding( input_ids: torch.LongTensor, candidate_generator: "GaudiCandidateGenerator", logits_processor: LogitsProcessorList, - logits_warper: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GaudiGenerationConfig, synced_gpus: bool, @@ -3456,10 +3483,6 @@ def _assisted_decoding( logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only used if sampling is active. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. @@ -3490,7 +3513,7 @@ def _assisted_decoding( `model.config.is_encoder_decoder=True`. """ # init values - do_sample = logits_warper is not None + do_sample = generation_config.do_sample output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3548,7 +3571,6 @@ def _assisted_decoding( model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids[:, :cur_len]) candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: @@ -3596,14 +3618,12 @@ def _assisted_decoding( ) # 2.3. Process the new logits - new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + # .float() is needed to retain precision for later logits manipulations + new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - if do_sample and len(logits_warper) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) @@ -3775,3 +3795,27 @@ def _assisted_decoding( ) else: return input_ids + + +def _ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + alpha: float, + beam_width: int, +) -> torch.FloatTensor: + """ + Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described + in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each + row in the batch. + """ + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + _, selected_idx = contrastive_score.max(dim=-1) # [B] + return selected_idx diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index ca6e69a53a..a7475db08d 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -23,7 +23,6 @@ GaudiGenerationMixin, gaudi_EosTokenCriteria_call, gaudi_MaxLengthCriteria_call, - gaudi_MaxNewTokensCriteria_call, gaudi_MaxTimeCriteria_call, gaudi_StoppingCriteriaList_call, ) @@ -276,10 +275,13 @@ def adapt_transformers_to_gaudi(): transformers.generation.GenerationMixin._contrastive_search = GaudiGenerationMixin._contrastive_search transformers.generation.GenerationMixin._assisted_decoding = GaudiGenerationMixin._assisted_decoding transformers.generation.GenerationMixin._get_candidate_generator = GaudiGenerationMixin._get_candidate_generator + transformers.generation.GenerationMixin._prepare_cache_for_generation = ( + GaudiGenerationMixin._prepare_cache_for_generation + ) transformers.generation.GenerationConfig = GaudiGenerationConfig + transformers.generation.configuration_utils.GenerationConfig = GaudiGenerationConfig transformers.modeling_utils.GenerationConfig = GaudiGenerationConfig transformers.generation.MaxLengthCriteria.__call__ = gaudi_MaxLengthCriteria_call - transformers.generation.MaxNewTokensCriteria.__call__ = gaudi_MaxNewTokensCriteria_call transformers.generation.MaxTimeCriteria.__call__ = gaudi_MaxTimeCriteria_call transformers.generation.EosTokenCriteria.__call__ = gaudi_EosTokenCriteria_call transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index 7873bb3ead..5b0a770451 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -23,6 +23,7 @@ import torch from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomMLP, dropout_add from transformers.utils import logging @@ -124,22 +125,21 @@ def gaudi_bloom_attention_forward( residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): + batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # Collapse views to improve performance on HPU query_layer = query_layer.contiguous() @@ -162,8 +162,7 @@ def gaudi_bloom_attention_forward( present = None # [batch_size * num_heads, q_length, kv_length] - # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 - matmul_result = alibi.baddbmm( + attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, @@ -171,7 +170,7 @@ def gaudi_bloom_attention_forward( ) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + attention_scores = attention_scores.view(batch_size, self.num_heads, q_length, -1) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype @@ -185,7 +184,7 @@ def gaudi_bloom_attention_forward( attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) @@ -225,10 +224,11 @@ def gaudi_bloom_block_forward( hidden_states: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] @@ -252,6 +252,7 @@ def gaudi_bloom_block_forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, ) @@ -326,7 +327,7 @@ def gaudi_bloom_convert_to_bloom_cache( def gaudi_bloom_model_forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -334,6 +335,7 @@ def gaudi_bloom_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: @@ -429,6 +431,7 @@ def gaudi_bloom_model_forward( head_mask[i], use_cache, output_attentions, + cache_position, None, ) else: @@ -440,6 +443,7 @@ def gaudi_bloom_model_forward( use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, token_idx=token_idx, ) @@ -477,10 +481,12 @@ def set_tp_for_inference(tp_for_inference: int): def prepare_inputs_for_generation( self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -498,14 +504,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the + # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in + # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs.update( { + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, } @@ -515,7 +525,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -524,6 +534,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: @@ -555,6 +566,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, ) hidden_states = transformer_outputs[0] diff --git a/optimum/habana/transformers/models/codegen/modeling_codegen.py b/optimum/habana/transformers/models/codegen/modeling_codegen.py index 573ec0f606..a7f15d32d4 100644 --- a/optimum/habana/transformers/models/codegen/modeling_codegen.py +++ b/optimum/habana/transformers/models/codegen/modeling_codegen.py @@ -3,6 +3,7 @@ import torch import torch.utils.checkpoint from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -16,12 +17,13 @@ class GaudiCodeGenAttention(CodeGenAttention): def forward( self, hidden_states: Optional[torch.FloatTensor], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Tuple[torch.Tensor]], @@ -106,12 +108,13 @@ def forward( def gaudi_codegen_block_forward( self, hidden_states: Optional[torch.FloatTensor], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: """ @@ -129,6 +132,7 @@ def gaudi_codegen_block_forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) @@ -148,7 +152,7 @@ def gaudi_codegen_block_forward( def gaudi_codegen_model_forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -158,6 +162,7 @@ def gaudi_codegen_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -229,14 +234,16 @@ def gaudi_codegen_model_forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + seq_length = inputs_embeds.shape[1] + hidden_states = inputs_embeds if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) if self.gradient_checkpointing and self.training: @@ -264,6 +271,7 @@ def gaudi_codegen_model_forward( head_mask[i], use_cache, output_attentions, + cache_position, None, ) else: @@ -275,6 +283,7 @@ def gaudi_codegen_model_forward( head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, ) @@ -314,7 +323,17 @@ class GaudiCodeGenForCausalLM(CodeGenForCausalLM): """ def prepare_inputs_for_generation( - self, input_ids, inputs_embeds=None, past_key_values=None, token_idx=None, **kwargs + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + token_idx=None, + **kwargs, ): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values @@ -330,9 +349,6 @@ def prepare_inputs_for_generation( if token_type_ids is not None: token_type_ids = token_type_ids[:, -1] - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -343,17 +359,21 @@ def prepare_inputs_for_generation( else: position_ids = position_ids[:, -1] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs.update( { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, @@ -364,7 +384,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -375,6 +395,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -397,6 +418,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, ) hidden_states = transformer_outputs[0] diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 398266254d..0277668422 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -30,6 +30,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -253,7 +254,7 @@ class GaudiFalconAttention(FalconAttention): 4. not use_flash_attention, bf16: F.scaled_dot_product_attention. Slowest option """ - def __init__(self, config: FalconConfig): + def __init__(self, config: FalconConfig, layer_idx=None): super().__init__(config) self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" @@ -337,10 +338,12 @@ def pre_attn_forward( alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, @@ -609,9 +612,9 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer): - add new arg flash_attention_causal_mask """ - def __init__(self, config: FalconConfig): + def __init__(self, config: FalconConfig, layer_idx=None): super().__init__(config) - self.self_attention = GaudiFalconAttention(config) + self.self_attention = GaudiFalconAttention(config, layer_idx) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -625,10 +628,12 @@ def forward( alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, @@ -654,6 +659,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -711,6 +718,8 @@ def pre_attn( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, @@ -735,6 +744,8 @@ def pre_attn( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -769,7 +780,7 @@ def update_sincos_cache(self, seq_len): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, @@ -778,6 +789,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, @@ -898,6 +910,8 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + position_embeddings = None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -913,6 +927,8 @@ def forward( layer_past, use_cache, output_attentions, + cache_position, + position_embeddings, None, use_flash_attention, flash_attention_recompute, @@ -928,6 +944,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -984,10 +1002,12 @@ def update_sincos_cache(self, seq_len): def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = True, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1030,16 +1050,20 @@ def prepare_inputs_for_generation( else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": reuse_cache, @@ -1054,7 +1078,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1064,6 +1088,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, @@ -1097,6 +1122,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index cf47a1df2e..9f82fae821 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -36,7 +36,7 @@ GemmaModel, apply_rotary_pos_emb, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, @@ -710,6 +710,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, @@ -745,10 +746,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -781,6 +790,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): """ @@ -816,6 +826,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) if token_idx is None: if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): @@ -828,7 +840,7 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs.update( { @@ -837,6 +849,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, "token_idx": token_idx, } ) diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 330f36f5b6..66b2a51dfd 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -2,6 +2,7 @@ import torch from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.gpt_neox.modeling_gpt_neox import ( GPTNeoXAttention, @@ -29,9 +30,11 @@ def gaudi_gpt_neox_attention_forward( attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, head_mask: Optional[torch.FloatTensor] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + padding_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ @@ -103,14 +106,14 @@ def gaudi_gpt_neox_attention_forward( class GaudiGPTNeoXLayer(GPTNeoXLayer): - def __init__(self, config): + def __init__(self, config, layer_idx): super(GPTNeoXLayer, self).__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_dropout = torch.nn.Dropout(config.hidden_dropout) self.post_mlp_dropout = torch.nn.Dropout(config.hidden_dropout) - self.attention = GPTNeoXAttention(config) + self.attention = GPTNeoXAttention(config, layer_idx) self.mlp = GPTNeoXMLP(config) def forward( @@ -120,8 +123,9 @@ def forward( position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ @@ -137,6 +141,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) @@ -173,11 +178,12 @@ def gaudi_gpt_neox_model_forward( position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -260,6 +266,7 @@ def gaudi_gpt_neox_model_forward( use_cache, None, output_attentions, + cache_position, None, ) else: @@ -271,6 +278,7 @@ def gaudi_gpt_neox_model_forward( layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, ) hidden_states = outputs[0] @@ -322,12 +330,13 @@ def forward( position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -343,6 +352,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, ) @@ -372,7 +382,16 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + token_idx=None, + **kwargs, ): input_shape = input_ids.shape @@ -393,7 +412,6 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -403,6 +421,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -412,13 +432,15 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} + model_inputs.update( { - "attention_mask": attention_mask, - "past_key_values": past_key_values, "position_ids": position_ids, - "use_cache": kwargs.get("use_cache"), + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, "token_idx": token_idx, } ) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 8fbb3441a4..3927e1feb9 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -4,6 +4,7 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.gptj.configuration_gptj import GPTJConfig from transformers.models.gptj.modeling_gptj import ( @@ -68,10 +69,18 @@ def forward(self, cur, dim, idx): class GaudiGPTJAttention(GPTJAttention): - def __init__(self, config: GPTJConfig): + def __init__(self, config: GPTJConfig, layer_idx=None): super().__init__(config) self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) self.matmul_qk = Matmul() self.matmul_av = Matmul() self.k_cache = KVCache() @@ -155,12 +164,13 @@ def _attn( def forward( self, hidden_states: torch.FloatTensor, - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, cos: Optional[torch.Tensor] = None, @@ -265,11 +275,11 @@ class GaudiGPTJBlock(GPTJBlock): Inherits from GPTJBlock: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/gptj/modeling_gptj.py#291 """ - def __init__(self, config: GPTJConfig): - super().__init__(config) + def __init__(self, config: GPTJConfig, layer_idx=None): + super().__init__(config, layer_idx=None) inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd self.ln_1 = torch.nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GaudiGPTJAttention(config) + self.attn = GaudiGPTJAttention(config, layer_idx) self.mlp = GPTJMLP(inner_dim, config) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -284,12 +294,13 @@ def update_sincos_cache(self, seq_len): def forward( self, hidden_states: Optional[torch.FloatTensor], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, cos: Optional[torch.Tensor] = None, @@ -312,6 +323,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -351,7 +363,7 @@ def update_sincos_cache(self, seq_len): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -361,6 +373,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -489,6 +502,7 @@ def forward( head_mask[i], use_cache, output_attentions, + cache_position, None, sin, cos, @@ -502,6 +516,7 @@ def forward( head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -555,11 +570,19 @@ def update_sincos_cache(self, seq_len): self.transformer.update_sincos_cache(seq_len) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + token_idx=None, + **kwargs, ): reuse_cache = kwargs.get("reuse_cache") - token_type_ids = kwargs.get("token_type_ids", None) - attention_mask = kwargs.get("attention_mask", None) # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: @@ -587,8 +610,6 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -598,18 +619,21 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs.update( { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, @@ -623,7 +647,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -634,6 +658,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -658,6 +683,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index ce754dadb5..fb159cfc48 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -25,6 +25,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, mlp_bias=False, + head_dim=None, fused_qkv=False, parallel_strategy=None, **kwargs, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c4a8dc6c85..f59b048684 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -20,6 +20,7 @@ apply_rotary_pos_emb, logger, ) +from transformers.utils import is_torchdynamo_compiling from .... import distributed from ....distributed.strategy import DistributedStrategy, NoOpStrategy @@ -99,7 +100,7 @@ def __init__( if config is None: logger.warning_once( "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" + "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, @@ -185,7 +186,7 @@ def forward(self, x, seq_len=None): class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): def __init__(self, *args, **kwargs): logger.warning_once( - "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." ) kwargs["rope_type"] = "linear" @@ -206,7 +207,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): def __init__(self, *args, **kwargs): logger.warning_once( - "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " "__init__)." ) @@ -480,7 +481,7 @@ def pre_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -562,7 +563,7 @@ def pre_attn_forward( # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -829,7 +830,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -1245,6 +1246,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, @@ -1302,11 +1304,18 @@ def forward( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1339,6 +1348,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, token_idx=None, **kwargs, ): @@ -1370,6 +1380,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # keep cache_position implementation as None for HPU cache_position = None @@ -1378,7 +1390,10 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index 8119f442c5..997c16d700 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -120,6 +120,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, image_offset: Optional[int] = None, tokens_pos: Optional[torch.LongTensor] = None, @@ -152,6 +154,7 @@ def forward( # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) + image_features = None # 2. Merge text and images if pixel_values is not None and input_ids.shape[1] != 1: image_outputs = self.vision_tower( @@ -186,6 +189,9 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here + # num_logits_to_keep=num_logits_to_keep, token_idx=token_idx + image_offset, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, @@ -210,6 +216,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) else: @@ -230,7 +237,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + num_logits_to_keep=None, + **kwargs, ): """ Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py @@ -301,6 +316,10 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids} use_flash_attention = kwargs.get("use_flash_attention", False) flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, diff --git a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py index 4670469e9e..6cf728d014 100644 --- a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py +++ b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py @@ -53,6 +53,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, @@ -84,6 +86,9 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here + # num_logits_to_keep=num_logits_to_keep, token_idx=token_idx + self.image_offset, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, @@ -142,6 +147,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) # Copied from https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L356 @@ -230,6 +237,8 @@ def prepare_inputs_for_generation( pixel_values=None, image_sizes=None, attention_mask=None, + cache_position=None, + num_logits_to_keep=None, **kwargs, ): """ @@ -247,6 +256,8 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, image_sizes=image_sizes, attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) else: @@ -386,6 +397,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py index 3f13114557..e23ce65dd8 100644 --- a/optimum/habana/transformers/models/mamba/modeling_mamba.py +++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py @@ -24,10 +24,18 @@ def gaudi_MambaForCausalLM_update_model_kwargs_for_generation( and model_kwargs["cache_position"] is not None ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + if token_idx is not None: token_idx.add_(1) if "token_idx_cpu" in model_kwargs: model_kwargs["token_idx_cpu"] += 1 + return model_kwargs @@ -38,7 +46,7 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask=None, + attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): token_idx = kwargs.get("token_idx", None) @@ -54,6 +62,10 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -63,6 +75,8 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( else: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) + if attention_mask is not None: + attention_mask = None else: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, torch.arange(token_idx_cpu, device=input_ids.device)) @@ -76,6 +90,7 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 1c5d108b72..822c9ea646 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -39,7 +39,7 @@ MistralRMSNorm, apply_rotary_pos_emb, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, @@ -696,6 +696,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, @@ -750,11 +751,18 @@ def forward( hidden_states = hidden_states.index_select(1, token_idx - 1) else: hidden_states = hidden_states[:, -1, :] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -787,6 +795,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): """ @@ -826,6 +835,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -833,6 +844,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index cfbd4dc3b0..208f4099a5 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -44,7 +44,7 @@ apply_rotary_pos_emb, load_balancing_loss_func, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ..llama.modeling_llama import ( GaudiLlamaDynamicNTKScalingRotaryEmbedding, @@ -745,6 +745,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, @@ -780,11 +781,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -833,6 +841,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): reuse_cache = kwargs.get("reuse_cache") @@ -871,6 +880,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -878,6 +890,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "output_router_logits": output_router_logits, "token_idx": token_idx, "reuse_cache": reuse_cache, "flash_attention_recompute": kwargs.get("flash_attention_recompute"), diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index c9eb95524e..90aa2d5e0f 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -115,7 +115,7 @@ def gaudi_conv1d_forward(self, x): @classmethod def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: # This model doesn't support SDPA in Gaudi yet, fallback to original code. - MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral", "wav2vec2"] + MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral", "wav2vec2", "roberta"] if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER: config._attn_implementation = "eager" diff --git a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py index 936f130762..c1fb019d66 100644 --- a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py +++ b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py @@ -68,12 +68,12 @@ def gaudi_persimmon_attention_forward( # Partial rotary embedding query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim :], + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], ) key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim :], + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) @@ -97,7 +97,7 @@ def gaudi_persimmon_attention_forward( cache_kwargs = { "sin": sin, "cos": cos, - "partial_rotation_size": self.rotary_emb.dim, + "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -339,6 +339,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ @@ -369,7 +370,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for Persimmon + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -405,6 +407,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): """ @@ -437,12 +440,19 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index b04ac5af1e..53a4b1f73a 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -35,7 +35,7 @@ PhiModel, apply_rotary_pos_emb, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, @@ -202,12 +202,12 @@ def forward( # Partial rotary embedding query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim :], + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], ) key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim :], + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) @@ -532,6 +532,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, @@ -575,11 +576,18 @@ def forward( hidden_states = hidden_states.index_select(1, token_idx - 1) else: hidden_states = hidden_states[:, -1, :] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -612,6 +620,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, token_idx=None, **kwargs, ): @@ -650,12 +659,19 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { @@ -670,4 +686,5 @@ def prepare_inputs_for_generation( "cache_idx": kwargs.get("cache_idx"), } ) + return model_inputs diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 67f969ea37..e25f59c16d 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -35,6 +35,7 @@ apply_rotary_pos_emb, logger, ) +from transformers.utils import is_torchdynamo_compiling from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, @@ -763,6 +764,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, @@ -816,10 +818,18 @@ def forward( else: hidden_states = hidden_states[:, -1, :] - logits = self.lm_head(hidden_states).float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -852,6 +862,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, token_idx=None, **kwargs, ): @@ -883,6 +894,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) cache_position = None @@ -890,7 +903,12 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index 8acfd1ece1..22eca3c9da 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -72,12 +72,12 @@ def gaudi_stablelm_attention_forward( # Partial rotary embedding query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim :], + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], ) key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim :], + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) @@ -101,7 +101,7 @@ def gaudi_stablelm_attention_forward( cache_kwargs = { "sin": sin, "cos": cos, - "partial_rotation_size": self.rotary_emb.dim, + "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -370,6 +370,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ @@ -398,7 +399,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for StableLm + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -434,6 +436,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): """ @@ -466,12 +469,19 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index f15b006d6b..0fe6ea8c51 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -33,7 +33,7 @@ Starcoder2Model, apply_rotary_pos_emb, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, @@ -734,6 +734,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, @@ -785,10 +786,18 @@ def forward( else: hidden_states = hidden_states[:, -1, :] - logits = self.lm_head(hidden_states).float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -821,6 +830,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, token_idx=None, **kwargs, ): @@ -850,6 +860,8 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) cache_position = None @@ -857,7 +869,12 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 01bb3b01b4..a9c9a1c923 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -482,7 +482,7 @@ def train( # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: - if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + if args.bf16_full_eval and not args.do_train and not self.is_model_parallel: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: @@ -675,11 +675,6 @@ def _inner_training_loop( # Activate gradient checkpointing if needed if args.gradient_checkpointing: - if args.gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {} - else: - gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs - import transformers.modeling_utils if args.deepspeed: @@ -713,7 +708,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio torch.utils.checkpoint.checkpoint = lazy_mode_checkpointing transformers.modeling_utils.checkpoint = lazy_mode_checkpointing - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) # Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context. if self.accelerator.state.is_fp8_enabled: @@ -817,7 +812,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) self._load_callback_state() - epochs_trained = self.state.global_step // num_update_steps_per_epoch + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps @@ -1041,6 +1036,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio args.max_grad_norm, ) + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + optimizer_was_run = True self.optimizer.step() @@ -1068,7 +1065,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio break if step < 0: logger.warning( - "There seems to be not a single sample in your epoch_iterator, stopping training at step" + "There seems not to be a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) @@ -1366,8 +1363,16 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: - # Update the `TrainerControl` state to where we are currently - self.state.stateful_callbacks["TrainerControl"] = self.control.state() + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: @@ -1579,6 +1584,9 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te `torch.Tensor`: The tensor with training loss on this batch. """ model.train() + # if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + # self.optimizer.train() + inputs = self._prepare_inputs(inputs) with self.compute_loss_context_manager(): @@ -1816,6 +1824,8 @@ def evaluation_loop( self.deepspeed = self.model_wrapped model.eval() + # if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + # self.optimizer.eval() # Do not use HPU graphs if the training is ongoing because it detaches gradients if args.use_hpu_graphs_for_inference and not self.is_in_train: @@ -2223,6 +2233,8 @@ def prediction_loop( if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped model.eval() + # if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + # self.optimizer.eval() # Do not use HPU graphs if the training is ongoing because it detaches gradients if args.use_hpu_graphs_for_inference and not self.is_in_train: @@ -2439,24 +2451,21 @@ def create_accelerator_and_postprocess(self): self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # post accelerator creation setup - # copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/trainer.py#L3991 # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) - if is_accelerate_available("0.23.0"): - fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( - "activation_checkpointing", fsdp_plugin.activation_checkpointing + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." ) - if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: - raise ValueError( - "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " - "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " - "when using FSDP." - ) if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: self.propagate_args_to_deepspeed() @@ -2470,10 +2479,15 @@ def create_accelerator_and_postprocess(self): wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") - # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP - if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size: - wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" - raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.") + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) def propagate_args_to_deepspeed(self, auto_find_batch_size=False): """ diff --git a/optimum/habana/transformers/trainer_seq2seq.py b/optimum/habana/transformers/trainer_seq2seq.py index 0873d9e7e9..7a327b5a7b 100644 --- a/optimum/habana/transformers/trainer_seq2seq.py +++ b/optimum/habana/transformers/trainer_seq2seq.py @@ -85,7 +85,7 @@ def load_generation_config(gen_config_arg: Union[str, GaudiGenerationConfig]) -> Loads a `~generation.GaudiGenerationConfig` from the `GaudiSeq2SeqTrainingArguments.generation_config` arguments. Args: - gen_config_arg (`str` or [`~generation.GaudiGenerationConfig`]): + gen_config_arg (`str` or [`~generation.GaudiGenerationConfig]`): `GaudiSeq2SeqTrainingArguments.generation_config` argument. Returns: diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 5a65074fc9..3a71d46506 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -581,8 +581,8 @@ def __post_init__(self): " during training" ) - if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0 or 0 < self.warmup_steps <= 1: - raise ValueError("warmup_steps must be either 0 or > 1") + if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0: + raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.") # Copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/training_args.py#L1563 # except following changes, (1) Remove XLA specific code & (2) change fsdp_backward_prefetch to backward_prefetch @@ -654,7 +654,7 @@ def __post_init__(self): self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) # accelerate integration for FSDP - if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + if len(self.fsdp) > 0: os.environ["ACCELERATE_USE_FSDP"] = "true" from accelerate.utils.constants import ( FSDP_AUTO_WRAP_POLICY, diff --git a/setup.py b/setup.py index cea680353e..37c16d8e2f 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ INSTALL_REQUIRES = [ - "transformers >= 4.43.0, < 4.44.0", + "transformers >= 4.45.0, < 4.46.0", "optimum", "torch", "accelerate >= 0.33.0, < 0.34.0", diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index 4f1e657ec1..19687459ed 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -33,7 +33,7 @@ < check_min_version("4.46.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") 174,176d175 < freeze_feature_extractor: Optional[bool] = field( diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt index 92b3cbbb63..ce6b6fd9ba 100644 --- a/tests/example_diff/run_clip.txt +++ b/tests/example_diff/run_clip.txt @@ -28,7 +28,7 @@ < check_min_version("4.46.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") 181a190,192 > mediapipe_dataloader: bool = field( diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt index 162953c477..22fbddee9d 100644 --- a/tests/example_diff/run_clm.txt +++ b/tests/example_diff/run_clm.txt @@ -38,7 +38,7 @@ > 63a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt index 5b0aef3cbf..52be944799 100644 --- a/tests/example_diff/run_glue.txt +++ b/tests/example_diff/run_glue.txt @@ -27,7 +27,7 @@ > logger = logging.getLogger(__name__) > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") 67,68d76 < logger = logging.getLogger(__name__) diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index 0e27153ed8..4e8112afe0 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -28,7 +28,7 @@ < check_min_version("4.46.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") 184c192 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 3eefb5b434..f69c33bdbc 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -34,7 +34,7 @@ 61a62,69 > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index c95011a858..bfbe594bcf 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -32,7 +32,7 @@ > 58a62,67 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/tests/example_diff/run_seq2seq_qa.txt b/tests/example_diff/run_seq2seq_qa.txt index afdc7c1502..3923456cb0 100644 --- a/tests/example_diff/run_seq2seq_qa.txt +++ b/tests/example_diff/run_seq2seq_qa.txt @@ -24,7 +24,7 @@ > 54a58,63 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt index 84684c39ac..d365943d86 100644 --- a/tests/example_diff/run_speech_recognition_ctc.txt +++ b/tests/example_diff/run_speech_recognition_ctc.txt @@ -25,7 +25,7 @@ > return () 59a61,66 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/tests/example_diff/run_speech_recognition_seq2seq.txt b/tests/example_diff/run_speech_recognition_seq2seq.txt index 6ec5e2fa13..4edffaea00 100644 --- a/tests/example_diff/run_speech_recognition_seq2seq.txt +++ b/tests/example_diff/run_speech_recognition_seq2seq.txt @@ -22,7 +22,7 @@ 51c58,59 < check_min_version("4.46.0.dev0") --- -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") 230a239,242 > label_features_max_length: int = field( diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt index cddaf91241..2cb562bc11 100644 --- a/tests/example_diff/run_summarization.txt +++ b/tests/example_diff/run_summarization.txt @@ -36,7 +36,7 @@ > 60a67,72 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt index 8fefa1ffb7..2ea57200b6 100644 --- a/tests/example_diff/run_translation.txt +++ b/tests/example_diff/run_translation.txt @@ -28,7 +28,7 @@ > 60a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.43.0") +> check_min_version("4.45.0") > check_optimum_habana_min_version("1.14.0.dev0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/tests/test_trainer.py b/tests/test_trainer.py index ba78bbd2cc..eddb82b500 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -28,7 +28,7 @@ from typing import Dict, List, Optional, Union import numpy as np -from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files +from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files from parameterized import parameterized from pytest import mark from requests.exceptions import HTTPError @@ -108,6 +108,21 @@ adapt_transformers_to_gaudi() +class MockOOMCallback(TrainerCallback): + """ + Simple callback to simulate CUDA OOM error if + the batch size is >= to `batch_size_limit`. + """ + + def __init__(self, batch_size_limit=16): + self.batch_size_limit = batch_size_limit + + def on_step_end(self, args, state, control, **kwargs): + # simulate OOM on the first step + if state.train_batch_size >= self.batch_size_limit: + raise RuntimeError("Out of memory.") + + class RegressionDataset: def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) @@ -1855,45 +1870,73 @@ def test_resume_training_with_randomness(self): self.assertAlmostEqual(a, a1, delta=1e-5) self.assertAlmostEqual(b, b1, delta=1e-5) - def test_auto_batch_size_with_resume_from_checkpoint(self): - train_dataset = RegressionDataset(length=128) + # @require_deepspeed + # def test_auto_batch_size_with_deepspeed(self): + # train_dataset = RegressionDataset(length=128) + + # config = RegressionModelConfig(a=0, b=2) + # model = RegressionRandomPreTrainedModel(config) + + # tmp_dir = self.get_auto_remove_tmp_dir() + + # for stage in [1, 2]: + # deepspeed = { + # "zero_optimization": { + # "stage": stage, + # }, + # "train_batch_size": "auto", + # "train_micro_batch_size_per_gpu": "auto", + # } + + # args = RegressionGaudiTrainingArguments( + # tmp_dir, + # do_train=True, + # max_steps=2, + # save_strategy="no", + # per_device_train_batch_size=16, + # auto_find_batch_size=True, + # deepspeed=deepspeed, + # use_habana=True, + # use_lazy_mode=True, + # ) + # gaudi_config = get_gaudi_config() + # trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset, callbacks=[MockOOMCallback()]) + # trainer.train() + # self.assertEqual(trainer._train_batch_size, 8) - config = RegressionModelConfig(a=0, b=2) - model = RegressionRandomPreTrainedModel(config) + # def test_auto_batch_size_with_resume_from_checkpoint(self): + # train_dataset = RegressionDataset(length=128) - tmp_dir = self.get_auto_remove_tmp_dir() + # config = RegressionModelConfig(a=0, b=2) + # model = RegressionRandomPreTrainedModel(config) - class MockCudaOOMCallback(TrainerCallback): - def on_step_end(self, args, state, control, **kwargs): - # simulate OOM on the first step - if state.train_batch_size >= 16: - raise RuntimeError("CUDA out of memory.") + # tmp_dir = self.get_auto_remove_tmp_dir() - args = RegressionGaudiTrainingArguments( - tmp_dir, - do_train=True, - max_steps=2, - save_steps=1, - per_device_train_batch_size=16, - auto_find_batch_size=True, - use_habana=True, - use_lazy_mode=True, - ) - gaudi_config = get_gaudi_config() - trainer = GaudiTrainer( - model, gaudi_config, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()] - ) - trainer.train() - # After `auto_find_batch_size` is ran we should now be at 8 - self.assertEqual(trainer._train_batch_size, 8) + # args = RegressionGaudiTrainingArguments( + # tmp_dir, + # do_train=True, + # max_steps=2, + # save_steps=1, + # per_device_train_batch_size=16, + # auto_find_batch_size=True, + # use_habana=True, + # use_lazy_mode=True, + # ) + # gaudi_config = get_gaudi_config() + # trainer = GaudiTrainer( + # model, gaudi_config, args, train_dataset=train_dataset, callbacks=[MockOOMCallback()] + # ) + # trainer.train() + # # After `auto_find_batch_size` is ran we should now be at 8 + # self.assertEqual(trainer._train_batch_size, 8) - # We can then make a new Trainer - trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset) - # Check we are at 16 to start - self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1)) - trainer.train(resume_from_checkpoint=True) - # We should be back to 8 again, picking up based upon the last ran Trainer - self.assertEqual(trainer._train_batch_size, 8) + # # We can then make a new Trainer + # trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset) + # # Check we are at 16 to start + # self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1)) + # trainer.train(resume_from_checkpoint=True) + # # We should be back to 8 again, picking up based upon the last ran Trainer + # self.assertEqual(trainer._train_batch_size, 8) # regression for this issue: https://github.com/huggingface/transformers/issues/12970 def test_training_with_resume_from_checkpoint_false(self): @@ -2903,6 +2946,25 @@ def test_push_to_hub_tags(self): model_card = ModelCard.load(repo_name) self.assertTrue("test-trainer-tags" in model_card.data.tags) + def test_push_to_hub_with_revision(self): + # Checks if `trainer.push_to_hub()` works correctly by adding revision + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=os.path.join(tmp_dir, "test-trainer-revision"), + push_to_hub=True, + hub_token=self._token, + ) + branch = "v1.0" + create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True) + url = trainer.push_to_hub(revision=branch) + + # Extract branch from the url + re_search = re.search(r"tree/([^/]+)/", url) + self.assertIsNotNone(re_search) + + branch_name = re_search.groups()[0] + self.assertEqual(branch_name, branch) + @require_torch @require_optuna diff --git a/tests/transformers/tests/generation/test_stopping_criteria.py b/tests/transformers/tests/generation/test_stopping_criteria.py index 0ce7838eee..9f177f9630 100644 --- a/tests/transformers/tests/generation/test_stopping_criteria.py +++ b/tests/transformers/tests/generation/test_stopping_criteria.py @@ -27,7 +27,6 @@ from transformers.generation import ( EosTokenCriteria, MaxLengthCriteria, - MaxNewTokensCriteria, MaxTimeCriteria, StoppingCriteriaList, validate_stopping_criteria, @@ -74,21 +73,6 @@ def test_max_length_criteria(self): input_ids, scores = self._get_tensors(10) self.assertTrue(all(criteria(input_ids, scores))) - def test_max_new_tokens_criteria(self): - criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) - - input_ids, scores = self._get_tensors(5) - self.assertFalse(all(criteria(input_ids, scores))) - - input_ids, scores = self._get_tensors(9) - self.assertFalse(all(criteria(input_ids, scores))) - - input_ids, scores = self._get_tensors(10) - self.assertTrue(all(criteria(input_ids, scores))) - - criteria_list = StoppingCriteriaList([criteria]) - self.assertEqual(criteria_list.max_length, 10) - def test_max_time_criteria(self): input_ids, scores = self._get_tensors(5)