From c5755320787740351f26824293157f42fd301ac0 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 27 Jul 2023 17:00:03 -0700 Subject: [PATCH] Merge pull request #306 from huggingface/upgrade_transformers Upgrade to Transformers v4.31 --- Makefile | 3 +- .../run_audio_classification.py | 2 +- .../contrastive-image-text/run_bridgetower.py | 2 +- examples/contrastive-image-text/run_clip.py | 2 +- .../run_image_classification.py | 3 +- examples/language-modeling/run_clm.py | 9 +- examples/language-modeling/run_mlm.py | 9 +- examples/question-answering/run_qa.py | 8 +- examples/question-answering/run_seq2seq_qa.py | 2 +- .../run_speech_recognition_ctc.py | 4 +- examples/summarization/run_summarization.py | 2 +- examples/text-classification/run_glue.py | 2 +- examples/translation/run_translation.py | 2 +- optimum/habana/accelerate/__init__.py | 2 + optimum/habana/accelerate/accelerator.py | 786 ++++++++++++ optimum/habana/accelerate/state.py | 221 ++++ optimum/habana/accelerate/utils/__init__.py | 1 + .../habana/accelerate/utils/dataclasses.py | 34 + .../habana/distributed/distributed_runner.py | 4 +- optimum/habana/transformers/deepspeed.py | 87 +- .../habana/transformers/generation/utils.py | 315 +++-- .../models/albert/modeling_albert.py | 6 +- .../models/bloom/modeling_bloom.py | 2 +- .../transformers/models/gpt2/modeling_gpt2.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 30 +- .../transformers/models/gptj/modeling_gptj.py | 2 + .../models/llama/modeling_llama.py | 96 +- .../transformers/models/opt/modeling_opt.py | 19 +- .../transformers/models/t5/modeling_t5.py | 1 + optimum/habana/transformers/trainer.py | 1093 +++++++++-------- .../habana/transformers/trainer_seq2seq.py | 19 +- optimum/habana/transformers/training_args.py | 214 ++-- optimum/habana/utils.py | 2 + pyproject.toml | 2 +- setup.py | 4 +- tests/create_diff_file_for_example.py | 22 +- .../example_diff/run_audio_classification.txt | 24 +- tests/example_diff/run_clip.txt | 75 +- tests/example_diff/run_clm.txt | 22 +- tests/example_diff/run_generation.txt | 599 +++++---- tests/example_diff/run_glue.txt | 18 +- .../example_diff/run_image_classification.txt | 32 +- tests/example_diff/run_mlm.txt | 22 +- tests/example_diff/run_qa.txt | 16 +- tests/example_diff/run_seq2seq_qa.txt | 16 +- .../run_speech_recognition_ctc.txt | 30 +- tests/example_diff/run_summarization.txt | 82 +- tests/example_diff/run_translation.txt | 20 +- tests/test_trainer.py | 187 +-- tests/test_trainer_seq2seq.py | 8 +- 50 files changed, 2820 insertions(+), 1346 deletions(-) create mode 100644 optimum/habana/accelerate/__init__.py create mode 100644 optimum/habana/accelerate/accelerator.py create mode 100644 optimum/habana/accelerate/state.py create mode 100644 optimum/habana/accelerate/utils/__init__.py create mode 100644 optimum/habana/accelerate/utils/dataclasses.py diff --git a/Makefile b/Makefile index 2fb4184746..b53f3c32a8 100644 --- a/Makefile +++ b/Makefile @@ -53,14 +53,12 @@ slow_tests_deepspeed: test_installs python -m pytest tests/test_examples.py -v -s -k "deepspeed" slow_tests_diffusers: test_installs - python -m pip install git+https://github.com/huggingface/transformers.git python -m pip install git+https://github.com/huggingface/diffusers.git python -m pip install ftfy python -m pytest tests/test_diffusers.py -v -s -k "test_no_" # Check if examples are up to date with the Transformers library example_diff_tests: test_installs - python -m pip install git+https://github.com/huggingface/transformers.git python -m pytest tests/test_examples_match_transformers.py # Utilities to release to PyPi @@ -103,3 +101,4 @@ clean: test_installs: python -m pip install .[tests] + python -m pip install git+https://github.com/huggingface/transformers.git diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index 46d9c6f45f..0d2d6667ca 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index e53e75c46b..7c214879ab 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.27.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index a0d0369a75..c6ab117af9 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index f3ef617028..c4de9814e9 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -12,6 +12,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +# limitations under the License. import logging import os @@ -54,7 +55,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 0bc97040a8..55d54c51a4 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -52,7 +52,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -509,10 +509,9 @@ def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size + # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index fa53e93630..c9df6b3aa3 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -511,10 +511,9 @@ def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length + # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // max_seq_length) * max_seq_length # Split by chunks of max_len. result = { k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 117f2c620d..068d117312 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -49,7 +49,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") @@ -599,12 +599,12 @@ def post_processing_function(examples, features, predictions, stage="eval"): # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [ - {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() ] else: - formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()] - references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py index f6a448d9e9..c6a4ed3c85 100644 --- a/examples/question-answering/run_seq2seq_qa.py +++ b/examples/question-answering/run_seq2seq_qa.py @@ -46,7 +46,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py index e8dd117773..d3f63346d4 100644 --- a/examples/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/speech-recognition/run_speech_recognition_ctc.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") @@ -720,7 +720,7 @@ def compute_metrics(pred): compute_metrics=compute_metrics, train_dataset=vectorized_datasets["train"] if training_args.do_train else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, - tokenizer=feature_extractor, + tokenizer=processor, ) # 8. Finally, we can start training diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index 8449e56188..3446af38a2 100644 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -53,7 +53,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 6234a2a118..c40f1640dc 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py index 6b1ff6c031..a61cbddaba 100644 --- a/examples/translation/run_translation.py +++ b/examples/translation/run_translation.py @@ -51,7 +51,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.28.0") +check_min_version("4.31.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/optimum/habana/accelerate/__init__.py b/optimum/habana/accelerate/__init__.py new file mode 100644 index 0000000000..7045124d9a --- /dev/null +++ b/optimum/habana/accelerate/__init__.py @@ -0,0 +1,2 @@ +from .accelerator import GaudiAccelerator +from .state import GaudiAcceleratorState, GaudiPartialState diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py new file mode 100644 index 0000000000..0bccd8d140 --- /dev/null +++ b/optimum/habana/accelerate/accelerator.py @@ -0,0 +1,786 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import math +import os +import sys +import warnings +from collections import OrderedDict +from contextlib import contextmanager +from dataclasses import make_dataclass +from types import MethodType + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.scheduler import AcceleratedScheduler +from accelerate.state import GradientState +from accelerate.tracking import GeneralTracker, filter_trackers +from accelerate.utils import ( + DeepSpeedPlugin, + DistributedDataParallelKwargs, + DistributedType, + DynamoBackend, + FP8RecipeKwargs, + FullyShardedDataParallelPlugin, + GradientAccumulationPlugin, + GradScalerKwargs, + InitProcessGroupKwargs, + KwargsHandler, + LoggerType, + MegatronLMPlugin, + PrecisionType, + ProjectConfiguration, + RNGType, + TorchDynamoPlugin, + convert_outputs_to_fp32, + is_deepspeed_available, + parse_choice_from_env, +) +from accelerate.utils.operations import _gpu_gather +from torch.optim.lr_scheduler import LRScheduler + + +if is_deepspeed_available(): + import deepspeed + from accelerate.utils import ( + DeepSpeedEngineWrapper, + DeepSpeedOptimizerWrapper, + DeepSpeedSchedulerWrapper, + DummyOptim, + DummyScheduler, + ) + +from .state import GaudiAcceleratorState, GaudiPartialState +from .utils import GaudiDistributedType + + +logger = get_logger(__name__) + + +# We pass cloned tensors to torch.save() to avoid checkpoint bloat that occurs when torch.save() +# saves the underlying storage rather than the slice of the storage corresponding to individual tensors. +# This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers. +# Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size. +# It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat. +# See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing +# TODO: remove this method when it is available in Habana's DeepSpeed fork +def clone_tensors_for_torch_save(item, device=torch.device("cpu")): + """ + Taken from: https://github.com/microsoft/DeepSpeed/blob/09601bb811b28fb0db92b6dcb2b737873e6677e8/deepspeed/checkpoint/utils.py#L41 + + Returns a copy of `item` with all enclosed tensors replaced by clones on a specified device. + Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. + + Parameters: + - `item`: tensor to clone or (possibly nested) container of tensors to clone. + - `device`: target device (defaults to `cpu`) + + Returns: + - copy of `item` with cloned tensors on target device + """ + if torch.is_tensor(item): + return item.detach().clone().to(device) + elif isinstance(item, list): + return [clone_tensors_for_torch_save(v, device) for v in item] + elif isinstance(item, tuple): + return tuple([clone_tensors_for_torch_save(v, device) for v in item]) + elif isinstance(item, dict): + return type(item)({k: clone_tensors_for_torch_save(v, device) for k, v in item.items()}) + else: + return item + + +class GaudiAccelerator(Accelerator): + """ + Adapted from: https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/accelerator.py#L145 + + Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training. + + Args: + device_placement (`bool`, *optional*, defaults to `True`): + Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model, + etc...). + split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + mixed_precision (`str`, *optional*): + Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the + value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the + accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp16' + requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher. 'fp8' requires the installation of + transformers-engine. + gradient_accumulation_steps (`int`, *optional*, default to 1): + The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with + `Accelerator.accumulate`. If not passed, will default to the value in the environment variable + `ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`. + cpu (`bool`, *optional*): + Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force + the execution on one process only. + deepspeed_plugin (`DeepSpeedPlugin`, *optional*): + Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured + directly using *accelerate config* + fsdp_plugin (`FullyShardedDataParallelPlugin`, *optional*): + Tweak your FSDP related args using this argument. This argument is optional and can be configured directly + using *accelerate config* + megatron_lm_plugin (`MegatronLMPlugin`, *optional*): + Tweak your MegatronLM related args using this argument. This argument is optional and can be configured + directly using *accelerate config* + rng_types (list of `str` or [`~utils.RNGType`]): + The list of random number generators to synchronize at the beginning of each iteration in your prepared + dataloaders. Should be one or several of: + + - `"torch"`: the base torch random number generator + - `"cuda"`: the CUDA random number generator (GPU only) + - `"xla"`: the XLA random number generator (TPU only) + - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your + dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type. + + Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6. + log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*): + A list of loggers to be setup for experiment tracking. Should be one or several of: + + - `"all"` + - `"tensorboard"` + - `"wandb"` + - `"comet_ml"` + If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can + also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`. + project_config (`ProjectConfiguration`, *optional*): + A configuration for how saving the state can be handled. + project_dir (`str`, `os.PathLike`, *optional*): + A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved + checkpoints. + dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): + Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only + done under certain circumstances (at the end of each epoch, for instance). + kwargs_handlers (`list[KwargHandler]`, *optional*) + A list of `KwargHandler` to customize how the objects related to distributed training or mixed precision + are created. See [kwargs](kwargs) for more information. + dynamo_backend (`str` or `DynamoBackend`, *optional*, defaults to `"no"`): + Set to one of the possible dynamo backends to optimize your training with torch dynamo. + gradient_accumulation_plugin (`GradientAccumulationPlugin`, *optional*): + A configuration for how gradient accumulation should be handled, if more tweaking than just the + `gradient_accumulation_steps` is needed. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration. + - **local_process_index** (`int`) -- The process index on the current machine. + - **mixed_precision** (`str`) -- The configured mixed precision mode. + - **num_processes** (`int`) -- The total number of processes used for training. + - **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of + gradient overflow in mixed precision), in which + case the learning rate should not be changed. + - **process_index** (`int`) -- The overall index of the current process among all processes. + - **state** ([`~state.AcceleratorState`]) -- The distributed setup state. + - **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes. + - **use_distributed** (`bool`) -- Whether the current configuration is for distributed training. + """ + + def __init__( + self, + device_placement: bool = True, + split_batches: bool = False, + mixed_precision: PrecisionType | str | None = None, + gradient_accumulation_steps: int = 1, + cpu: bool = False, + deepspeed_plugin: DeepSpeedPlugin | None = None, + fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + megatron_lm_plugin: MegatronLMPlugin | None = None, + rng_types: list[str | RNGType] | None = None, + log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, + project_dir: str | os.PathLike | None = None, + project_config: ProjectConfiguration | None = None, + gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, + dispatch_batches: bool | None = None, + even_batches: bool = True, + step_scheduler_with_optimizer: bool = True, + kwargs_handlers: list[KwargsHandler] | None = None, + dynamo_backend: DynamoBackend | str | None = None, + ): + if project_config is not None: + self.project_configuration = project_config + else: + self.project_configuration = ProjectConfiguration(project_dir=project_dir) + if project_dir is not None and self.project_dir is None: + self.project_configuration.set_directories(project_dir) + if mixed_precision is not None: + mixed_precision = str(mixed_precision) + if mixed_precision not in PrecisionType: + raise ValueError( + f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}" + ) + elif mixed_precision == "fp16": + raise ValueError("fp16 is not supported on Habana Gaudi.") + + dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend) + + if deepspeed_plugin is None: # init from env variables + deepspeed_plugin = ( + DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None + ) + else: + if not isinstance(deepspeed_plugin, DeepSpeedPlugin): + raise TypeError("`deepspeed_plugin` must be an `accelerate.utils.DeepSpeedPlugin` object.") + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided + if deepspeed_plugin: + if not is_deepspeed_available(): + raise ImportError( + "DeepSpeed is not installed => run `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.10.0`." + ) + + mixed_precision = ( + os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision + ) + deepspeed_plugin.set_mixed_precision(mixed_precision) + deepspeed_plugin.set_deepspeed_weakref() + + # Kwargs handlers + self.ddp_handler = None + self.scaler_handler = None + self.init_handler = None + self.fp8_recipe_handler = None + if kwargs_handlers is not None: + for handler in kwargs_handlers: + assert isinstance( + handler, KwargsHandler + ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`." + if isinstance(handler, DistributedDataParallelKwargs): + if self.ddp_handler is not None: + raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.") + else: + self.ddp_handler = handler + elif isinstance(handler, GradScalerKwargs): + if self.scaler_handler is not None: + raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.") + else: + self.scaler_handler = handler + elif isinstance(handler, InitProcessGroupKwargs): + if self.init_handler is not None: + raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.") + else: + self.init_handler = handler + elif isinstance(handler, FP8RecipeKwargs): + if self.fp8_recipe_handler is not None: + raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.") + else: + self.fp8_recipe_handler = handler + + kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} + self.state = GaudiAcceleratorState( + mixed_precision=mixed_precision, + cpu=cpu, + dynamo_plugin=dynamo_plugin, + deepspeed_plugin=deepspeed_plugin, + fsdp_plugin=fsdp_plugin, + megatron_lm_plugin=megatron_lm_plugin, + _from_accelerator=True, + **kwargs, + ) + + trackers = filter_trackers(log_with, self.logging_dir) + if len(trackers) < 1 and log_with is not None: + warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.") + self.log_with = trackers + + if ( + (mixed_precision != "bf16") + and getattr(self.state, "downcast_bfloat", False) + and (self.state.distributedType != DistributedType.TPU) + ): + raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU") + + if gradient_accumulation_plugin is not None: + if gradient_accumulation_steps != 1: + raise ValueError( + "You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object." + ) + else: + gradient_accumulation_steps = int( + parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps) + ) + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps) + self.gradient_state = GradientState( + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + + self.device_placement = device_placement + self.split_batches = split_batches + self.dispatch_batches = dispatch_batches + self.even_batches = even_batches + self.step_scheduler_with_optimizer = step_scheduler_with_optimizer + + # Mixed precision attributes + self.scaler = None + self.native_amp = self.state.mixed_precision == "bf16" + + # Start of internal step tracking + self.step = 0 + + # Internal references to the training objects + self._optimizers = [] + self._models = [] + self._schedulers = [] + self._dataloaders = [] + self._custom_objects = [] + + # Hooks + self._load_model_state_pre_hook = OrderedDict() + self._save_model_state_pre_hook = OrderedDict() + + # RNG Types + self.rng_types = rng_types + if self.rng_types is None: + self.rng_types = ["generator"] + + @property + def use_fp16(self): + raise ValueError("fp16 is not supported on Habana Gaudi.") + + def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False): + """ + Prepares a PyTorch model for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + model (`torch.nn.Module`): + A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without + any kind of mixed precision + device_placement (`bool`, *optional*): + Whether or not to place the model on the proper device. Will default to `self.device_placement`. + evaluation_mode (`bool`, *optional*, defaults to `False`): + Whether or not to set the model for evaluation only, by just applying mixed precision and + `torch.compile` (if configured in the `Accelerator` object). + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume a model is defined + >>> model = accelerator.prepare_model(model) + ``` + """ + if device_placement is None: + device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP + self._models.append(model) + # We check only for models loaded with `accelerate` + # Checks if any of the child module has the attribute `hf_device_map`. + has_hf_device_map = False + for m in model.modules(): + if hasattr(m, "hf_device_map"): + has_hf_device_map = True + break + + if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( + model, "hf_device_map", False + ): + model_devices = set(model.hf_device_map.values()) + if len(model_devices) > 1 and self.distributed_type != DistributedType.NO: + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode." + " In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism." + " Therefore you should not specify that you are under any distributed regime in your accelerate config." + ) + current_device = list(model_devices)[0] + current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device + + if torch.device(current_device_index) != self.device: + # if on the first device (GPU 0) we don't care + if (self.device.index is not None) or (current_device_index != 0): + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision on a different device than the one " + "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}" + "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}" + ) + + if "cpu" in model_devices or "disk" in model_devices: + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload." + ) + elif device_placement and not has_hf_device_map: + model = model.to(self.device) + + if self.native_amp: + model._original_forward = model.forward + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + if self.mixed_precision == "bf16": + new_forward = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16)( + model_forward_func + ) + + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) + # elif self.mixed_precision == "fp8": + # if not has_transformer_engine_layers(model): + # with torch.no_grad(): + # convert_model(model) + # model._converted_to_transformer_engine = True + # model._original_forward = model.forward + + # kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {} + # if "fp8_format" in kwargs: + # kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) + # fp8_recipe = te_recipe.DelayedScaling(**kwargs) + # cuda_device_capacity = torch.cuda.get_device_capability() + # fp8_enabled = cuda_device_capacity[0] >= 9 or ( + # cuda_device_capacity[0] == 8 and cuda_device_capacity[1] >= 9 + # ) + # if not fp8_enabled: + # logger.warn( + # f"The current device has compute capability of {cuda_device_capacity} which is " + # "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " + # "or higher, compute capability of 8.9 or higher). Will use FP16 instead." + # ) + # model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward) + if not evaluation_mode: + if self.distributed_type == GaudiDistributedType.MULTI_HPU: + if any(p.requires_grad for p in model.parameters()): + kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} + model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) + # torch.compile should be called last. + if self.state.dynamo_plugin.backend != DynamoBackend.NO: + model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) + return model + + def _prepare_deepspeed(self, *args): + deepspeed_plugin = self.state.deepspeed_plugin + + is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto" or is_dataloader_present: + result = [ + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + for obj in args + ] + + batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")] + if self.split_batches: + batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes] + + if any(bs is None for bs in batch_sizes): + raise ValueError( + "At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size." + "Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file" + "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`." + ) + if len(batch_sizes) == 0: + raise ValueError( + "When using DeepSpeed `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders " + "or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file" + "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`." + ) + + batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes) + if len(batch_sizes) > 1: + logger.info( + "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here " + f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})." + ) + else: + batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] + result = list(args) + + if self.gradient_accumulation_steps != deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]: + logger.info( + f"Updating DeepSpeed's gradient accumulation steps to {self.gradient_accumulation_steps} from " + f"{deepspeed_plugin.deepspeed_config['gradient_accumulation_steps']}." + ) + deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"] = self.gradient_accumulation_steps + config_kwargs = { + "train_micro_batch_size_per_gpu": batch_size_per_device, + "train_batch_size": batch_size_per_device + * deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"] + * self.num_processes, + "gradient_clipping": 1.0, + "zero_optimization.stage3_gather_16bit_weights_on_model_save": False, + } + + model = None + optimizer = None + scheduler = None + for obj in result: + if isinstance(obj, torch.nn.Module): + model = obj + elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)): + optimizer = obj + elif (isinstance(obj, (LRScheduler, DummyScheduler))) or ( + type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES + ): + scheduler = obj + + if optimizer is not None: + if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)): + raise ValueError( + "You cannot specify an optimizer in the config file and in the code at the same time. " + "Please remove the optimizer from the config file or " + "create `accelerate.utils.DummyOptim` in the code." + ) + elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)): + raise ValueError( + "You cannot create a `DummyOptim` without specifying an optimizer in the config file." + ) + + if isinstance(optimizer, (torch.optim.Optimizer)): + deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True + + if scheduler is not None: + if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)): + raise ValueError( + "You cannot specify a scheduler in the config file and in the code at the same time. " + "Please remove the scheduler from the config file or " + "create `accelerate.utils.DummyScheduler` in the code." + ) + elif "scheduler" not in deepspeed_plugin.deepspeed_config and isinstance(scheduler, (DummyScheduler)): + raise ValueError( + "You cannot create a `DummyScheduler` without specifying a scheduler in the config file." + ) + + if optimizer is not None and scheduler is not None: + if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)): + raise ValueError( + "You can only specify `accelerate.utils.DummyScheduler` in the code when using " + "`accelerate.utils.DummyOptim`." + ) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + } + ) + + if isinstance(optimizer, (DummyOptim)): + config_kwargs.update( + {"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay} + ) + if isinstance(scheduler, (DummyScheduler)): + max_lr = ( + getattr(scheduler.optimizer, "lr", None) + if getattr(scheduler.optimizer, "defaults", None) is None + else scheduler.optimizer.defaults["lr"] + ) + config_kwargs.update( + { + "scheduler.params.warmup_min_lr": 0, + "scheduler.params.warmup_max_lr": max_lr, + "scheduler.params.warmup_num_steps": scheduler.warmup_num_steps, + } + ) + if scheduler.total_num_steps is not None: + config_kwargs["scheduler.params.total_num_steps"] = ( + math.ceil(scheduler.total_num_steps / self.num_processes) + if not self.split_batches + else scheduler.total_num_steps + ) + deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs) + self.deepspeed_config = deepspeed_plugin.deepspeed_config + kwargs = {"model": model, "config_params": self.deepspeed_config} + if optimizer is not None: + if isinstance(optimizer, (DummyOptim)): + kwargs["model_parameters"] = optimizer.params + else: + if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get( + "device", "none" + ) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True): + from deepspeed.ops.adam import DeepSpeedCPUAdam + + defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]} + optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults) + kwargs["optimizer"] = optimizer + if scheduler is not None: + if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES: + kwargs["lr_scheduler"] = scheduler + + HabanaArgs = make_dataclass("HabanaArgs", [("use_hpu", bool), ("no_cuda", bool)]) + habana_args = HabanaArgs( + use_hpu=True if self.device.type == "hpu" else False, + no_cuda=True if self.device.type == "cpu" else False, + ) + if habana_args.use_hpu: + # This env variable is initialized here to make sure it is set to "true" + # It should be done by the launcher but it does not work for multi-node runs + os.environ["DEEPSPEED_USE_HPU"] = "true" + + engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + if optimizer is not None: + optimizer = DeepSpeedOptimizerWrapper(optimizer) + if scheduler is not None: + if lr_scheduler is None: + scheduler = AcceleratedScheduler( + scheduler, + optimizer, + step_with_optimizer=self.step_scheduler_with_optimizer, + split_batches=self.split_batches, + ) + else: + scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer) + + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = engine + elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)): + result[i] = optimizer + elif (isinstance(result[i], (LRScheduler, DummyScheduler))) or ( + type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES + ): + result[i] = scheduler + # pointing for deepspeed_engine_wrapped.backward() + self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine) + self._models.append(engine) + if optimizer is not None: + self._optimizers.append(optimizer) + if scheduler is not None: + self._schedulers.append(scheduler) + if len(self._models) > 1: + raise AssertionError( + "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed" + ) + return tuple(result) + + def gather(self, tensor): + """ + Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to + regroup the predictions from all processes when doing evaluation. + + Note: + This gather happens in all processes. + + Args: + tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`): + The tensors to gather across all processes. + + Returns: + `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the + first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors. + + Example: + + ```python + >>> # Assuming four processes + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> process_tensor = torch.tensor([accelerator.process_index]) + >>> gathered_tensor = accelerator.gather(process_tensor) + >>> gathered_tensor + tensor([0, 1, 2, 3]) + ``` + """ + if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]: + return _gpu_gather(tensor) + else: + return tensor + + def get_state_dict(self, model, unwrap=True): + """ + Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full + precision. + + Args: + model (`torch.nn.Module`): + A PyTorch model sent through [`Accelerator.prepare`] + unwrap (`bool`, *optional*, defaults to `True`): + Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict + + Returns: + `dict`: The state dictionary of the model potentially without full precision. + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> net = torch.nn.Linear(2, 2) + >>> net = accelerator.prepare(net) + >>> state_dict = accelerator.get_state_dict(net) + ``` + """ + + if self.distributed_type == DistributedType.DEEPSPEED: + if self.deepspeed_config["zero_optimization"]["stage"] == 3: + if model.zero_gather_16bit_weights_on_model_save(): + state_dict = model._zero3_consolidated_16bit_state_dict() + else: + raise ValueError( + "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. " + "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or " + "set `zero3_save_16bit_model` to True when using `accelerate config`. " + "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights." + ) + else: + # from deepspeed.checkpoint.utils import clone_tensors_for_torch_save + state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict()) + else: + if unwrap: + model = self.unwrap_model(model) + state_dict = model.state_dict() + + return state_dict + + @contextmanager + def autocast(self, cache_enabled: bool = False): + """ + Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing + different will happen otherwise. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(mixed_precision="fp16") + >>> with accelerator.autocast(): + ... train() + ``` + """ + if self.native_amp: + autocast_context = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16) + else: + autocast_context = contextlib.nullcontext() + + autocast_context.__enter__() + yield + autocast_context.__exit__(*sys.exc_info()) diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py new file mode 100644 index 0000000000..9972db827d --- /dev/null +++ b/optimum/habana/accelerate/state.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from accelerate.state import AcceleratorState, PartialState +from accelerate.utils import is_deepspeed_available, parse_choice_from_env, parse_flag_from_env + +from optimum.utils import logging + +from .utils import GaudiDistributedType + + +logger = logging.get_logger() + + +class GaudiPartialState(PartialState): + """ + Adapted from: https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/state.py#L96 + + Singleton class that has information about the current training environment and functions to help with process + control. Designed to be used when only process control and device execution states are needed. Does *not* need to + be initialized from `Accelerator`. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`GaudiDistributedType`]) -- The type of distributed environment currently + in use. + - **local_process_index** (`int`) -- The index of the current process on the current server. + - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type + of mixed precision being performed. + - **num_processes** (`int`) -- The number of processes currently launched in parallel. + - **process_index** (`int`) -- The index of the current process. + - **is_last_process** (`bool`) -- Whether or not the current process is the last one. + - **is_main_process** (`bool`) -- Whether or not the current process is the main one. + - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. + """ + + def __init__(self, cpu: bool = False, **kwargs): + self.__dict__ = self._shared_state + if not self.initialized: + self._cpu = cpu + self.backend = None + env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) + self.device = torch.device(env_device) if env_device is not None else None + + if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + self.backend = kwargs.pop("backend", "hccl") + + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + if not is_deepspeed_available(): + raise ImportError( + "DeepSpeed is not available, install it with: `pip install" + " git+https://github.com/HabanaAI/DeepSpeed.git@1.10.0`." + ) + self.distributed_type = GaudiDistributedType.DEEPSPEED + if not torch.distributed.is_initialized(): + import deepspeed + + if world_size > 1: + os.environ["HLS_MODULE_ID"] = str(local_rank) + os.environ["ID"] = str(rank) + + deepspeed.init_distributed(dist_backend=self.backend, **kwargs) + logger.info("DeepSpeed is enabled.") + self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config + else: + self.distributed_type = GaudiDistributedType.MULTI_HPU + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size) + logger.info("Enabled distributed run.") + self.num_processes = world_size + self.process_index = rank + self.local_process_index = local_rank + if self.device is None: + self.device = torch.device("hpu", self.local_process_index) + else: + self.distributed_type = GaudiDistributedType.NO + self.num_processes = 1 + self.process_index = self.local_process_index = 0 + logger.info("Single-device run.") + + if self.device is None: + self.device = torch.device("cpu") if cpu else self.default_device + + self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) + + def wait_for_everyone(self): + """ + Will stop the execution of the current process until every other process has reached that point (so this does + nothing when the script is only run in one process). Useful to do before saving a model. + + Example: + + ```python + >>> # Assuming two GPU processes + >>> import time + >>> from accelerate.state import PartialState + + >>> state = PartialState() + >>> if state.is_main_process: + ... time.sleep(2) + >>> else: + ... print("I'm waiting for the main process to finish its sleep...") + >>> state.wait_for_everyone() + >>> # Should print on every process at the same time + >>> print("Everyone is here") + ``` + """ + if self.distributed_type in ( + GaudiDistributedType.MULTI_CPU, + GaudiDistributedType.DEEPSPEED, + GaudiDistributedType.MULTI_HPU, + ): + torch.distributed.barrier() + + @property + def default_device(self) -> torch.device: + """ + Returns the default device which is: + - HPU if it is available + - CPU otherwise + """ + import habana_frameworks.torch.hpu as hthpu + + if hthpu.is_available(): + return torch.device("hpu") + else: + return torch.device("cpu") + + +class GaudiAcceleratorState(AcceleratorState): + """ + Adapted from: https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/state.py#L683 + + Singleton class that has information about the current training environment. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`GaudiDistributedType`]) -- The type of distributed environment currently + in use. + - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`. + - **local_process_index** (`int`) -- The index of the current process on the current server. + - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type + of mixed precision being performed. + - **num_processes** (`int`) -- The number of processes currently launched in parallel. + - **process_index** (`int`) -- The index of the current process. + - **is_last_process** (`bool`) -- Whether or not the current process is the last one. + - **is_main_process** (`bool`) -- Whether or not the current process is the main one. + - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. + """ + + def __init__( + self, + mixed_precision: str = None, + cpu: bool = False, + dynamo_plugin=None, + deepspeed_plugin=None, + fsdp_plugin=None, + megatron_lm_plugin=None, + _from_accelerator: bool = False, + **kwargs, + ): + self.__dict__ = self._shared_state + if parse_flag_from_env("ACCELERATE_USE_CPU"): + cpu = True + if GaudiPartialState._shared_state == {}: + GaudiPartialState(cpu, **kwargs) + self.__dict__.update(GaudiPartialState._shared_state) + self._check_initialized(mixed_precision, cpu) + if not self.initialized: + self.deepspeed_plugin = None + mixed_precision = ( + parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") + if mixed_precision is None + else mixed_precision.lower() + ) + self.dynamo_plugin = dynamo_plugin + # deepspeed handles mixed_precision using deepspeed_config + self._mixed_precision = ( + "no" if self.distributed_type == GaudiDistributedType.DEEPSPEED else mixed_precision + ) + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: + self.deepspeed_plugin = deepspeed_plugin + GaudiPartialState._shared_state["distributed_type"] = self.distributed_type + + @property + def mixed_precision(self): + if self.distributed_type == GaudiDistributedType.DEEPSPEED: + config = self.deepspeed_plugin.deepspeed_config + if config.get("fp16", {}).get("enabled", False): + mixed_precision = "fp16" + elif config.get("bf16", {}).get("enabled", False): + mixed_precision = "bf16" + else: + mixed_precision = "no" + else: + mixed_precision = self._mixed_precision + + if mixed_precision == "fp16": + raise ValueError("fp16 is not supported on Habana Gaudi.") + + return mixed_precision diff --git a/optimum/habana/accelerate/utils/__init__.py b/optimum/habana/accelerate/utils/__init__.py new file mode 100644 index 0000000000..17eb21be03 --- /dev/null +++ b/optimum/habana/accelerate/utils/__init__.py @@ -0,0 +1 @@ +from .dataclasses import GaudiDistributedType diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py new file mode 100644 index 0000000000..9024a7ea7b --- /dev/null +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class GaudiDistributedType(str, Enum): + """ + Represents a type of distributed environment. + Adapted from: https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/utils/dataclasses.py#L176 + + Values: + + - **NO** -- Not a distributed environment, just a single process. + - **MULTI_HPU** -- Distributed on multiple HPUs. + - **DEEPSPEED** -- Using DeepSpeed. + """ + + # Subclassing str as well as Enum allows the `GaudiDistributedType` to be JSON-serializable out of the box. + NO = "NO" + MULTI_HPU = "MULTI_HPU" + DEEPSPEED = "DEEPSPEED" diff --git a/optimum/habana/distributed/distributed_runner.py b/optimum/habana/distributed/distributed_runner.py index 4fe7f311a8..251a8f3fc4 100644 --- a/optimum/habana/distributed/distributed_runner.py +++ b/optimum/habana/distributed/distributed_runner.py @@ -113,8 +113,8 @@ def __init__( self.create_single_card_setup(use_deepspeed) def get_peval(self): - cmd1 = "lscpu 2>/dev/null | awk '/Socket\(s\)/ { print $2 }'" - cmd2 = "lscpu 2>/dev/null | awk '/Core\(s\) per socket/ { print $4 }'" + cmd1 = r"lscpu 2>/dev/null | awk '/Socket\(s\)/ { print $2 }'" + cmd2 = r"lscpu 2>/dev/null | awk '/Core\(s\) per socket/ { print $4 }'" with subprocess.Popen( cmd1, shell=True, executable="/bin/bash", stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) as proc: diff --git a/optimum/habana/transformers/deepspeed.py b/optimum/habana/transformers/deepspeed.py index 532f1d86c6..feb0eb0de2 100644 --- a/optimum/habana/transformers/deepspeed.py +++ b/optimum/habana/transformers/deepspeed.py @@ -15,9 +15,6 @@ """ Integration with Deepspeed """ -import os -from copy import deepcopy -from dataclasses import make_dataclass import torch from transformers.deepspeed import ( @@ -36,8 +33,11 @@ class GaudiTrainerDeepSpeedConfig(HfTrainerDeepSpeedConfig): """ - The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the - same lifespan as the latter. + Adapted from: https://github.com/huggingface/transformers/blob/e42587f596181396e1c4b63660abf0c736b10dae/src/transformers/deepspeed.py#L69 + + The differences are: + - disable DeepSpeed version check as we run a custom version on HPU + - remove uncompatible args (e.g. fp16) in config processing """ def __init__(self, config_file_or_dict): @@ -84,43 +84,30 @@ def trainer_config_process(self, args): self._dtype = torch.float32 -def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): +def deepspeed_init(trainer, num_training_steps, inference=False): """ - Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. - If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made. - Args: - trainer: Trainer object - num_training_steps: per single HPU - resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load - inference: launch in inference mode (no optimizer and no lr scheduler) - Returns: model, optimizer, lr_scheduler - We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: - https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it - can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 + Adapted from: https://github.com/huggingface/transformers/blob/e42587f596181396e1c4b63660abf0c736b10dae/src/transformers/deepspeed.py#L316 + + The difference is: + - add a workaround to cast the model to the target dtype """ - import deepspeed from deepspeed.utils import logger as ds_logger model = trainer.model args = trainer.args - if hasattr(trainer, "hf_deepspeed_config_orig"): - hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig) - else: - hf_deepspeed_config = args.hf_deepspeed_config - trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config) + hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config + + # TODO: temporary workaround + # To remove when it is solved, see https://github.com/HabanaAI/Model-References/blob/17fbab7ceebca15b1560ffb2c4e15a3888bb5f33/PyTorch/nlp/pretraining/deepspeed-bert/run_pretraining.py#L527 + model.to(dtype=hf_deepspeed_config.dtype(), device="hpu") # resume config update - some bits like `model` and `num_training_steps` only become available during train hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - config = hf_deepspeed_config.config # set the Deepspeed log level consistent with the Trainer ds_logger.setLevel(args.get_process_log_level()) - # TODO: temporary workaround - # To remove when it is solved, see https://github.com/HabanaAI/Model-References/blob/17fbab7ceebca15b1560ffb2c4e15a3888bb5f33/PyTorch/nlp/pretraining/deepspeed-bert/run_pretraining.py#L527 - model.to(dtype=hf_deepspeed_config._dtype, device="hpu") - if inference: # only Z3 makes sense for the inference if not hf_deepspeed_config.is_zero3(): @@ -133,48 +120,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf model_parameters = None else: trainer.optimizer = None # important for when deepspeed_init is used as re-init - optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + optimizer, lr_scheduler = deepspeed_optim_sched( + trainer, hf_deepspeed_config, args, num_training_steps, model_parameters + ) # keep for quick debug: # from pprint import pprint; pprint(config) - HabanaArgs = make_dataclass("HabanaArgs", [("use_hpu", bool), ("no_cuda", bool)]) - habana_args = HabanaArgs(use_hpu=args.use_habana, no_cuda=args.no_cuda) - if args.use_habana: - # This env variable is initialized here to make sure it is set to "true" - # It should be done by the launcher but it does not work for multi-node runs - os.environ["DEEPSPEED_USE_HPU"] = "true" - - kwargs = { - "args": habana_args, - "model": model, - "model_parameters": model_parameters, - "config_params": config, - "optimizer": optimizer, - "lr_scheduler": lr_scheduler, - } - - deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) - - if resume_from_checkpoint is not None: - # it's possible that the user is trying to resume from model_path, which doesn't necessarily - # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's - # a resume from a checkpoint and not just a local pretrained weight. So we check here if the - # path contains what looks like a deepspeed checkpoint - import glob - - deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*")) - - if len(deepspeed_checkpoint_dirs) > 0: - logger.info(f"Attempting to resume from {resume_from_checkpoint}") - # this magically updates self.optimizer and self.lr_scheduler - load_path, _ = deepspeed_engine.load_checkpoint( - resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True - ) - if load_path is None: - raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") - else: - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") - - return deepspeed_engine, optimizer, lr_scheduler + return optimizer, lr_scheduler diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1c2ba5f82c..2821327c20 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -54,6 +54,8 @@ if TYPE_CHECKING: + from transformers import PreTrainedModel + from .streamers import BaseStreamer @@ -133,6 +135,8 @@ def _update_model_kwargs_for_generation( model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state # update token_type_ids with last value if "token_type_ids" in model_kwargs: @@ -173,36 +177,6 @@ def _update_model_kwargs_for_generation( return model_kwargs - # TODO: remove this method when Transformers v4.31 is released since it solves the issue with Llama - def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py - - Remove `token_type_ids` from model_kwargs, which is not used for llama model - """ - if self.config.is_encoder_decoder: - for key in ["decoder_input_ids"]: - model_kwargs.pop(key, None) - if self.config.model_type == "llama": - for key in ["token_type_ids"]: - model_kwargs.pop(key, None) - - unused_model_args = [] - model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) - if "kwargs" in model_args or "model_kwargs" in model_args: - model_args |= set(inspect.signature(self.forward).parameters) - for key, value in model_kwargs.items(): - if value is not None and key not in model_args: - unused_model_args.append(key) - - if unused_model_args: - raise ValueError( - f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" - " generate arguments will also show up in this list)" - ) - @torch.no_grad() def generate( self, @@ -212,6 +186,7 @@ def generate( stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, hpu_graphs: Optional[bool] = False, @@ -267,6 +242,11 @@ def generate( Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -278,7 +258,7 @@ def generate( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. - kwargs: + kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. @@ -323,7 +303,7 @@ def generate( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." " Please use a generation configuration file (see" - " https://huggingface.co/docs/transformers/main_classes/text_generation)" + " https://huggingface.co/docs/transformers/main_classes/text_generation )" ) self.generation_config = new_generation_config generation_config = self.generation_config @@ -368,7 +348,12 @@ def generate( # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs @@ -392,9 +377,14 @@ def generate( # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if generation_config.pad_token_id is not None: position = model_kwargs["token_idx"] - 1 if "token_idx" in model_kwargs else -1 - if torch.sum(inputs_tensor[:, position] == generation_config.pad_token_id) > 0: + if ( + len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, position] == generation_config.pad_token_id) > 0 + ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." @@ -409,17 +399,14 @@ def generate( # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id, - model_kwargs=model_kwargs, device=inputs_tensor.device, ) - - # conditional generation for multi-modal models. - if "input_ids" in model_kwargs and model_input_name == "pixel_values": - input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") @@ -439,13 +426,13 @@ def generate( elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: - logger.warn( + logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( @@ -508,6 +495,14 @@ def generate( and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) + is_assisted_gen_mode = False + if assistant_model is not None: + if not (is_greedy_gen_mode or is_sample_gen_mode): + raise ValueError( + "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " + "is only supported with Greedy Search and Sample." + ) + is_assisted_gen_mode = True if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -562,11 +557,49 @@ def generate( self.htcore_generation = htcore # 10. go into different generation modes + if is_assisted_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("assisted generate requires `use_cache=True`") + + # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs + if assistant_model.config.is_encoder_decoder: + assistant_model_kwargs = copy.deepcopy(model_kwargs) + inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs + ) + assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_model_kwargs, model_input_name + ) + model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] + + # 12. run assisted generate + return self.assisted_decoding( + input_ids, + assistant_model=assistant_model, + do_sample=generation_config.do_sample, + logits_processor=logits_processor, + logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " greedy search." + "num_return_sequences has to be 1 when doing greedy search, " + f"but is {generation_config.num_return_sequences}." ) # 11. run greedy search @@ -590,9 +623,11 @@ def generate( elif is_contrastive_search_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " contrastive search." + "num_return_sequences has to be 1 when doing contrastive search, " + f"but is {generation_config.num_return_sequences}." ) + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") return self.contrastive_search( input_ids, @@ -732,6 +767,11 @@ def generate( if generation_config.num_beams % generation_config.num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + if generation_config.diversity_penalty == 0.0: + raise ValueError( + "`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical." + ) + if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") @@ -880,7 +920,7 @@ def contrastive_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, @@ -990,7 +1030,7 @@ def greedy_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = None, @@ -1230,13 +1270,19 @@ def greedy_search( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - hb_profer.step() - # stop if we exceed the maximum length, or when each sentence is finished (eager mode only) - if (not ignore_eos and unfinished_sequences.max() == 0) or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: + # stop when each sentence is finished + if not ignore_eos and unfinished_sequences.max() == 0: this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + hb_profer.step() + + if this_peer_finished and not synced_gpus: + break + hb_profer.stop() if streamer is not None: streamer.end() @@ -1275,7 +1321,7 @@ def sample( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = None, @@ -1539,15 +1585,19 @@ def sample( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - # if lazy_mode and not hpu_graphs: - # self.htcore_generation.mark_step() - hb_profer.step() - # stop if we exceed the maximum length, or when each sentence is finished (eager mode only) - if (not ignore_eos and unfinished_sequences.max() == 0) or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: + # stop when each sentence is finished + if not ignore_eos and unfinished_sequences.max() == 0: this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + hb_profer.step() + + if this_peer_finished and not synced_gpus: + break + hb_profer.stop() if streamer is not None: streamer.end() @@ -1586,7 +1636,7 @@ def beam_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -1871,8 +1921,6 @@ def beam_search( # increase cur_len cur_len = cur_len + 1 - # if lazy_mode and not hpu_graphs: - # self.htcore_generation.mark_step() hb_profer.step() if stopping_criteria(input_ids, scores) or (beam_scorer.is_done and not lazy_mode): if not synced_gpus: @@ -1934,7 +1982,7 @@ def beam_sample( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -2080,7 +2128,7 @@ def group_beam_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -2550,3 +2598,130 @@ def constrained_beam_search( ) else: return sequence_outputs["sequences"] + + def assisted_decoding( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + do_sample: bool = False, + logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + lazy_mode: Optional[bool] = False, + profiling_warmup_steps: Optional[int] = 0, + profiling_steps: Optional[int] = 0, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** or + **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text, + speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use + generate() instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + lazy_mode (`bool`, *optional*, defaults to `False`): + Whether the run is executed in lazy mode or not (i.e. eager mode). + profiling_warmup_steps (`int`, *optional*, defaults to 0): + Number of steps to ignore for profling. + profiling_steps (`int`, *optional*, defaults to 0): + Number of steps to be captured when enabling profiling. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> outputs = model.assisted_decoding( + ... input_ids, + ... assistant_model=assistant_model, + ... logits_processor=logits_processor, + ... stopping_criteria=stopping_criteria, + ... ) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + raise NotImplementedError("Assisted decoding is not supported by optimum-habana yet.") diff --git a/optimum/habana/transformers/models/albert/modeling_albert.py b/optimum/habana/transformers/models/albert/modeling_albert.py index a0b6f67f83..2237475c41 100644 --- a/optimum/habana/transformers/models/albert/modeling_albert.py +++ b/optimum/habana/transformers/models/albert/modeling_albert.py @@ -28,9 +28,9 @@ def gaudi_albert_forward( position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[None] = None, - output_hidden_states: Optional[None] = None, - return_dict: Optional[None] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[BaseModelOutputWithPooling, Tuple]: """ Same as https://github.com/huggingface/transformers/blob/a9eee2ffecc874df7dd635b2c6abb246fdb318cc/src/transformers/models/albert/modeling_albert.py#L689 diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index a4dcc6b897..b76aa05a64 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -183,7 +183,7 @@ def gaudi_bloom_attention_forward( # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) - # change view [batch_size, num_heads, q_length, head_dim] + # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 4fda9abec5..f1fde39b12 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -25,8 +25,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 1, 1, max_positions, max_positions ), + persistent=False, ) - self.register_buffer("masked_bias", torch.tensor(-1e4)) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 3d9683c678..e2b3fbe9c0 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -111,12 +111,14 @@ def gaudi_gpt_neox_layer_forward( token_idx=token_idx, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + attn_output = self.post_attention_dropout(attn_output) outputs = attention_layer_outputs[1:] if self.use_parallel_residual: # pseudocode: # x = x + attn(ln1(x)) + mlp(ln2(x)) mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + mlp_output = self.post_mlp_dropout(mlp_output) hidden_states = mlp_output + attn_output + hidden_states else: # pseudocode: @@ -124,6 +126,7 @@ def gaudi_gpt_neox_layer_forward( # x = x + mlp(ln2(x)) attn_output = attn_output + hidden_states mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + mlp_output = self.post_mlp_dropout(mlp_output) hidden_states = mlp_output + attn_output if use_cache: @@ -213,7 +216,7 @@ def gaudi_gpt_neox_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_in(input_ids) - hidden_states = inputs_embeds + hidden_states = self.emb_dropout(inputs_embeds) if self.gradient_checkpointing and self.training: if use_cache: @@ -345,7 +348,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, token_idx=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs ): input_shape = input_ids.shape @@ -371,10 +374,19 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "token_idx": token_idx, - } + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + "token_idx": token_idx, + } + ) + + return model_inputs diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 5b14e45b52..5dd9f4a3a5 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -427,6 +427,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index f53ff0e36d..e7462b7685 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -2,9 +2,9 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn as nn +import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.llama.modeling_llama import LlamaForCausalLM, apply_rotary_pos_emb, logger +from transformers.models.llama.modeling_llama import LlamaForCausalLM, apply_rotary_pos_emb, logger, repeat_kv def gaudi_llama_attention_forward( @@ -24,9 +24,31 @@ def gaudi_llama_attention_forward( - optimize KV cache """ bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if token_idx is None: @@ -35,7 +57,7 @@ def gaudi_llama_attention_forward( kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] + if past_key_value is not None: # reuse k, v, self_attention past_key = past_key_value[0] @@ -50,35 +72,49 @@ def gaudi_llama_attention_forward( value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) + if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - dtype_min = torch.tensor( - torch.finfo(attn_weights.dtype).min, device=attn_weights.device, dtype=attn_weights.dtype - ) - attn_weights = torch.max(attn_weights, dtype_min) + # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + if not output_attentions: attn_weights = None + return attn_output, attn_weights, past_key_value @@ -98,7 +134,9 @@ def gaudi_llama_decoder_layer_forward( - add new args token_idx """ residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -110,16 +148,20 @@ def gaudi_llama_decoder_layer_forward( token_idx=token_idx, ) hidden_states = residual + hidden_states + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) + return outputs @@ -146,7 +188,9 @@ def gaudi_llama_model_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -156,11 +200,14 @@ def gaudi_llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + seq_length_with_past = seq_length past_key_values_length = 0 + if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -169,6 +216,7 @@ def gaudi_llama_model_forward( position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions @@ -177,21 +225,27 @@ def gaudi_llama_model_forward( attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + hidden_states = inputs_embeds + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + if self.gradient_checkpointing and self.training: def create_custom_forward(module): @@ -218,15 +272,21 @@ def custom_forward(*inputs): use_cache=use_cache, token_idx=token_idx, ) + hidden_states = layer_outputs[0] + if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: all_self_attns += (layer_outputs[1],) + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -283,7 +343,13 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: @@ -291,7 +357,7 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() + loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index fee4dcf7e8..0fa89f9e23 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -1,4 +1,3 @@ -import random from typing import List, Optional, Tuple, Union import torch @@ -114,7 +113,9 @@ def gaudi_opt_attention_forward( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) from habana_frameworks.torch.hpex import hmp @@ -168,9 +169,9 @@ def gaudi_opt_decoder_layer_forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -280,6 +281,11 @@ def gaudi_opt_decoder_forward( # embed positions if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) causal_attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -316,9 +322,10 @@ def gaudi_opt_decoder_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue past_key_value = past_key_values[idx] if past_key_values is not None else None diff --git a/optimum/habana/transformers/models/t5/modeling_t5.py b/optimum/habana/transformers/models/t5/modeling_t5.py index 06f88e5b0e..3d4eb15471 100644 --- a/optimum/habana/transformers/models/t5/modeling_t5.py +++ b/optimum/habana/transformers/models/t5/modeling_t5.py @@ -48,6 +48,7 @@ def gaudi_T5Attention_forward( # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + real_seq_length = seq_length if past_key_value is not None: diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index ccf08e0c7b..9452936c02 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -27,22 +27,19 @@ import numpy as np import torch -from packaging import version +from accelerate import skip_first_batches +from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.data.distributed import DistributedSampler -from tqdm.auto import tqdm from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.deepspeed import deepspeed_load_checkpoint from transformers.integrations import hp_params -from transformers.modeling_utils import PreTrainedModel, unwrap_model +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( - DistributedLengthGroupedSampler, - DistributedSamplerWithLoop, DistributedTensorGatherer, IterableDatasetShard, LengthGroupedSampler, @@ -53,7 +50,6 @@ nested_concat, nested_detach, nested_numpify, - nested_truncate, reissue_pt_warnings, ) from transformers.trainer_utils import ( @@ -70,18 +66,23 @@ get_last_checkpoint, has_length, ) -from transformers.training_args import TrainingArguments +from transformers.training_args import ParallelMode, TrainingArguments from transformers.utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_datasets_available, + is_peft_available, is_safetensors_available, ) from optimum.habana.distributed import all_reduce_gradients from optimum.utils import logging +from ..accelerate import GaudiAccelerator from ..utils import ( HabanaProfile, get_hpu_memory_stats, @@ -103,6 +104,10 @@ import safetensors.torch +if is_peft_available(): + from peft import PeftModel + + if TYPE_CHECKING: import optuna @@ -254,24 +259,16 @@ def __init__( "ignore", message="User provided device_type of 'cuda', but CUDA is not available. Disabling" ) + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to HPU disconnects the tied weights, so we have to retie them. + if self.args.use_habana and hasattr(model, "tie_weights"): + model.tie_weights() + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None - generator = None - if self.args.world_size <= 1: - generator = torch.Generator() - # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with - # `args.seed`) if data_seed isn't provided. - # Further on in this method, we default to `args.seed` instead. - if self.args.data_seed is None: - seed = int(torch.empty((), dtype=torch.int64).random_().item()) - else: - seed = self.args.data_seed - generator.manual_seed(seed) - - seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed - # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): @@ -283,56 +280,27 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: lengths = None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None - if self.args.world_size <= 1: - return LengthGroupedSampler( - self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, - lengths=lengths, - model_input_name=model_input_name, - generator=generator, - ) - else: - return DistributedLengthGroupedSampler( - self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - lengths=lengths, - model_input_name=model_input_name, - seed=seed, - ) + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) else: - if self.args.world_size <= 1: - num_samples = len(self.train_dataset) - if ( - self.args.use_lazy_mode - and not self.args.dataloader_drop_last - and len(self.train_dataset) % self.args.per_device_train_batch_size != 0 - ): - # Make the total number of samples divisible by the batch size in lazy mode if needed - num_samples += ( - self.args.per_device_train_batch_size - - len(self.train_dataset) % self.args.per_device_train_batch_size - ) - return RandomSampler(self.train_dataset, num_samples=num_samples, generator=generator) - else: - if self.args.use_lazy_mode and not self.args.dataloader_drop_last: - # Use a loop for HPUs when drop_last is False to have all batches have the same size - return DistributedSamplerWithLoop( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=seed, - ) - else: - return DistributedSampler( - self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=seed, - ) + num_samples = len(self.train_dataset) + if ( + self.args.use_lazy_mode + and not self.args.dataloader_drop_last + and len(self.train_dataset) % self.args.per_device_train_batch_size != 0 + and self.args.parallel_mode != ParallelMode.DISTRIBUTED + ): + # Make the total number of samples divisible by the batch size in lazy mode if needed + num_samples += ( + self.args.per_device_train_batch_size + - len(self.train_dataset) % self.args.per_device_train_batch_size + ) + return RandomSampler(self.train_dataset, num_samples=num_samples) def create_optimizer(self): """ @@ -397,14 +365,6 @@ def _tune_save_checkpoint(self): torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) def _wrap_model(self, model, training=True, dataloader=None): - if self.args.use_ipex: - dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 - model = self.ipex_optimize_model(model, training, dtype=dtype) - - # already initialized its own DDP and AMP - if self.deepspeed: - return self.deepspeed - # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model @@ -414,7 +374,7 @@ def _wrap_model(self, model, training=True, dataloader=None): if not training: return model - if self.args.local_rank != -1 and self.args.distribution_strategy == "ddp": + if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "ddp": kwargs = {} kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters @@ -428,68 +388,18 @@ def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_habana: kwargs["gradient_as_bucket_view"] = True - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[self.args.local_rank] if self.args._n_gpu != 0 and not self.args.use_habana else None, - output_device=self.args.local_rank if self.args._n_gpu != 0 and not self.args.use_habana else None, - **kwargs, - ) + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) if self.args.use_hpu_graphs_for_training: import habana_frameworks.torch as ht ht.hpu.ModuleCacher()(model=model, inplace=True) - # torch.compile() needs to be called after wrapping the model with FSDP or DDP - # to ensure that it accounts for the graph breaks required by those wrappers - if self.args.torch_compile: - model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) - return model - def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): - if self.args.adjust_throughput: - save_start = time.perf_counter() - - if self.control.should_log: - logs: Dict[str, float] = {} - - # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._nested_gather(tr_loss).mean().item() - - # reset tr_loss to zero - tr_loss -= tr_loss - logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) - logs["learning_rate"] = self._get_learning_rate() - - self._total_loss_scalar += tr_loss_scalar - self._globalstep_last_logged = self.state.global_step - self.store_flos() - - self.log(logs) - - metrics = None - if self.control.should_evaluate: - if isinstance(self.eval_dataset, dict): - metrics = {} - for eval_dataset_name, eval_dataset in self.eval_dataset.items(): - dataset_metrics = self.evaluate( - eval_dataset=eval_dataset, - ignore_keys=ignore_keys_for_eval, - metric_key_prefix=f"eval_{eval_dataset_name}", - ) - metrics.update(dataset_metrics) - else: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) - self._report_to_hp_search(trial, self.state.global_step, metrics) - - if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - if self.args.adjust_throughput: - self.log_evaluate_save_time += time.perf_counter() - save_start - def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -509,7 +419,7 @@ def train( ignore_keys_for_eval (`List[str]`, *optional*) A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs: + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments used to hide deprecated arguments """ if resume_from_checkpoint is False: @@ -556,7 +466,7 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and args.deepspeed is None: + if resume_from_checkpoint is not None and not self.is_deepspeed_enabled: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -576,72 +486,6 @@ def train( ignore_keys_for_eval=ignore_keys_for_eval, ) - def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: - """ - Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. - Compared to Transformers, it is also possible to enable non-blocking data copy. - """ - if isinstance(data, Mapping): - return type(data)({k: self._prepare_input(v) for k, v in data.items()}) - elif isinstance(data, (tuple, list)): - return type(data)(self._prepare_input(v) for v in data) - elif isinstance(data, torch.Tensor): - kwargs = {"device": self.args.device} - if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): - # NLP models inputs are int/uint and those get adjusted to the right dtype of the - # embedding. Other models such as wav2vec2's inputs are already float and thus - # may need special handling to match the dtypes of the model - kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) - if self.args.non_blocking_data_copy: - return data.to(**kwargs, non_blocking=True) - else: - return data.to(**kwargs) - return data - - def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. - - Subclass and override to inject custom behavior. - - Args: - model (`torch.nn.Module`): - The model to train. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - - Return: - `torch.Tensor`: The tensor with training loss on this batch. - """ - model.train() - inputs = self._prepare_inputs(inputs) - - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: - # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` - loss = loss / self.args.gradient_accumulation_steps - - if self.args.pipelining_fwd_bwd: - self.htcore.mark_step() - - if self.do_grad_scaling: - self.scaler.scale(loss).backward() - elif self.deepspeed: - # loss gets scaled under gradient_accumulation_steps in deepspeed - loss = self.deepspeed.backward(loss) - else: - loss.backward() - - return loss.detach() - def _inner_training_loop( self, batch_size=None, @@ -650,7 +494,9 @@ def _inner_training_loop( trial=None, ignore_keys_for_eval=None, ): + self.accelerator.free_memory() self._train_batch_size = batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -658,7 +504,7 @@ def _inner_training_loop( # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if has_length(train_dataloader): @@ -691,20 +537,26 @@ def _inner_training_loop( f" {args.max_steps}" ) + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps and args.logging_steps < 1: + args.logging_steps = math.ceil(max_steps * args.logging_steps) + if args.eval_steps and args.eval_steps < 1: + args.eval_steps = math.ceil(max_steps * args.eval_steps) + if args.save_steps and args.save_steps < 1: + args.save_steps = math.ceil(max_steps * args.save_steps) + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - if args.deepspeed: - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - else: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None @@ -742,10 +594,34 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): model = self._wrap_model(self.model_wrapped) + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # deepspeed ckpt loading + if resume_from_checkpoint is not None and self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -770,7 +646,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") @@ -800,17 +678,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( - f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " - "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " - "flag to your launch command, but you will resume the training on data already seen by your model." + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." ) - if self.is_local_process_zero() and not args.disable_tqdm: - steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - steps_trained_progress_bar.set_description("Skipping the first batches") # In multi-worker training: broadcast model parameters from worker:0 to all the others. # This must be done manually unless DistributedDataParallel is used. - if self.args.local_rank != -1 and self.args.distribution_strategy == "fast_ddp": + if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "fast_ddp": logger.debug( f"Broadcasting the model parameters to assure that each of {self.args.world_size} workers start the training from the same point." ) @@ -851,18 +725,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( - train_dataloader.sampler, RandomSampler - ) - if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler: - # We just need to begin an iteration to create the randomization of the sampler. - # That was before PyTorch 1.11 however... - for _ in train_dataloader: - break - else: - # Otherwise we need to call the whooooole sampler cause there is some random operation added - # AT THE VERY END! - _ = list(train_dataloader.sampler) + for _ in train_dataloader: + break if self.args.adjust_throughput: self.log_evaluate_save_time = 0 @@ -872,12 +736,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): hb_profiler = HabanaProfile(warmup=self.args.profiling_warmup_steps, active=self.args.profiling_steps) hb_profiler.start() + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch) - elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): - train_dataloader.dataset.set_epoch(epoch) - epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. @@ -894,8 +754,15 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) - step = -1 + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + step = -1 for step, inputs in enumerate(epoch_iterator): if ( args.throughput_warmup_steps > 0 @@ -904,6 +771,11 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): ): start_time_after_warmup = time.time() + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 @@ -919,26 +791,26 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # Proceed with forward and backward passes. - if ( - args.distribution_strategy == "ddp" - and ((step + 1) % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 - and args._no_sync_in_gradient_accumulation - ): - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: + # TODO: keep syncs for fast DDP? + with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) - is_optimization_step = (step + 1) % args.gradient_accumulation_steps == 0 or ( + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) + + is_optimization_step = ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch + is_last_step_and_steps_less_than_grad_acc ) - if args.local_rank != -1 and args.distribution_strategy == "fast_ddp" and is_optimization_step: + if ( + args.parallel_mode == ParallelMode.DISTRIBUTED + and args.distribution_strategy == "fast_ddp" + and is_optimization_step + ): all_reduce_gradients( model, use_hpu_graphs=True ) # use HPU graphs for gradient fusion regardless of args.use_hpu_graphs_for_training setting @@ -953,13 +825,14 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if args.use_lazy_mode: self.htcore.mark_step() - # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps - if self.deepspeed: - self.deepspeed.step() - if is_optimization_step: + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc: + self.accelerator.gradient_state._set_sync_gradients(True) + # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if hasattr(self.optimizer, "clip_grad_norm"): @@ -969,6 +842,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) elif self.gaudi_config.use_fused_clip_norm and args.use_habana: + # TODO: to merge self.accelerator.clip_grad_norm_ when HMP is removed self.FusedNorm.clip_norm(model.parameters()) else: # Revert to normal clipping otherwise @@ -980,16 +854,14 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): with self.hmp.disable_casts(): torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) else: - torch.nn.utils.clip_grad_norm_( + self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) # Optimizer step optimizer_was_run = True - if self.deepspeed: - pass # called outside the loop - elif ( + if ( args.use_habana and self.gaudi_config.use_habana_mixed_precision and (not self.gaudi_config.use_fused_adam) @@ -999,14 +871,17 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): self.optimizer.step() else: self.optimizer.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - if optimizer_was_run and not self.deepspeed: - self.lr_scheduler.step() + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() self._zero_model_grad(model) self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch if args.use_lazy_mode: self.htcore.mark_step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) @@ -1041,7 +916,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. - if args.local_rank != -1: + if args.parallel_mode == ParallelMode.DISTRIBUTED: torch.distributed.barrier() self._load_best_model() @@ -1085,6 +960,115 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): return TrainOutput(self.state.global_step, train_loss, metrics) + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + + model = self.model + if ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): + # TODO: check if the code below works + # if self.is_deepspeed_enabled: + # deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + # else: + has_been_loaded = True + if is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + has_been_loaded = False + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + + # If the model is on the GPU, it still works! + load_result = model.load_state_dict(state_dict, False) + + if has_been_loaded: + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False) + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + if self.args.adjust_throughput: + save_start = time.perf_counter() + + if self.control.should_log: + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + if isinstance(self.eval_dataset, dict): + metrics = {} + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + if self.args.adjust_throughput: + self.log_evaluate_save_time += time.perf_counter() - save_start + def _load_rng_state(self, checkpoint): # Load RNG states from `checkpoint` if checkpoint is None: @@ -1113,11 +1097,11 @@ def _load_rng_state(self, checkpoint): np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) if self.args.use_habana: - if self.args.local_rank != -1: - self.hpu_random.set_rng_state(checkpoint_rng_state["hpu"]) + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + self.hpu_random.set_rng_state_all(checkpoint_rng_state["hpu"]) else: try: - self.hpu_random.set_rng_state_all(checkpoint_rng_state["hpu"]) + self.hpu_random.set_rng_state(checkpoint_rng_state["hpu"]) except Exception as e: logger.info( f"Didn't manage to set back the RNG states of the HPU because of the following error:\n {e}" @@ -1138,13 +1122,13 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) - if self.deepspeed: + if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True - self.deepspeed.save_checkpoint(output_dir) + self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler - if self.args.should_save and not self.deepspeed: + if self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched # This block is exectuted by the main process only optim_dict = self.optimizer.state_dict() @@ -1192,7 +1176,7 @@ def _save_checkpoint(self, model, trial, metrics=None): } if self.args.use_habana: - if self.args.local_rank == -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global HPU RNG state rng_states["hpu"] = self.hpu_random.get_rng_state_all() else: @@ -1215,7 +1199,7 @@ def _save_checkpoint(self, model, trial, metrics=None): self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) # Synchronize all processes after saving the current checkpoint - if self.args.local_rank != -1 and self.args.use_habana: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.use_habana: torch.distributed.barrier() def _load_optimizer_and_scheduler(self, checkpoint): @@ -1223,14 +1207,16 @@ def _load_optimizer_and_scheduler(self, checkpoint): if checkpoint is None: return - if self.deepspeed: + if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( os.path.join(checkpoint, SCHEDULER_NAME) ): - # Load in optimizer and scheduler states + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = "cpu" if self.args.use_habana else self.args.device self.optimizer.load_state_dict( @@ -1254,39 +1240,206 @@ def _load_optimizer_and_scheduler(self, checkpoint): if self.do_grad_scaling: to_device_dtype(self.scaler.state.values(), target_device=torch.device("hpu")) - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[List[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: + def log(self, logs: Dict[str, float]) -> None: """ - Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. - Works both with or without labels. + Log `logs` on the various objects watching training. + Subclass and override this method to inject custom behavior. + Args: + logs (`Dict[str, float]`): + The values to log. """ - args = self.args - - prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine + mem_stats = get_hpu_memory_stats() + logs.update(mem_stats) - model = self._wrap_model(self.model, training=False, dataloader=dataloader) - model.eval() + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) - # Do not use HPU graphs if the training is ongoing because it detaches gradients - if args.use_hpu_graphs_for_inference and not self.is_in_train: - logger.info("Using HPU graphs for inference.") + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + Compared to Transformers, it is also possible to enable non-blocking data copy. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) + if self.args.non_blocking_data_copy: + return data.to(**kwargs, non_blocking=True) + else: + return data.to(**kwargs) + return data + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. Modified by Habana to enable using `autocast` on Gaudi devices. + """ + if self.use_cpu_amp: + ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=torch.bfloat16) + elif self.use_hpu_amp: + ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True) + else: + import contextlib + + ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() + return ctx_manager + + def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`torch.nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.pipelining_fwd_bwd: + self.htcore.mark_step() + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + else: + self.accelerator.backward(loss) + + return loss.detach() / self.args.gradient_accumulation_steps + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + Will only save from the main process. + """ + if output_dir is None: + output_dir = self.args.output_dir + + if self.is_deepspeed_enabled: + # this takes care of everything as long as we aren't under zero3 + try: + state_dict = self.accelerator.get_state_dict(self.deepspeed) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.model_wrapped.save_checkpoint(output_dir) + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + + if state_dict is None: + state_dict = self.model.state_dict() + if state_dict and self.args.use_habana: + # state_dict items have to be saved on the CPU + state_dict = to_device_dtype(state_dict, target_device=torch.device("cpu")) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if isinstance(unwrap_model(self.model), supported_classes): + unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + self.gaudi_config.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + model.eval() + + # Do not use HPU graphs if the training is ongoing because it detaches gradients + if args.use_hpu_graphs_for_inference and not self.is_in_train: + logger.info("Using HPU graphs for inference.") # Do not wrap the model in HPU graphs if it has already been done if not self.already_wrapped_for_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -1346,15 +1499,15 @@ def evaluation_loop( # Update containers on host if loss is not None: - losses = self._nested_gather(loss.repeat(batch_size)) - losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) + losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) if labels is not None: - labels = self._pad_across_processes(labels) - labels = self._nested_gather(labels) + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + labels = self.accelerator.gather_for_metrics((labels)) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) if inputs_decode is not None: - inputs_decode = self._pad_across_processes(inputs_decode) - inputs_decode = self._nested_gather(inputs_decode) + inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) + inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_host = ( inputs_decode if inputs_host is None @@ -1363,15 +1516,16 @@ def evaluation_loop( if logits is not None: if args.use_habana and logits_dtype != "float32": logits = to_device_dtype(logits, target_dtype=torch.float32) - logits = self._pad_across_processes(logits) - logits = self._nested_gather(logits) + logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self.accelerator.gather_for_metrics((logits)) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) @@ -1438,17 +1592,6 @@ def evaluation_loop( if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples - # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of - # samplers has been rounded to a multiple of batch_size, so we truncate. - if all_losses is not None: - all_losses = all_losses[:num_samples] - if all_preds is not None: - all_preds = nested_truncate(all_preds, num_samples) - if all_labels is not None: - all_labels = nested_truncate(all_labels, num_samples) - if all_inputs is not None: - all_inputs = nested_truncate(all_inputs, num_samples) - # Convert predictions back into their original dtype if necessary if all_preds is not None: all_preds = convert_into_dtypes(all_preds, logits_dtype) @@ -1568,6 +1711,51 @@ def prediction_step( return (loss, logits, labels) + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one. + if self.push_in_progress is not None and not self.push_in_progress.is_done: + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, GAUDI_CONFIG_NAME] + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + try: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Temporarily move the checkpoint just saved for the push + tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") + # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a + # subfolder. + if os.path.isdir(tmp_checkpoint): + shutil.rmtree(tmp_checkpoint) + shutil.move(checkpoint_folder, tmp_checkpoint) + + if self.args.save_strategy == IntervalStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + _, self.push_in_progress = self.repo.push_to_hub( + commit_message=commit_message, blocking=False, auto_lfs_prune=True + ) + finally: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Move back the checkpoint to its place + shutil.move(tmp_checkpoint, checkpoint_folder) + + # + # Deprecated code + # def prediction_loop( self, dataloader: DataLoader, @@ -1587,23 +1775,29 @@ def prediction_loop( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since - # for example the Z3-optimizer is a must for zero3 to work even for inference - what we - # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer - deepspeed_engine.optimizer.optimizer = None - deepspeed_engine.lr_scheduler = None + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped model.eval() # Do not use HPU graphs if the training is ongoing because it detaches gradients @@ -1718,206 +1912,31 @@ def prediction_loop( return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - """ - Will save the model, so you can reload it using `from_pretrained()`. - Will only save from the main process. - """ - if output_dir is None: - output_dir = self.args.output_dir - - if self.deepspeed: - # this takes care of everything as long as we aren't under zero3 - if self.args.should_save: - self._save(output_dir) - - if is_deepspeed_zero3_enabled(): - # It's too complicated to try to override different places where the weights dump gets - # saved, so since under zero3 the file is bogus, simply delete it. The user should - # either user deepspeed checkpoint to resume or to recover full weights use - # zero_to_fp32.py stored in the checkpoint. - if self.args.should_save: - file = os.path.join(output_dir, WEIGHTS_NAME) - if os.path.isfile(file): - # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") - os.remove(file) - - # now save the real model if stage3_gather_16bit_weights_on_model_save=True - # if false it will not be saved. - # This must be called on all ranks - if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): - logger.warning( - "deepspeed.save_16bit_model didn't save the model, since" - " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" - " zero_to_fp32.py to recover weights" - ) - self.deepspeed.save_checkpoint(output_dir) - elif self.args.should_save: - self._save(output_dir) - - # Push to the Hub when `save_model` is called by the user. - if self.args.push_to_hub and not _internal_call: - self.push_to_hub(commit_message="Model save") - - def _save(self, output_dir: Optional[str] = None, state_dict=None): - # If we are executing this function, we are the process zero, so we don't check for that. - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Saving model checkpoint to {output_dir}") - - if state_dict is None: - state_dict = self.model.state_dict() - if state_dict and self.args.use_habana: - # state_dict items have to be saved on the CPU - state_dict = to_device_dtype(state_dict, target_device=torch.device("cpu")) - - # Save a trained model and configuration using `save_pretrained()`. - # They can then be reloaded using `from_pretrained()` - if not isinstance(self.model, PreTrainedModel): - if isinstance(unwrap_model(self.model), PreTrainedModel): - unwrap_model(self.model).save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - else: - logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - if self.args.save_safetensors: - safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) - else: - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) - else: - self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) - - self.gaudi_config.save_pretrained(output_dir) - - # Good practice: save your training arguments together with the trained model - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - - def _push_from_checkpoint(self, checkpoint_folder): - # Only push from one node. - if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: - return - # If we haven't finished the last push, we don't do this one. - if self.push_in_progress is not None and not self.push_in_progress.is_done: - return - - output_dir = self.args.output_dir - # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder - modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, GAUDI_CONFIG_NAME] - for modeling_file in modeling_files: - if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): - shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) - # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) - # Same for the training arguments - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) - try: - if self.args.hub_strategy == HubStrategy.CHECKPOINT: - # Temporarily move the checkpoint just saved for the push - tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") - # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a - # subfolder. - if os.path.isdir(tmp_checkpoint): - shutil.rmtree(tmp_checkpoint) - shutil.move(checkpoint_folder, tmp_checkpoint) - - if self.args.save_strategy == IntervalStrategy.STEPS: - commit_message = f"Training in progress, step {self.state.global_step}" - else: - commit_message = f"Training in progress, epoch {int(self.state.epoch)}" - _, self.push_in_progress = self.repo.push_to_hub( - commit_message=commit_message, blocking=False, auto_lfs_prune=True - ) - finally: - if self.args.hub_strategy == HubStrategy.CHECKPOINT: - # Move back the checkpoint to its place - shutil.move(tmp_checkpoint, checkpoint_folder) - - def _load_best_model(self): - logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") - best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) - best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) - model = self.model - if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): - # TODO: the code below does not work with Habana DeepSpeed - # if self.deepspeed: - - # if self.model_wrapped is not None: - # # this removes the pre-hooks from the previous engine - # self.model_wrapped.destroy() - # self.model_wrapped = None - - # # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping - # deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - # self, - # num_training_steps=self.args.max_steps, - # resume_from_checkpoint=self.state.best_model_checkpoint, - # ) - # self.model = deepspeed_engine.module - # self.model_wrapped = deepspeed_engine - # self.deepspeed = deepspeed_engine - # self.optimizer = optimizer - # self.lr_scheduler = lr_scheduler - # else: - # We load the model state dict on the CPU to avoid an OOM error. - if self.args.save_safetensors and os.path.isfile(best_safe_model_path): - state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") - else: - state_dict = torch.load(best_model_path, map_location="cpu") - - # If the model is on the GPU, it still works! - load_result = model.load_state_dict(state_dict, strict=False) - self._issue_warnings_after_load(load_result) - else: - logger.warning( - f"Could not locate the best model at {best_model_path}, if you are running a distributed training " - "on multiple nodes, you should activate `--save_on_each_node`." - ) - - def log(self, logs: Dict[str, float]) -> None: - """ - Log `logs` on the various objects watching training. - Subclass and override this method to inject custom behavior. - Args: - logs (`Dict[str, float]`): - The values to log. - """ - if self.state.epoch is not None: - logs["epoch"] = round(self.state.epoch, 2) - - mem_stats = get_hpu_memory_stats() - logs.update(mem_stats) + # create accelerator object + self.accelerator = GaudiAccelerator( + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_plugin=gradient_accumulation_plugin, + even_batches=self.args.use_lazy_mode and not self.args.dataloader_drop_last, + ) - output = {**logs, **{"step": self.state.global_step}} - self.state.log_history.append(output) - self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = False - def _move_model_to_device(self, model, device): - model = model.to(device) - # Moving a model to HPU disconnects the tied weights, so we have to retie them. - if self.args.use_habana and hasattr(model, "tie_weights"): - model.tie_weights() + # post accelerator creation setup + if self.is_deepspeed_enabled: + if getattr(self.args, "hf_deepspeed_config", None) is None: + from .deepspeed import GaudiTrainerDeepSpeedConfig - def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): - """ - A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired - arguments, depending on the situation. Modified by Habana to enable using `autocast` on Gaudi devices. - """ - if self.use_cpu_amp: - ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=torch.bfloat16) - elif self.use_hpu_amp: - ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True) - else: - import contextlib + ds_plugin = self.accelerator.state.deepspeed_plugin - ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() - return ctx_manager + ds_plugin.hf_ds_config = GaudiTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config def _zero_model_grad(self, model): if hasattr(model, "_zero_grad_kwargs"): @@ -1926,7 +1945,7 @@ def _zero_model_grad(self, model): # Optimization based on setting gradients to None (instead of zeroing them out) may only be used when gradients are not recorded using HPU graphs. # HPU graphs rely on fixed tensors - setting gradients to None will enforce their re-allocation during the backward pass each step. set_to_none = ( - self.args.local_rank == -1 or self.args.distribution_strategy == "ddp" + self.args.parallel_mode != ParallelMode.DISTRIBUTED or self.args.distribution_strategy == "ddp" ) and not self.args.use_hpu_graphs_for_training try: diff --git a/optimum/habana/transformers/trainer_seq2seq.py b/optimum/habana/transformers/trainer_seq2seq.py index 3985a05709..11eac22ded 100644 --- a/optimum/habana/transformers/trainer_seq2seq.py +++ b/optimum/habana/transformers/trainer_seq2seq.py @@ -215,6 +215,7 @@ def prediction_step( inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, + **gen_kwargs, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. @@ -228,6 +229,8 @@ def prediction_step( argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): Whether or not to return the loss only. + gen_kwargs: + Additional `generate` specific kwargs. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). @@ -244,7 +247,10 @@ def prediction_step( # XXX: adapt synced_gpus for fairscale as well # Priority (handled in generate): # gen_kwargs > model.generation_config > default GenerationConfig() - gen_kwargs = self._gen_kwargs.copy() + + if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): + gen_kwargs = self._gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: gen_kwargs["max_length"] = self.model.generation_config.max_length gen_kwargs["num_beams"] = ( @@ -266,9 +272,14 @@ def prediction_step( else self.args.use_hpu_graphs_for_inference ) - # TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block - # users from preparing a dataset with `decoder_input_ids`. - inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} + # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate + # (otherwise, it would continue generating from the padded `decoder_input_ids`) + if ( + "labels" in inputs + and "decoder_input_ids" in inputs + and inputs["labels"].shape == inputs["decoder_input_ids"].shape + ): + inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} try: generated_tokens = self.model.generate( **inputs, diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index b7dbeebd4a..b705810a76 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -26,20 +26,21 @@ from transformers.trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType from transformers.training_args import ( OptimizerNames, + ParallelMode, TrainingArguments, default_logdir, - get_int_from_env, ) from transformers.utils import ( - ccl_version, get_full_repo_name, is_accelerate_available, - is_psutil_available, is_safetensors_available, ) from optimum.utils import logging +from ..accelerate.state import GaudiAcceleratorState, GaudiPartialState +from ..accelerate.utils import GaudiDistributedType + if is_torch_available(): import torch @@ -242,11 +243,20 @@ class GaudiTrainingArguments(TrainingArguments): half_precision_backend: str = field( default="hpu_amp", metadata={ - "help": "The backend to use for half precision.", + "help": "The backend to be used for half precision.", "choices": ["cpu_amp", "hpu_amp"], }, ) + # Overriding ddp_backend to replace all possible backends by hccl + ddp_backend: Optional[str] = field( + default="hccl", + metadata={ + "help": "The backend to be used for distributed training.", + "choices": ["hccl"], + }, + ) + def __post_init__(self): if self.use_hpu_graphs: warnings.warn( @@ -296,17 +306,6 @@ def __post_init__(self): if self.throughput_warmup_steps < 0: raise ValueError("--throughput_warmup_steps must be positive.") - # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). - # This needs to happen before any call to self.device or self.n_gpu. - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != self.local_rank: - self.local_rank = env_local_rank - - if self.local_rank != -1 and self.use_hpu_graphs_for_training and self.distribution_strategy != "fast_ddp": - raise ValueError( - "`--use_hpu_graphs_for_training` may only be used with `--distribution_strategy fast_ddp`" - ) - # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home # see https://github.com/huggingface/transformers/issues/10628 @@ -355,6 +354,19 @@ def __post_init__(self): if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: + if self.logging_steps != int(self.logging_steps): + raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") + self.logging_steps = int(self.logging_steps) + if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_steps != int(self.eval_steps): + raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") + self.eval_steps = int(self.eval_steps) + if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: + if self.save_steps != int(self.save_steps): + raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") + self.save_steps = int(self.save_steps) + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. if self.load_best_model_at_end: if self.evaluation_strategy != self.save_strategy: @@ -363,6 +375,20 @@ def __post_init__(self): f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" ) if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_steps < 1 or self.save_steps < 1: + if not (self.eval_steps < 1 and self.save_steps < 1): + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps" + f"{self.save_steps} and eval_steps {self.eval_steps}." + ) + # Work around floating point precision issues + LARGE_MULTIPLIER = 1_000_000 + if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." + ) raise ValueError( "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." @@ -379,13 +405,21 @@ def __post_init__(self): f"https://github.com/huggingface/safetensors!" ) - if self.load_best_model_at_end and self.metric_for_best_model is None: + if ( + self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU + ) and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] if self.run_name is None: self.run_name = self.output_dir + if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: + if self.evaluation_strategy == IntervalStrategy.NO: + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") + if not is_torch_available(): + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") + self.optim = OptimizerNames(self.optim) if self.adafactor: warnings.warn( @@ -427,10 +461,13 @@ def __post_init__(self): if isinstance(self.debug, str): self.debug = [DebugOption(s) for s in self.debug.split()] + elif self.debug is None: + self.debug = [] # This call to self.device is necessary to call _setup_devices so that # torch.distributed is initialized device_is_hpu = self.device.type == "hpu" + self.deepspeed_plugin = None if self.deepspeed: if not device_is_hpu: raise ValueError("This version of DeepSpeed must be run on HPUs.") @@ -446,6 +483,12 @@ def __post_init__(self): self.hf_deepspeed_config = GaudiTrainerDeepSpeedConfig(self.deepspeed) self.hf_deepspeed_config.trainer_config_process(self) + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + if self.push_to_hub_token is not None: warnings.warn( ( @@ -489,6 +532,21 @@ def __post_init__(self): FutureWarning, ) + # if training args is specified, it will override the one specified in the accelerate config + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + + if ( + self.parallel_mode == ParallelMode.DISTRIBUTED + and self.use_hpu_graphs_for_training + and self.distribution_strategy != "fast_ddp" + ): + raise ValueError( + "`--use_hpu_graphs_for_training` may only be used with `--distribution_strategy fast_ddp`." + ) + def __str__(self): self_as_dict = asdict(self) @@ -511,75 +569,23 @@ def __str__(self): def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) + logger.info("PyTorch: setting up devices") + if not is_accelerate_available(min_version="0.21.0"): + raise ImportError( + "Using the `GaudiTrainer` requires `accelerate>=0.21.0`: Please run `pip install accelerate -U`." + ) + GaudiAcceleratorState._reset_state() + GaudiPartialState._reset_state() + self.distributed_state = None + # Set the log level here for optimum.utils.logging # otherwise logs are not sent in this method. log_level = self.get_process_log_level() logging.set_verbosity(log_level) - logger.info("PyTorch: setting up devices") - if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: - logger.warning("torch.distributed process group is initialized, but local_rank == -1. ") if self.no_cuda: - device = torch.device("cpu") + self.distributed_state = GaudiPartialState(cpu=True, backend=self.ddp_backend) self._n_gpu = 0 - self.local_rank = get_int_from_env( - ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], - self.local_rank, - ) - if self.local_rank != -1 and not torch.distributed.is_initialized(): - # Initializes distributed backend for cpu - if self.xpu_backend not in ("mpi", "ccl", "gloo"): - raise ValueError( - "CPU distributed training backend is not properly set. " - "Please set '--xpu_backend' to either 'mpi' or 'ccl' or 'gloo'." - ) - if self.xpu_backend == "ccl": - requires_backends(self, "oneccl_bind_pt") - if ccl_version >= "1.12": - import oneccl_bindings_for_pytorch # noqa: F401 - else: - import torch_ccl # noqa: F401 - if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1: - raise ValueError( - "CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. " - "Please use like 'export CCL_WORKER_COUNT = 1' to set." - ) - - # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH - rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0) - size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1) - local_size = get_int_from_env( - ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 - ) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(size) - os.environ["LOCAL_RANK"] = str(self.local_rank) - if not os.environ.get("MASTER_PORT", None): - os.environ["MASTER_PORT"] = "29500" - if not os.environ.get("MASTER_ADDR", None): - if local_size != size or self.xpu_backend != "mpi": - raise ValueError( - "Looks like distributed multinode run but MASTER_ADDR env not set, " - "please try exporting rank 0's hostname as MASTER_ADDR" - ) - if ( - torch.get_num_threads() == 1 - and get_int_from_env(["OMP_NUM_THREADS", "MKL_NUM_THREADS"], 0) == 0 - and is_psutil_available() - ): - import psutil - - num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size) - if num_cpu_threads_per_process == 0: - num_cpu_threads_per_process = 1 - torch.set_num_threads(num_cpu_threads_per_process) - logger.info( - f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob" - " performance." - ) - torch.distributed.init_process_group( - backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta - ) elif self.use_habana: # Some methods needs to be tweaked to optimally run on Gaudi # Calling this method here to be sure it is done before model instantiation @@ -594,40 +600,34 @@ def _setup_devices(self) -> "torch.device": os.environ["PT_HPU_LAZY_MODE"] = "2" logger.info("Enabled eager mode because use_lazy_mode=False.") - device = torch.device("hpu") - self._n_gpu = 1 - - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, self.local_rank = initialize_distributed_hpu() - if self.deepspeed: - # deepspeed inits torch.distributed internally - from transformers.deepspeed import is_deepspeed_available - - if not is_deepspeed_available(): - raise ImportError( - "--deepspeed requires deepspeed: `pip install" - " git+https://github.com/HabanaAI/DeepSpeed.git@1.10.0`." - ) - import deepspeed - - if world_size > 1: - os.environ["HLS_MODULE_ID"] = str(self.local_rank) - os.environ["ID"] = str(rank) - - deepspeed.init_distributed(dist_backend="hccl", timeout=timedelta(seconds=self.ddp_timeout)) - logger.info("DeepSpeed is enabled.") + # Need to do similar for Accelerator init + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = GaudiPartialState(timeout=timedelta(seconds=self.ddp_timeout)) + del os.environ["ACCELERATE_USE_DEEPSPEED"] else: - if self.local_rank != -1: - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="hccl", rank=rank, world_size=world_size) - logger.info("Enabled distributed run.") - else: - logger.info("Single-device run.") + self.distributed_state = GaudiPartialState( + backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) + ) + self._n_gpu = 1 else: raise ValueError( "No device has been set. Use either --use_habana to run on HPU or --no_cuda to run on CPU." ) + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and self.parallel_mode != ParallelMode.DISTRIBUTED + ): + logger.warning( + "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + + if self.distributed_state.distributed_type == GaudiDistributedType.NO: + self._n_gpu = 0 + return device diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index bc2d603386..860fd3d89b 100644 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -91,6 +91,8 @@ def speed_metrics( runtime = time.time() - start_time result = {f"{split}_runtime": round(runtime, 4)} + if runtime == 0: + return result # Adjust runtime if log_evaluate_save_time should not be included if log_evaluate_save_time is not None: diff --git a/pyproject.toml b/pyproject.toml index 5d9b675a8c..7323ffa36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ target-version = ['py37'] [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["C901", "E501", "E741", "W605"] +ignore = ["C901", "E501", "E741"] select = ["C", "E", "F", "I", "W"] line-length = 119 exclude = ["text-generation-inference"] diff --git a/setup.py b/setup.py index a7300eb5ee..2be656bf38 100644 --- a/setup.py +++ b/setup.py @@ -29,10 +29,10 @@ INSTALL_REQUIRES = [ - "transformers >= 4.26.0, < 4.29.0", + "transformers >= 4.31.0", "optimum", "torch", - "accelerate", + "accelerate >= 0.21.0", "diffusers >= 0.18.0", ] diff --git a/tests/create_diff_file_for_example.py b/tests/create_diff_file_for_example.py index ff45d325cb..467848fb20 100644 --- a/tests/create_diff_file_for_example.py +++ b/tests/create_diff_file_for_example.py @@ -106,16 +106,18 @@ def auto_diff(): # Loop over all the "run_*.py" scripts in the example folder for file in directory.iterdir(): if file.is_file() and file.name.startswith("run_"): - final_diff = create_diff_content( - diff( - path_to_transformers / file.name, - file, - ), - keep_all_diffs=True, - ) - diff_filename = DIFF_DIRECTORY / f"{file.stem}.txt" - with open(diff_filename, "w") as fp: - fp.write(final_diff) + transformers_file = path_to_transformers / file.name + if transformers_file.is_file(): + final_diff = create_diff_content( + diff( + transformers_file, + file, + ), + keep_all_diffs=True, + ) + diff_filename = DIFF_DIRECTORY / f"{file.stem}.txt" + with open(diff_filename, "w") as fp: + fp.write(final_diff) def parse_args(): diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index fd60770573..795f8c69c2 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -21,9 +21,9 @@ > from optimum.habana.utils import set_seed > 48c41 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 164,166d156 < freeze_feature_extractor: Optional[bool] = field( < default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."} @@ -56,20 +56,22 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -224,225c206,208 +222a205 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +224,225c207,209 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} " -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -291a275,277 +> + f"mixed-precision training: {mixed_precision}" +291a276,278 > # Max input length > max_length = int(round(feature_extractor.sampling_rate * data_args.max_length_seconds)) > -296a283 +296a284 > -302c289,295 +302c290,296 < inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) --- > inputs = feature_extractor( @@ -79,7 +81,7 @@ > padding="max_length", > truncation=True, > ) -311c304,310 +311c305,311 < inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) --- > inputs = feature_extractor( @@ -89,9 +91,9 @@ > padding="max_length", > truncation=True, > ) -376c375 +376c376 < trainer = Trainer( --- > trainer = GaudiTrainer( -377a377 +377a378 > gaudi_config=gaudi_config, diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt index 66dde55145..c80e2befac 100644 --- a/tests/example_diff/run_clip.txt +++ b/tests/example_diff/run_clip.txt @@ -2,26 +2,32 @@ < 32a32 > import transformers -38,39d37 +33a34 +> from habana_dataloader_trainer import HabanaDataloaderTrainer +38,39d38 < < import transformers -45,47d42 +45,47d43 < Trainer, < TrainingArguments, < set_seed, -52a48,50 +52a49,51 > from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments > from optimum.habana.utils import set_seed > -57c55 -< check_min_version("4.29.0.dev0") +57c56 +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") -230c228 +> check_min_version("4.31.0") +171a171,173 +> mediapipe_dataloader: bool = field( +> default=False, metadata={"help": "Turn on MediaPipe hardware-based accelerated data loading."} +> ) +230c232 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- > parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments)) -258a257,263 +258a261,267 > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, @@ -29,18 +35,57 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -261,262c266,268 +259a269 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +261,262c271,273 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -399d404 +> + f"mixed-precision training: {mixed_precision}" +399d409 < image_transformations = torch.jit.script(image_transformations) -496c501 +446,447c456,464 +< # Transform images on the fly as doing it on the whole dataset takes too much time. +< train_dataset.set_transform(transform_images) +--- +> if data_args.mediapipe_dataloader: +> train_dataset.image_mean = image_processor.image_mean +> train_dataset.image_std = image_processor.image_std +> train_dataset.text_max_length = data_args.max_seq_length +> train_dataset.image_resize = config.vision_config.image_size +> train_dataset.transform_func = transform_images +> else: +> # Transform images on the fly as doing it on the whole dataset takes too much time. +> train_dataset.set_transform(transform_images) +469,470c486,494 +< # Transform images on the fly as doing it on the whole dataset takes too much time. +< eval_dataset.set_transform(transform_images) +--- +> if data_args.mediapipe_dataloader: +> eval_dataset.image_mean = image_processor.image_mean +> eval_dataset.image_std = image_processor.image_std +> eval_dataset.text_max_length = data_args.max_seq_length +> eval_dataset.image_resize = config.vision_config.image_size +> eval_dataset.transform_func = transform_images +> else: +> # Transform images on the fly as doing it on the whole dataset takes too much time. +> eval_dataset.set_transform(transform_images) +493a518,526 +> if data_args.mediapipe_dataloader: +> test_dataset.image_mean = image_processor.image_mean +> test_dataset.image_std = image_processor.image_std +> test_dataset.text_max_length = data_args.max_seq_length +> test_dataset.image_resize = config.vision_config.image_size +> test_dataset.transform_func = transform_images +> else: +> # Transform images on the fly as doing it on the whole dataset takes too much time. +> test_dataset.set_transform(transform_images) +496c529,530 < trainer = Trainer( --- -> trainer = GaudiTrainer( -497a503 +> trainer_cls = HabanaDataloaderTrainer if data_args.mediapipe_dataloader else GaudiTrainer +> trainer = trainer_cls( +497a532 > gaudi_config=gaudi_config, diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt index 252a780a77..4ec8bb108d 100644 --- a/tests/example_diff/run_clm.txt +++ b/tests/example_diff/run_clm.txt @@ -25,9 +25,9 @@ > from optimum.habana.utils import set_seed > 58c55 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 79c76,77 < "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." --- @@ -60,22 +60,24 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -265,266c280,282 +263a279 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +265,266c281,283 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -376a393 +> + f"mixed-precision training: {mixed_precision}" +376a394 > "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, -562c579 +561c579 < trainer = Trainer( --- > trainer = GaudiTrainer( -563a581 +562a581 > gaudi_config=gaudi_config, -570,573c588,589 +569,572c588,589 < compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, < preprocess_logits_for_metrics=preprocess_logits_for_metrics < if training_args.do_eval and not is_torch_tpu_available() @@ -83,7 +85,7 @@ --- > compute_metrics=compute_metrics if training_args.do_eval else None, > preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, -627,631d642 +626,630d642 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_generation.txt b/tests/example_diff/run_generation.txt index faa85479f4..7c245f24e3 100644 --- a/tests/example_diff/run_generation.txt +++ b/tests/example_diff/run_generation.txt @@ -5,38 +5,52 @@ --- > Conditional text generation on Habana Gaudi/Gaudi2. > """ -23c23,25 +22c22 +< import inspect +--- +> import copy +24c24,25 < from typing import Tuple --- > import os -> import tempfile > import time -25d26 +26d26 < import numpy as np -27,43c28,31 +28,49c28,33 < < from transformers import ( +< AutoTokenizer, +< BloomForCausalLM, +< BloomTokenizerFast, < CTRLLMHeadModel, < CTRLTokenizer, < GenerationMixin, < GPT2LMHeadModel, < GPT2Tokenizer, +< GPTJForCausalLM, +< LlamaForCausalLM, +< LlamaTokenizer, < OpenAIGPTLMHeadModel, < OpenAIGPTTokenizer, +< OPTForCausalLM, < TransfoXLLMHeadModel, < TransfoXLTokenizer, < XLMTokenizer, < XLMWithLMHeadModel, < XLNetLMHeadModel, < XLNetTokenizer, -< ) +--- +> from checkpoint_utils import ( +> get_ds_injection_policy, +> get_repo_root, +> model_is_bloom, +> model_is_optimized, +> write_checkpoints_json, +51c35 < from transformers.modeling_outputs import CausalLMOutputWithPast --- -> import torch.nn.functional as F -> from checkpoint_utils import model_is_bloom, write_checkpoints_json > from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -> from transformers.generation import GenerationConfig -53,280d40 +61,290d44 < MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop < < MODEL_CLASSES = { @@ -46,6 +60,10 @@ < "xlnet": (XLNetLMHeadModel, XLNetTokenizer), < "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), < "xlm": (XLMWithLMHeadModel, XLMTokenizer), +< "gptj": (GPTJForCausalLM, AutoTokenizer), +< "bloom": (BloomForCausalLM, BloomTokenizerFast), +< "llama": (LlamaForCausalLM, LlamaTokenizer), +< "opt": (OPTForCausalLM, GPT2Tokenizer), < } < < # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -160,23 +178,26 @@ < raise ValueError("Check the model config") < < num_embedding_size_per_head = int(embedding_size / num_head) -< num_layer = model_config.n_layer +< if hasattr(model_config, "n_layer"): +< num_layer = model_config.n_layer +< elif hasattr(model_config, "num_hidden_layers"): +< num_layer = model_config.num_hidden_layers +< else: +< raise ValueError("Number of hidden layers couldn't be determined from the model config") < < return num_layer, num_head, num_embedding_size_per_head < < -< def prepare_jit_inputs(inputs, model, tokenizer): -< num_batch = len(inputs) -< dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) +< def generate_past_key_values(model, batch_size, seq_len): < num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) < if model.config.model_type == "bloom": < past_key_values = tuple( < ( -< torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1) -< .to(model.config.torch_dtype) +< torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) +< .to(model.dtype) < .to(model.device), -< torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head) -< .to(model.config.torch_dtype) +< torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) +< .to(model.dtype) < .to(model.device), < ) < for _ in range(num_block_layers) @@ -184,37 +205,34 @@ < else: < past_key_values = tuple( < ( -< torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) -< .to(model.config.torch_dtype) +< torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) +< .to(model.dtype) < .to(model.device), -< torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) -< .to(model.config.torch_dtype) +< torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) +< .to(model.dtype) < .to(model.device), < ) < for _ in range(num_block_layers) < ) +< return past_key_values < +< +< def prepare_jit_inputs(inputs, model, tokenizer): +< batch_size = len(inputs) +< dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") +< dummy_input = dummy_input.to(model.device) +< if model.config.use_cache: +< dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) < dummy_input["attention_mask"] = torch.cat( < [ -< torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype), +< torch.zeros(dummy_input["attention_mask"].shape[0], 1) +< .to(dummy_input["attention_mask"].dtype) +< .to(model.device), < dummy_input["attention_mask"], < ], < -1, < ) -< -< if model.config.use_cache: -< jit_inputs = ( -< dummy_input["input_ids"].to(model.device), -< past_key_values, -< dummy_input["attention_mask"].to(model.device), -< ) -< else: -< jit_inputs = ( -< dummy_input["input_ids"].to(model.device), -< dummy_input["attention_mask"].to(model.device), -< ) -< -< return jit_inputs +< return dummy_input < < < class _ModelFallbackWrapper(GenerationMixin): @@ -225,15 +243,13 @@ < self._default = default < < def __call__(self, *args, **kwargs): -< if kwargs["past_key_values"] is None: -< return self._default(*args, **kwargs) -< trace_graph_inputs = [] +< if kwargs["past_key_values"] is None and self._default.config.use_cache: +< kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) < kwargs.pop("position_ids", None) -< for k, v in kwargs.items(): -< if v is not None and not isinstance(v, bool): -< trace_graph_inputs.append(v) -< trace_graph_inputs = tuple(trace_graph_inputs) -< outputs = self._optimized(*trace_graph_inputs) +< for k in list(kwargs.keys()): +< if kwargs[k] is None or isinstance(kwargs[k], bool): +< kwargs.pop(k) +< outputs = self._optimized(**kwargs) < lm_logits = outputs[0] < past_key_values = outputs[1] < fixed_output = CausalLMOutputWithPast( @@ -265,48 +281,73 @@ < """ < return self._default._reorder_cache(past_key_values, beam_idx) < -282a43 +292a47 > # Arguments management -285c46 +295c50 < "--model_type", --- > "--model_name_or_path", -289c50 +299c54 < help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), --- > help="Path to pre-trained model (on the HF Hub or locally).", -292c53 +302c57,76 < "--model_name_or_path", --- -> "--gaudi_config_name_or_path", -295,307c56 +> "--bf16", +> action="store_true", +> help="Whether to perform generation in bf16 precision.", +> ) +> parser.add_argument("--max_new_tokens", type=int, default=100, help="Number of tokens to generate.") +> parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.") +> parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.") +> parser.add_argument("--local_rank", type=int, default=-1, metavar="N", help="Local process rank.") +> parser.add_argument( +> "--use_kv_cache", +> action="store_true", +> help="Whether to use the key/value cache for decoding. It should speed up generation.", +> ) +> parser.add_argument( +> "--use_hpu_graphs", +> action="store_true", +> help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.", +> ) +> parser.add_argument( +> "--dataset_name", +305,306c79,85 < required=True, < help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), -< ) +--- +> help="Optional argument if you want to assess your model on a given dataset of the HF Hub.", +> ) +> parser.add_argument( +> "--column_name", +> default=None, +> type=str, +> help="If `--dataset_name` was given, this will be the name of the column to use as prompts for generation.", +308,312d86 < < parser.add_argument("--prompt", type=str, default="") < parser.add_argument("--length", type=int, default=20) < parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") < -< parser.add_argument( +314,317c88,90 < "--temperature", < type=float, < default=1.0, < help="temperature of 1.0 has no effect, lower tend toward greedy sampling", --- -> help="Path to Gaudi configuration (on the HF Hub or locally).", -308a58,61 -> parser.add_argument("--max_new_tokens", type=int, default=100) -> parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.") -> parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations.") -> parser.add_argument("--local_rank", type=int, default=-1, metavar="N", help="Local process rank.") -310c63,65 +> "--do_sample", +> action="store_true", +> help="Whether to use sampling for generation.", +320c93,96 < "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" --- -> "--use_kv_cache", -> action="store_true", -> help="Whether to use the key/value cache for decoding. It should speed up generation.", -312,321c67 +> "--num_beams", +> default=1, +> type=int, +> help="Number of beams used for beam search generation. 1 means greedy search will be performed.", +322,331d97 < parser.add_argument("--k", type=int, default=0) < parser.add_argument("--p", type=float, default=0.9) < @@ -317,51 +358,55 @@ < parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") < parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") < parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") ---- -> parser.add_argument("--use_hpu_graphs", action="store_true", help="Whether to use HPU graphs or not.") -323,325c69,72 +333,335c99,127 < "--fp16", < action="store_true", < help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", --- -> "--dataset_name", +> "--seed", +> default=27, +> type=int, +> help="Seed to use for random generation. Useful to reproduce your runs with `--do_sample`.", +> ) +> parser.add_argument( +> "--profiling_warmup_steps", +> default=0, +> type=int, +> help="Number of steps to ignore for profling.", +> ) +> parser.add_argument( +> "--profiling_steps", +> default=0, +> type=int, +> help="Number of steps to capture for profiling.", +> ) +> parser.add_argument( +> "--prompt", > default=None, > type=str, -> help="Optional argument if you want to assess your model on a given dataset of the HF Hub.", -328c75,78 -< "--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference" +> help="Optional argument to give a prompt of your choice as input.", +> ) +> parser.add_argument( +> "--bad_words", +> default=None, +> type=str, +> nargs="+", +> help="Optional argument list of words that are not allowed to be generated.", +337c129,137 +< parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") --- -> "--column_name", +> parser.add_argument( +> "--force_words", > default=None, > type=str, -> help="Optional argument if you want to assess your model on a given dataset of the HF Hub, this will be the name of the column to use as prompts for generation.", -330,333d79 -< args = parser.parse_args() -< +> nargs="+", +> help="Optional argument list of words that must be generated.", +> ) +> parser.add_argument("--num_return_sequences", type=int, default=1) +> +340,341c140,174 < args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") < args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() -335,350c81 -< logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") -< -< set_seed(args) -< -< # Initialize the model and tokenizer -< try: -< args.model_type = args.model_type.lower() -< model_class, tokenizer_class = MODEL_CLASSES[args.model_type] -< except KeyError: -< raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") -< -< tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) -< if tokenizer.pad_token is None: -< tokenizer.pad_token = tokenizer.eos_token -< model = model_class.from_pretrained(args.model_name_or_path) -< model.to(args.device) ---- -> args = parser.parse_args() -352,353c83,110 -< if args.fp16: -< model.half() --- > # If the DeepSpeed launcher is used, the env variable _ will be equal to /usr/local/bin/deepspeed > # For multi node, the value of the env variable WORLD_SIZE should be larger than 8 @@ -388,192 +433,176 @@ > > if not is_deepspeed_available(): > raise ImportError( -> "This script requires deepspeed: `pip install" " git+https://github.com/HabanaAI/DeepSpeed.git@1.9.0`." +> "This script requires deepspeed: `pip install" +> " git+https://github.com/HabanaAI/DeepSpeed.git@1.10.0`." > ) > import deepspeed -355,356c112,122 -< args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) -< logger.info(args) ---- +> > # Initialize process(es) for DeepSpeed > deepspeed.init_distributed(dist_backend="hccl") > logger.info("DeepSpeed is enabled.") > else: -> if args.gaudi_config_name_or_path is None: -> gaudi_config = None -> logger.warning( -> "`--gaudi_config_name_or_path` was not specified so not using Habana Mixed Precision for this run." -> ) -> else: -> from optimum.habana import GaudiConfig -358c124 -< prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") +> logger.info("Single-device run.") +343c176,177 +< logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") +--- +> # Tweak generation so that it runs faster on Gaudi +> from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +345c179 +< set_seed(args) +--- +> adapt_transformers_to_gaudi() +347,352c181,182 +< # Initialize the model and tokenizer +< try: +< args.model_type = args.model_type.lower() +< model_class, tokenizer_class = MODEL_CLASSES[args.model_type] +< except KeyError: +< raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") --- -> gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name_or_path) -360,364c126,127 +> # Set seed before initializing model. +> from optimum.habana.utils import set_seed +354,358c184 +< tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) +< if tokenizer.pad_token is None: +< tokenizer.pad_token = tokenizer.eos_token +< model = model_class.from_pretrained(args.model_name_or_path) +< model.to(args.device) +--- +> set_seed(args.seed) +360,372c186,189 +< if args.fp16: +< model.half() +< max_seq_length = getattr(model.config, "max_position_embeddings", 0) +< args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length) +< logger.info(args) +< +< prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") +< < # Different models need different input formatting and/or extra arguments < requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() < if requires_preprocessing: < prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) < preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) --- -> if gaudi_config.use_habana_mixed_precision: -> from habana_frameworks.torch.hpex import hmp -366,367c129,160 +> if args.bad_words is not None or args.force_words is not None: +> tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, add_prefix_space=True) +> else: +> tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) +374,375c191,204 < if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: < tokenizer_kwargs = {"add_space_before_punct_symbol": True} --- -> # Open temporary files to mixed-precision write ops -> with tempfile.NamedTemporaryFile() as hmp_bf16_file: -> with tempfile.NamedTemporaryFile() as hmp_fp32_file: -> # hmp.convert needs ops to be written in text files -> gaudi_config.write_bf16_fp32_ops_to_text_files( -> hmp_bf16_file.name, -> hmp_fp32_file.name, -> ) -> hmp.convert( -> opt_level=gaudi_config.hmp_opt_level, -> bf16_file_path=hmp_bf16_file.name, -> fp32_file_path=hmp_fp32_file.name, -> isVerbose=gaudi_config.hmp_is_verbose, -> ) -> logger.info("Single-device run.") -> -> # Tweak generation so that it runs faster on Gaudi -> from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -> -> adapt_transformers_to_gaudi() -> -> tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) +> if use_deepspeed or args.bf16: +> model_dtype = torch.bfloat16 +> else: +> model_dtype = torch.float > > if use_deepspeed: > config = AutoConfig.from_pretrained(args.model_name_or_path) -> args.dtype = torch.bfloat16 +> is_optimized = model_is_optimized(config) > is_bloom = model_is_bloom(config) > > if is_bloom: > # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load -> with deepspeed.OnDevice(dtype=args.dtype, device="meta"): -> model = AutoModelForCausalLM.from_config(config, torch_dtype=args.dtype) -369,373c162,187 +> with deepspeed.OnDevice(dtype=model_dtype, device="meta"): +> model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) +377c206,227 < tokenizer_kwargs = {} -< -< encoded_prompt = tokenizer.encode( -< preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs -< ) --- -> with deepspeed.OnDevice(dtype=args.dtype, device=args.device): -> model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=args.dtype) +> get_repo_root(args.model_name_or_path, args.local_rank) +> # TODO: revisit placement on CPU when auto-injection is possible +> with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): +> model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) > model = model.eval() > > # Initialize the model -> ds_inference_kwargs = {"dtype": args.dtype} +> ds_inference_kwargs = {"dtype": model_dtype} > ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} > ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs > -> # BLOOM is managed differently > if is_bloom: +> # BLOOM is managed differently > checkpoints_json = "checkpoints.json" > write_checkpoints_json(args.model_name_or_path, args.local_rank, checkpoints_json) > -> # Make sure all devices/nodes have access to the model checkpoints -> torch.distributed.barrier() -> -> from transformers.models.bloom.modeling_bloom import BloomBlock +> # Make sure all devices/nodes have access to the model checkpoints +> torch.distributed.barrier() > -> ds_inference_kwargs["injection_policy"] = {BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} +> ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config) +> if is_bloom: > ds_inference_kwargs["checkpoint"] = checkpoints_json -> +379,381c229,230 +< encoded_prompt = tokenizer.encode( +< preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs +< ) +--- > model = deepspeed.init_inference(model, **ds_inference_kwargs) -> if is_bloom: -> model.module.split_lm_head() > model = model.module -375,377c189,191 +383,385c232,247 < prefix = args.prefix if args.prefix else args.padding_text < encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") < encoded_prompt = encoded_prompt.to(args.device) --- -> model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) +> get_repo_root(args.model_name_or_path) +> model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) > model = model.eval().to(args.device) -> is_bloom = model_is_bloom(model.config) -379,382c193,194 -< if encoded_prompt.size()[-1] == 0: -< input_ids = None -< else: -< input_ids = encoded_prompt ---- +> is_optimized = model_is_optimized(model.config) +> > if args.use_hpu_graphs: > from habana_frameworks.torch.hpu import wrap_in_hpu_graph -384,405c196 -< if args.jit: -< jit_input_texts = ["jit"] -< jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) -< torch._C._jit_set_texpr_fuser_enabled(False) -< model.config.return_dict = False -< traced_model = torch.jit.trace(model, jit_inputs, strict=False) -< traced_model = torch.jit.freeze(traced_model.eval()) -< traced_model(*jit_inputs) -< traced_model(*jit_inputs) -< -< model = _ModelFallbackWrapper(traced_model, model) -< -< output_sequences = model.generate( -< input_ids=input_ids, -< max_length=args.length + len(encoded_prompt[0]), -< temperature=args.temperature, -< top_k=args.k, -< top_p=args.p, -< repetition_penalty=args.repetition_penalty, -< do_sample=True, -< num_return_sequences=args.num_return_sequences, -< ) ---- +> > model = wrap_in_hpu_graph(model) -407,409c198,201 -< # Remove the batch dimension when returning multiple sequences -< if len(output_sequences.shape) > 2: -< output_sequences.squeeze_() ---- +> +> if not model.config.is_encoder_decoder: +> tokenizer.padding_side = "left" > # Some models like GPT2 do not have a PAD token so we have to set it if necessary > if tokenizer.pad_token is None: > tokenizer.pad_token = tokenizer.eos_token > model.generation_config.pad_token_id = model.generation_config.eos_token_id -411c203,204 -< generated_sequences = [] +387,388c249,367 +< if encoded_prompt.size()[-1] == 0: +< input_ids = None --- > if rank in [-1, 0]: > logger.info(f"Args: {args}") -413,415c206,215 -< for generated_sequence_idx, generated_sequence in enumerate(output_sequences): -< print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") -< generated_sequence = generated_sequence.tolist() ---- -> use_bf16 = False -> if use_deepspeed or (gaudi_config is not None and gaudi_config.use_habana_mixed_precision): -> use_bf16 = True -> logger.info(f"device: {args.device}, n_hpu: {world_size}, bf16: {use_bf16}") +> logger.info(f"device: {args.device}, n_hpu: {world_size}, bf16: {use_deepspeed or args.bf16}") +> +> bad_words_ids = None +> force_words_ids = None +> if args.bad_words is not None: +> bad_words_ids = [tokenizer.encode(bad_word, add_special_tokens=False) for bad_word in args.bad_words] +> if args.force_words is not None: +> force_words_ids = [tokenizer.encode(force_word, add_special_tokens=False) for force_word in args.force_words] > > # Generation configuration -> generation_config = GenerationConfig( -> max_new_tokens=args.max_new_tokens, -> use_cache=args.use_kv_cache, -> ) -417,418c217,320 -< # Decode text -< text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) ---- +> generation_config = copy.deepcopy(model.generation_config) +> generation_config.max_new_tokens = args.max_new_tokens +> generation_config.use_cache = args.use_kv_cache +> generation_config.static_shapes = is_optimized +> generation_config.do_sample = args.do_sample +> generation_config.num_beams = args.num_beams +> generation_config.bad_words_ids = bad_words_ids +> generation_config.force_words_ids = force_words_ids +> generation_config.num_return_sequences = args.num_return_sequences +> > if args.dataset_name is None: > # Benchmark over the prompts below -> input_sentences = [ -> "DeepSpeed is a machine learning framework", -> "He is working on", -> "He has a", -> "He got all", -> "Everyone is happy and I can", -> "The new movie that got Oscar this year", -> "In the far far distance from our galaxy,", -> "Peace is the only way", -> ] +> if args.prompt: +> input_sentences = [ +> args.prompt, +> ] +> else: +> input_sentences = [ +> "DeepSpeed is a machine learning framework", +> "He is working on", +> "He has a", +> "He got all", +> "Everyone is happy and I can", +> "The new movie that got Oscar this year", +> "In the far far distance from our galaxy,", +> "Peace is the only way", +> ] > > if args.batch_size > len(input_sentences): > # Dynamically extends to support larger batch sizes @@ -589,18 +618,6 @@ > # Tokenization > input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) > -> # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs -> input_token_len = input_tokens.input_ids.shape[-1] -> input_tokens["input_ids"] = F.pad( -> input_tokens.input_ids, (0, args.max_new_tokens), value=model.config.pad_token_id -> ) -> input_tokens["attention_mask"] = F.pad(input_tokens.attention_mask, (0, args.max_new_tokens), value=0) -> if is_bloom: -> # token_idx is the current index in the generation process, it is incremented each time a new token is generated -> kwargs = {"token_idx": torch.tensor(input_token_len, device=args.device)} -> else: -> kwargs = {} -> > # Move inputs to target device(s) > for t in input_tokens: > if torch.is_tensor(input_tokens[t]): @@ -608,13 +625,18 @@ > > outputs = model.generate( > **input_tokens, -> **kwargs, > generation_config=generation_config, > lazy_mode=True, > hpu_graphs=args.use_hpu_graphs, +> profiling_steps=args.profiling_steps, +> profiling_warmup_steps=args.profiling_warmup_steps, > ).cpu() > return tokenizer.batch_decode(outputs, skip_special_tokens=True) > +> from optimum.habana.utils import HabanaProfile +> +> # compilation stage disable profiling +> HabanaProfile.disable() > # Compilation > if rank in [-1, 0]: > logger.info("Graph compilation...") @@ -624,7 +646,7 @@ > generate() > torch_hpu.synchronize() > compilation_duration = time.perf_counter() - t0 -> +> HabanaProfile.enable() > total_new_tokens_generated = 0 > if rank in [-1, 0]: > logger.info("Running generate...") @@ -637,23 +659,33 @@ > throughput = total_new_tokens_generated / duration > > if rank in [-1, 0]: +> from optimum.habana.utils import get_hpu_memory_stats +> > stats = f"Throughput (including tokenization) = {throughput} tokens/second" > separator = "-" * len(stats) > print() > print("Stats:") > print(separator) > print(stats) +> mem = get_hpu_memory_stats() +> for k, v in mem.items(): +> print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) > if args.use_hpu_graphs: -> print(f"Graph compilation duration = {compilation_duration} seconds") +> print(f"Graph compilation duration = {compilation_duration} seconds") > print(separator) > print() > print("Input/outputs:") > print(separator) -> for i, (input_sentence, output) in enumerate(zip(input_sentences, generated)): +> for i, input_sentence in enumerate(zip(input_sentences)): > print(f"input {i+1}: {input_sentence}") -> print(f"output {i+1}: {output}") +> for j, output in enumerate( +> zip(generated[args.num_return_sequences * i : args.num_return_sequences * (i + 1)]) +> ): +> print(f"output {j+1}: {output}") > print(separator) -> else: +390c369,380 +< input_ids = encoded_prompt +--- > # Downloading and loading a dataset from the hub. > from datasets import load_dataset > from torch.utils.data import DataLoader @@ -666,9 +698,14 @@ > else: > split = "train" > raw_dataset = raw_dataset[split] -420,421c322,332 -< # Remove all text after the stop token -< text = text[: text.find(args.stop_token) if args.stop_token else None] +392,398c382,390 +< if args.jit: +< jit_input_texts = ["enable jit"] +< jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) +< torch._C._jit_set_texpr_fuser_enabled(False) +< model.config.return_dict = False +< if hasattr(model, "forward"): +< sig = inspect.signature(model.forward) --- > if args.column_name is None: > # If no column name is given, take the first column that has strings @@ -679,25 +716,61 @@ > logger.info( > f"No column name was given so automatically choosing '{column_name}' for prompts. If you would like to use another column of the dataset, you can set the argument `--column_name`." > ) -> else: -> column_name = args.column_name -423,426c334,335 +400,439c392 +< sig = inspect.signature(model.__call__) +< jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) +< traced_model = torch.jit.trace(model, jit_inputs, strict=False) +< traced_model = torch.jit.freeze(traced_model.eval()) +< traced_model(*jit_inputs) +< traced_model(*jit_inputs) +< +< model = _ModelFallbackWrapper(traced_model, model) +< +< output_sequences = model.generate( +< input_ids=input_ids, +< max_length=args.length + len(encoded_prompt[0]), +< temperature=args.temperature, +< top_k=args.k, +< top_p=args.p, +< repetition_penalty=args.repetition_penalty, +< do_sample=True, +< num_return_sequences=args.num_return_sequences, +< ) +< +< # Remove the batch dimension when returning multiple sequences +< if len(output_sequences.shape) > 2: +< output_sequences.squeeze_() +< +< generated_sequences = [] +< +< for generated_sequence_idx, generated_sequence in enumerate(output_sequences): +< print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") +< generated_sequence = generated_sequence.tolist() +< +< # Decode text +< text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) +< +< # Remove all text after the stop token +< text = text[: text.find(args.stop_token) if args.stop_token else None] +< < # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing < total_sequence = ( < prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] < ) --- -> # Remove unused columns -> raw_dataset = raw_dataset.remove_columns([name for name in raw_dataset.column_names if name != column_name]) -428,429c337,338 +> column_name = args.column_name +441,442c394,395 < generated_sequences.append(total_sequence) < print(total_sequence) --- -> # Set the prompt length to 16 -> prompt_length = 16 -431c340,387 +> # Remove unused columns +> raw_dataset = raw_dataset.remove_columns([name for name in raw_dataset.column_names if name != column_name]) +444c397,445 < return generated_sequences --- +> # Set the prompt length to 16 +> prompt_length = 16 +> > def preprocess_function(examples): > # Tokenize the texts > return tokenizer(examples[column_name], padding="max_length", max_length=prompt_length, truncation=True) @@ -718,24 +791,20 @@ > dataloader = DataLoader(raw_dataset, batch_size=args.batch_size) > for i, batch in enumerate(dataloader): > prompt = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True) -> # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs -> batch["input_ids"] = F.pad(batch["input_ids"], (0, args.max_new_tokens), value=model.config.pad_token_id) -> batch["attention_mask"] = F.pad(batch["attention_mask"], (0, args.max_new_tokens), value=0) -> # prompt = batch.pop(column_name) +> > # Move inputs to target device(s) > for t in batch: > if torch.is_tensor(batch[t]): > batch[t] = batch[t].to(args.device) -> if is_bloom: -> # token_idx is the current index in the generation process, it is incremented each time a new token is generated -> batch["token_idx"] = torch.tensor(prompt_length, device=args.device) > > # Generate new sequences > outputs = model.generate( > **batch, > generation_config=generation_config, -> lazy_mode=args.use_hpu_graphs, +> lazy_mode=True, > hpu_graphs=args.use_hpu_graphs, +> profiling_steps=args.profiling_steps, +> profiling_warmup_steps=args.profiling_warmup_steps, > ).cpu() > > # Print outputs @@ -745,4 +814,6 @@ > print(separator) > print(f"Batch n°{i+1}") > print(f"Input: {prompt[:args.batch_size]}") -> print(f"Output: {tokenizer.batch_decode(outputs, skip_special_tokens=True)[:args.batch_size]}") +> print( +> f"Output: {tokenizer.batch_decode(outputs, skip_special_tokens=True)[:args.batch_size*args.num_return_sequences]}" +> ) diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt index f43e2e056f..7b7167bf7c 100644 --- a/tests/example_diff/run_glue.txt +++ b/tests/example_diff/run_glue.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 51c50 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 211c210 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- @@ -28,20 +28,22 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -243,244c249,251 +241a248 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +243,244c250,252 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -514c521 +> + f"mixed-precision training: {mixed_precision}" +514c522 < trainer = Trainer( --- > trainer = GaudiTrainer( -515a523 +515a524 > gaudi_config=gaudi_config, -615,619d622 +615,619d623 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index a8ebe18883..9ad0ea959f 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -1,25 +1,27 @@ -24a25 +14a15 +> # limitations under the License. +24a26 > import transformers -36,37d36 +36,37d37 < < import transformers -44,46d42 +44,46d43 < Trainer, < TrainingArguments, < set_seed, -51a48,50 +51a49,51 > from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments > from optimum.habana.utils import set_seed > -58c57 -< check_min_version("4.29.0.dev0") +58c58 +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") -171c170 +> check_min_version("4.31.0") +171c171 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- > parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments)) -199a199,205 +199a200,206 > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, @@ -27,16 +29,18 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -202,203c208,210 +200a208 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +202,203c210,212 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -353c360 +> + f"mixed-precision training: {mixed_precision}" +353c362 < trainer = Trainer( --- > trainer = GaudiTrainer( -354a362 +354a364 > gaudi_config=gaudi_config, diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 76b6316ec4..ffcb53b4d3 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -20,9 +20,9 @@ > from optimum.habana.utils import set_seed > 56c53 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 209c206 < streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) --- @@ -39,22 +39,24 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -265,266c269,271 +263a268 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +265,266c270,272 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -268d272 +> + f"mixed-precision training: {mixed_precision}" +268d273 < # Set the verbosity to info of the Transformers logger (on main process only): -588c592 +587c592 < trainer = Trainer( --- > trainer = GaudiTrainer( -589a594 +588a594 > gaudi_config=gaudi_config, -595,598c600,601 +594,597c600,601 < compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, < preprocess_logits_for_metrics=preprocess_logits_for_metrics < if training_args.do_eval and not is_torch_tpu_available() @@ -62,7 +64,7 @@ --- > compute_metrics=compute_metrics if training_args.do_eval else None, > preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, -651,655d653 +650,654d653 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index 7ed36f88c1..3175b8f217 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -18,9 +18,9 @@ > from optimum.habana import GaudiConfig, GaudiTrainingArguments > from optimum.habana.utils import set_seed 52c52 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 135c135 < " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)." --- @@ -37,16 +37,18 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -254,255c261,263 +252a260 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +254,255c262,264 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -608a617 +> + f"mixed-precision training: {mixed_precision}" +608a618 > gaudi_config=gaudi_config, -677,681d685 +677,681d686 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_seq2seq_qa.txt b/tests/example_diff/run_seq2seq_qa.txt index 79597f7b91..e3b9b3f498 100644 --- a/tests/example_diff/run_seq2seq_qa.txt +++ b/tests/example_diff/run_seq2seq_qa.txt @@ -11,9 +11,9 @@ > from optimum.habana.utils import set_seed > 49c49 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 168c168 < " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)." --- @@ -30,16 +30,18 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -300,301c307,309 +298a306 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +300,301c308,310 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -645a654 +> + f"mixed-precision training: {mixed_precision}" +645a655 > gaudi_config=gaudi_config, -719,723d727 +719,723d728 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt index 86fc91e9b2..234468eebd 100644 --- a/tests/example_diff/run_speech_recognition_ctc.txt +++ b/tests/example_diff/run_speech_recognition_ctc.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 54c53 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 141d139 < 374c372 @@ -29,14 +29,16 @@ > use_auth_token=True if data_args.use_auth_token else None, > ) > -411,412c415,417 +409a414 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +411,412c416,418 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -425,431c430,435 +> + f"mixed-precision training: {mixed_precision}" +425,431c431,436 < if training_args.do_train: < raw_datasets["train"] = load_dataset( < data_args.dataset_name, @@ -51,7 +53,7 @@ > split=data_args.train_split_name, > use_auth_token=data_args.use_auth_token, > ) -433,438c437,442 +433,438c438,443 < if data_args.audio_column_name not in raw_datasets["train"].column_names: < raise ValueError( < f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'." @@ -65,7 +67,7 @@ > " Make sure to set `--audio_column_name` to the correct audio column - one of" > f" {', '.join(raw_datasets['train'].column_names)}." > ) -440,445c444,449 +440,445c445,450 < if data_args.text_column_name not in raw_datasets["train"].column_names: < raise ValueError( < f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " @@ -79,23 +81,23 @@ > "Make sure to set `--text_column_name` to the correct text column - one of " > f"{', '.join(raw_datasets['train'].column_names)}." > ) -447,448c451,452 +447,448c452,453 < if data_args.max_train_samples is not None: < raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) --- > if data_args.max_train_samples is not None: > raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) -466c470 +466c471 < f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None --- > f'[{"".join(data_args.chars_to_ignore).replace(" ", "")}]' if data_args.chars_to_ignore is not None else None -595a600,604 +595a601,605 > raise RuntimeError( > f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one" > f" ({feature_extractor.sampling_rate}).Data resampling should be done. The Datasets library does not" > " support it on HPUs yet." > ) -698c707,711 +698c708,712 < data_collator = DataCollatorCTCWithPadding(processor=processor) --- > data_collator = DataCollatorCTCWithPadding( @@ -103,9 +105,9 @@ > pad_to_multiple_of=int(max_input_length), > pad_to_multiple_of_labels=500, > ) -701c714 +701c715 < trainer = Trainer( --- > trainer = GaudiTrainer( -702a716 +702a717 > gaudi_config=gaudi_config, diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt index 3c4f81d7f1..2038783639 100644 --- a/tests/example_diff/run_summarization.txt +++ b/tests/example_diff/run_summarization.txt @@ -2,26 +2,28 @@ < # Copyright 2021 The HuggingFace Team. All rights reserved. --- > # Copyright 2022 The HuggingFace Team. All rights reserved. -30a31 +20a21 +> import copy +30a32 > import transformers -33,34d33 +33,34d34 < < import transformers -45,47c44 +45,47c45 < Seq2SeqTrainer, < Seq2SeqTrainingArguments, < set_seed, --- > default_data_collator, -52a50,52 +52a51,53 > from optimum.habana import GaudiConfig, GaudiSeq2SeqTrainer, GaudiSeq2SeqTrainingArguments > from optimum.habana.utils import set_seed > -55c55 -< check_min_version("4.29.0.dev0") +55c56 +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") -119a120,128 +> check_min_version("4.31.0") +119a121,129 > use_cache: bool = field( > default=True, > metadata={ @@ -31,15 +33,17 @@ > ) > }, > ) -203c212 +203c213 < "efficient on GPU but very bad for TPU." --- > "efficient on GPU but very bad for HPU in lazy mode." -307c316 +251a262 +> source_suffix: Optional[str] = field(default="", metadata={"help": "A suffix to add after every source text."}) +307c318 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) --- > parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiSeq2SeqTrainingArguments)) -336a346,352 +336a348,354 > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, @@ -47,16 +51,27 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -339,340c355,357 +337a356 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +339,340c358,360 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -420a438 +> + f"mixed-precision training: {mixed_precision}" +420a441 > use_cache=False if training_args.gradient_checkpointing else model_args.use_cache, -611,616c629,637 +473a495 +> suffix = data_args.source_suffix if data_args.source_suffix is not None else "" +544a567,568 +> else: +> raise ValueError("Found case where either text or summary is missing.") +546c570 +< inputs = [prefix + inp for inp in inputs] +--- +> inputs = [prefix + inp + suffix for inp in inputs] +611,616c635,643 < data_collator = DataCollatorForSeq2Seq( < tokenizer, < model=model, @@ -73,13 +88,42 @@ > label_pad_token_id=label_pad_token_id, > pad_to_multiple_of=8 if training_args.fp16 else None, > ) -661c682 +651,658c678,686 +< training_args.generation_max_length = ( +< training_args.generation_max_length +< if training_args.generation_max_length is not None +< else data_args.val_max_target_length +< ) +< training_args.generation_num_beams = ( +< data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams +< ) +--- +> training_args.generation_config = copy.deepcopy(model.generation_config) +> if training_args.generation_max_length is not None: +> training_args.generation_config.max_length = training_args.generation_max_length +> else: +> training_args.generation_config.max_length = data_args.val_max_target_length +> if data_args.num_beams is not None: +> training_args.generation_config.num_beams = data_args.num_beams +> elif training_args.generation_num_beams is not None: +> training_args.generation_config.num_beams = training_args.generation_num_beams +661c689 < trainer = Seq2SeqTrainer( --- > trainer = GaudiSeq2SeqTrainer( -662a684 +662a691 > gaudi_config=gaudi_config, -745,749d766 +695,701c724 +< if isinstance(eval_dataset, dict): +< metrics = {} +< for eval_ds_name, eval_ds in eval_dataset.items(): +< dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}") +< metrics.update(dataset_metrics) +< else: +< metrics = trainer.evaluate(metric_key_prefix="eval") +--- +> metrics = trainer.evaluate(metric_key_prefix="eval") +751,755d773 < < < def _mp_fn(index): diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt index b463c1908a..5bbf43ea9f 100644 --- a/tests/example_diff/run_translation.txt +++ b/tests/example_diff/run_translation.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 55c54 -< check_min_version("4.29.0.dev0") +< check_min_version("4.32.0.dev0") --- -> check_min_version("4.28.0") +> check_min_version("4.31.0") 100a100,108 > use_cache: bool = field( > default=True, @@ -42,22 +42,24 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -288,289c303,305 +286a302 +> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision +288,289c304,306 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" -< + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {bool(training_args.local_rank != -1)}, " -> + f"mixed-precision training: {gaudi_config.use_habana_mixed_precision}" -369a386 +> + f"mixed-precision training: {mixed_precision}" +369a387 > use_cache=False if training_args.gradient_checkpointing else model_args.use_cache, -564c581 +564c582 < trainer = Seq2SeqTrainer( --- > trainer = GaudiSeq2SeqTrainer( -565a583 +565a584 > gaudi_config=gaudi_config, -658,662d675 +658,662d676 < < < def _mp_fn(index): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 4235a8106f..e4cbd9ce27 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -32,6 +32,7 @@ from parameterized import parameterized from requests.exceptions import HTTPError from transformers import IntervalStrategy, PretrainedConfig, is_torch_available +from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS from transformers.testing_utils import ( ENDPOINT_STAGING, TOKEN, @@ -47,7 +48,7 @@ require_tokenizers, require_torch, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -547,6 +548,87 @@ def test_custom_optimizer(self): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) + def test_reduce_lr_on_plateau_args(self): + # test passed arguments for a custom ReduceLROnPlateau scheduler + train_dataset = RegressionDataset(length=64) + eval_dataset = RegressionDataset(length=64) + gaudi_config = get_gaudi_config() + gaudi_config.use_fused_adam = False + args = GaudiTrainingArguments( + "./regression", + evaluation_strategy="epoch", + metric_for_best_model="eval_loss", + use_habana=True, + use_lazy_mode=True, + ) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=1.0) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2) + trainer = GaudiTrainer( + model, + gaudi_config, + args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizers=(optimizer, lr_scheduler), + ) + trainer.train() + + self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + self.assertEqual(trainer.lr_scheduler.factor, 0.2) + self.assertEqual(trainer.lr_scheduler.patience, 5) + self.assertEqual(trainer.lr_scheduler.cooldown, 2) + + def test_reduce_lr_on_plateau(self): + # test the ReduceLROnPlateau scheduler + + class TrainerWithLRLogs(GaudiTrainer): + def log(self, logs): + # the LR is computed after metrics and does not exist for the first epoch + if hasattr(self.lr_scheduler, "_last_lr"): + logs["learning_rate"] = self.lr_scheduler._last_lr + super().log(logs) + + train_dataset = RegressionDataset(length=64) + eval_dataset = RegressionDataset(length=64) + gaudi_config = get_gaudi_config() + gaudi_config.use_fused_adam = False + + args = GaudiTrainingArguments( + "./regression", + lr_scheduler_type="reduce_lr_on_plateau", + evaluation_strategy="epoch", + metric_for_best_model="eval_loss", + num_train_epochs=10, + learning_rate=0.2, + use_habana=True, + use_lazy_mode=True, + ) + model = RegressionModel() + trainer = TrainerWithLRLogs(model, gaudi_config, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + + self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + patience = trainer.lr_scheduler.patience + + logs = trainer.state.log_history[1:] + best_loss = logs[0]["eval_loss"] + bad_epochs = 0 + for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's + loss = log["eval_loss"] + just_decreased = False + if loss > best_loss: + bad_epochs += 1 + if bad_epochs > patience: + self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0]) + just_decreased = True + bad_epochs = 0 + else: + best_loss = loss + bad_epochs = 0 + if not just_decreased: + self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0]) + def test_adafactor_lr_none(self): # test the special case where lr=None, since Trainer can't not have lr_scheduler @@ -710,9 +792,9 @@ def is_any_loss_nan_or_inf(log_history): def test_train_and_eval_dataloaders(self): trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16) - self.assertEqual(trainer.get_train_dataloader().batch_size, 16) + self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16) trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16) - self.assertEqual(trainer.get_eval_dataloader().batch_size, 16) + self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16) # Check drop_last works trainer = get_regression_trainer( @@ -750,70 +832,6 @@ def test_dataloader_without_dataset(self): trainer.train() trainer.evaluate() - def test_sampler_seed(self): - # nb: we don't want to inherit from IterableDataset to hit the right code path - class DummyDataset(torch.utils.data.Dataset): - def __init__(self, length: int = 101): - self.length = length - - def __len__(self): - return self.length - - def __getitem__(self, i): - if (i < 0) or (i >= self.length): - raise IndexError - return {"input_ids": [i]} - - class DummyModel(PreTrainedModel): - def __init__(self, num_params: int): - super().__init__(PretrainedConfig()) - # Add some (unused) params. the point here is that randomness in model_init shouldn't influence - # data loader order. - self.params = nn.Parameter(torch.randn(num_params)) - - def forward(self, input_ids, labels=None): - if labels is not None: - return torch.tensor(0.0, device=input_ids.device), input_ids - else: - return input_ids - - def _get_first_data_sample(num_params, seed, data_seed, **kwargs): - with tempfile.TemporaryDirectory() as tmpdir: - trainer = GaudiTrainer( - model_init=lambda: DummyModel(num_params), - gaudi_config=get_gaudi_config(), - args=GaudiTrainingArguments( - output_dir=tmpdir, - **kwargs, - seed=seed, - data_seed=data_seed, - local_rank=-1, - use_habana=True, - use_lazy_mode=True, - ), - train_dataset=DummyDataset(), - ) - - return next(iter(trainer.get_train_dataloader())) - - # test that the seed is passed to the sampler - # the codepath we want to hit is world_size <= 1, and both group_by_length - for group_by_length in [True, False]: - sample42_1 = _get_first_data_sample(num_params=10, seed=42, data_seed=42, group_by_length=group_by_length) - sample42_2 = _get_first_data_sample(num_params=11, seed=42, data_seed=42, group_by_length=group_by_length) - self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_2["input_ids"])) - - # should get same samples with different seed, so long as data_seed is the same - sample42_3 = _get_first_data_sample(num_params=11, seed=11, data_seed=42, group_by_length=group_by_length) - self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_3["input_ids"])) - - # make sure we have some randomness in the samples if data_seed is different - others = [ - _get_first_data_sample(num_params=i, seed=42, data_seed=i, group_by_length=group_by_length) - for i in range(10) - ] - self.assertTrue(any(not torch.equal(sample42_1["input_ids"], sample["input_ids"]) for sample in others)) - def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel() # Make the Trainer believe it's a parallelized model @@ -835,9 +853,9 @@ def test_data_is_not_parallelized_when_model_is_parallel(self): self.assertEqual(trainer.args.n_gpu, 1) # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu - self.assertEqual(trainer.get_train_dataloader().batch_size, 16) + self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16) self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16) - self.assertEqual(trainer.get_eval_dataloader().batch_size, 16) + self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16) self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16) def test_evaluate(self): @@ -1390,29 +1408,6 @@ def test_training_iterable_dataset(self): self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) - def test_training_finite_iterable_dataset(self): - config = RegressionModelConfig() - model = RegressionPreTrainedModel(config) - - batch_size = 1 - num_samples = 10 - - available_steps = num_samples // batch_size - - data = FiniteIterableDataset(length=num_samples) - train_args = GaudiTrainingArguments( - "..", - max_steps=available_steps + 1, # set a higher number than actually available - per_device_train_batch_size=batch_size, - use_habana=True, - use_lazy_mode=True, - ) - gaudi_config = get_gaudi_config() - trainer = GaudiTrainer(model, gaudi_config=gaudi_config, train_dataset=data, args=train_args) - with self.assertLogs("optimum.habana.transformers.trainer", level="WARNING") as logs: - trainer.train() - self.assertIn(f"stopping training at step {available_steps}!", logs.output[0]) - def test_evaluation_iterable_dataset(self): config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) @@ -2022,3 +2017,11 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): # trainer.hyperparameter_search( # direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must" # ) + + +class HyperParameterSearchBackendsTest(unittest.TestCase): + def test_hyperparameter_search_backends(self): + self.assertEqual( + list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()), + list(HPSearchBackend), + ) diff --git a/tests/test_trainer_seq2seq.py b/tests/test_trainer_seq2seq.py index a4de241892..455fb4bf95 100644 --- a/tests/test_trainer_seq2seq.py +++ b/tests/test_trainer_seq2seq.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from transformers import T5ForConditionalGeneration, T5Tokenizer +from transformers import AutoTokenizer, T5ForConditionalGeneration from transformers.testing_utils import TestCasePlus, require_torch from transformers.utils import is_datasets_available @@ -28,7 +28,7 @@ class GaudiSeq2seqTrainerTester(TestCasePlus): @require_torch def test_finetune_t5(self): model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5-v1.1") - tokenizer = T5Tokenizer.from_pretrained("t5-small") + tokenizer = AutoTokenizer.from_pretrained("t5-small") model.config.max_length = 128 @@ -77,7 +77,7 @@ def _compute_metrics(pred): ) train_dataset.set_format( type="torch", - columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + columns=["input_ids", "attention_mask", "decoder_input_ids", "labels"], ) # same for validation dataset @@ -89,7 +89,7 @@ def _compute_metrics(pred): ) val_dataset.set_format( type="torch", - columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + columns=["input_ids", "attention_mask", "decoder_input_ids", "labels"], ) output_dir = self.get_auto_remove_tmp_dir()