From fc5aa192005cee01354bd30cebf37e2cc7827ae1 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:38:55 +0000 Subject: [PATCH 1/4] Add support for autocast custom ops in `GaudiTrainer` --- .../gaudi_config_custom_autocast.json | 27 +++++++++++++++++ .../transformers/gaudi_configuration.py | 2 ++ .../transformers/models/gpt2/modeling_gpt2.py | 5 ++-- optimum/habana/transformers/trainer.py | 30 ++++++++++++++----- 4 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 examples/language-modeling/gaudi_config_custom_autocast.json diff --git a/examples/language-modeling/gaudi_config_custom_autocast.json b/examples/language-modeling/gaudi_config_custom_autocast.json new file mode 100644 index 0000000000..e97c77324a --- /dev/null +++ b/examples/language-modeling/gaudi_config_custom_autocast.json @@ -0,0 +1,27 @@ +{ + "use_torch_autocast": true, + "use_fused_adam": true, + "use_fused_clip_norm": true, + "autocast_bf16_ops": [ + "add", + "addmm", + "bmm", + "div", + "dropout", + "gelu", + "iadd", + "linear", + "layer_norm", + "matmul", + "mm", + "rsub", + "softmax", + "truediv" + ], + "autocast_fp32_ops": [ + "embedding", + "nll_loss", + "log_softmax", + "cross_entropy" + ] +} \ No newline at end of file diff --git a/optimum/habana/transformers/gaudi_configuration.py b/optimum/habana/transformers/gaudi_configuration.py index 7230c26953..3ce6c0bc14 100644 --- a/optimum/habana/transformers/gaudi_configuration.py +++ b/optimum/habana/transformers/gaudi_configuration.py @@ -59,6 +59,8 @@ def __init__(self, **kwargs): self.hmp_is_verbose = kwargs.pop("hmp_is_verbose", False) # Torch Autocast self.use_torch_autocast = kwargs.pop("use_torch_autocast", False) + self.autocast_bf16_ops = kwargs.pop("autocast_bf16_ops", None) + self.autocast_fp32_ops = kwargs.pop("autocast_fp32_ops", None) if self.use_habana_mixed_precision and self.use_torch_autocast: raise ValueError( diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 4fda9abec5..5168bffbb7 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -387,11 +387,12 @@ def gaudi_gpt2_forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = 1.0 - attention_mask from habana_frameworks.torch.hpex import hmp - with hmp.disable_casts(): - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + with hmp.disable_casts(), torch.autocast(enabled=False, device_type="hpu"): + attention_mask = attention_mask * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index ccf08e0c7b..4ca2f1d069 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -176,14 +176,6 @@ def __init__( self.gaudi_config = copy.deepcopy(gaudi_config) if self.args.use_habana: - if self.args.use_lazy_mode: - try: - import habana_frameworks.torch.core as htcore - except ImportError as error: - error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}." - raise error - self.htcore = htcore - if self.args.use_hpu_graphs_for_inference: self.already_wrapped_for_hpu_graphs = False @@ -209,6 +201,20 @@ def __init__( "`--bf16` was given and `use_habana_mixed_precision` is True in the Gaudi configuration. Using Torch Autocast as mixed-precision backend." ) + if self.use_hpu_amp: + if self.gaudi_config.autocast_bf16_ops is not None and self.gaudi_config.autocast_fp32_ops is not None: + # Open temporary files to write mixed-precision ops + with tempfile.NamedTemporaryFile() as autocast_bf16_file: + with tempfile.NamedTemporaryFile() as autocast_fp32_file: + self.gaudi_config.write_bf16_fp32_ops_to_text_files( + autocast_bf16_file.name, + autocast_fp32_file.name, + ) + os.environ["LOWER_LIST"] = str(autocast_bf16_file) + os.environ["FP32_LIST"] = str(autocast_fp32_file) + + import habana_frameworks.torch.core # noqa + if self.gaudi_config.use_habana_mixed_precision and not (self.use_hpu_amp or self.use_cpu_amp): try: from habana_frameworks.torch.hpex import hmp @@ -232,6 +238,14 @@ def __init__( isVerbose=self.gaudi_config.hmp_is_verbose, ) + if self.args.use_lazy_mode: + try: + import habana_frameworks.torch.core as htcore + except ImportError as error: + error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}." + raise error + self.htcore = htcore + try: from habana_frameworks.torch.hpu import random as hpu_random except ImportError as error: From 11c544bf057695f5b354c78a460f7a582991bc43 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sat, 9 Sep 2023 13:59:50 +0000 Subject: [PATCH 2/4] Add possibility to use custom bf16/fp32 ops with Torch Autocast --- .../transformers/gaudi_configuration.py | 29 +++++++++++++++++-- .../transformers/models/gpt2/modeling_gpt2.py | 3 +- optimum/habana/transformers/trainer.py | 18 +++--------- optimum/habana/transformers/training_args.py | 11 +++++++ optimum/habana/utils.py | 6 ++-- 5 files changed, 46 insertions(+), 21 deletions(-) diff --git a/optimum/habana/transformers/gaudi_configuration.py b/optimum/habana/transformers/gaudi_configuration.py index 3ce6c0bc14..da30341acd 100644 --- a/optimum/habana/transformers/gaudi_configuration.py +++ b/optimum/habana/transformers/gaudi_configuration.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import warnings from pathlib import Path @@ -84,10 +86,31 @@ def write_bf16_fp32_ops_to_text_files( self, path_to_bf16_file: Path, path_to_fp32_file: Path, + autocast: bool = False, ): - for path, ops in zip( - [Path(path_to_bf16_file), Path(path_to_fp32_file)], [self.hmp_bf16_ops, self.hmp_fp32_ops] - ): + bf16_ops = self.autocast_bf16_ops if autocast else self.hmp_bf16_ops + fp32_ops = self.autocast_fp32_ops if autocast else self.hmp_fp32_ops + + for path, ops in zip([Path(path_to_bf16_file), Path(path_to_fp32_file)], [bf16_ops, fp32_ops]): with path.open("w") as text_file: # writelines does not add new lines after each element so "\n" is inserted text_file.writelines(op + "\n" for op in ops) + + def declare_autocast_bf16_fp32_ops(self): + if self.autocast_bf16_ops is not None and self.autocast_fp32_ops is not None: + if "habana_frameworks.torch.core" in sys.modules: + raise RuntimeError( + "Setting bf16/fp32 ops for Torch Autocast but `habana_frameworks.torch.core` has already been imported. " + "You should instantiate your Gaudi config and your training arguments before importing from `habana_frameworks.torch` or calling a method from `optimum.habana.utils`." + ) + else: + autocast_bf16_filename = "/tmp/lower_list.txt" + autocast_fp32_filename = "/tmp/fp32_list.txt" + + self.write_bf16_fp32_ops_to_text_files( + autocast_bf16_filename, + autocast_fp32_filename, + autocast=True, + ) + os.environ["LOWER_LIST"] = autocast_bf16_filename + os.environ["FP32_LIST"] = autocast_fp32_filename diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 5e2655dfd5..c61c90a8d2 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -389,12 +389,11 @@ def gaudi_gpt2_forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = 1.0 - attention_mask from habana_frameworks.torch.hpex import hmp with hmp.disable_casts(), torch.autocast(enabled=False, device_type="hpu"): - attention_mask = attention_mask * torch.finfo(self.dtype).min + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 66cc55d430..b42b5fe2f3 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -84,7 +84,6 @@ is_safetensors_available, ) -from optimum.habana.distributed import all_reduce_gradients from optimum.utils import logging from ..accelerate import GaudiAccelerator @@ -211,19 +210,8 @@ def __init__( "`--bf16` was given and `use_habana_mixed_precision` is True in the Gaudi configuration. Using Torch Autocast as mixed-precision backend." ) - if self.use_hpu_amp: - if self.gaudi_config.autocast_bf16_ops is not None and self.gaudi_config.autocast_fp32_ops is not None: - # Open temporary files to write mixed-precision ops - with tempfile.NamedTemporaryFile() as autocast_bf16_file: - with tempfile.NamedTemporaryFile() as autocast_fp32_file: - self.gaudi_config.write_bf16_fp32_ops_to_text_files( - autocast_bf16_file.name, - autocast_fp32_file.name, - ) - os.environ["LOWER_LIST"] = str(autocast_bf16_file) - os.environ["FP32_LIST"] = str(autocast_fp32_file) - - import habana_frameworks.torch.core # noqa + if self.use_hpu_amp and "LOWER_LIST" not in os.environ: + gaudi_config.declare_autocast_bf16_fp32_ops() if self.gaudi_config.use_habana_mixed_precision and not (self.use_hpu_amp or self.use_cpu_amp): try: @@ -733,6 +721,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # In multi-worker training: broadcast model parameters from worker:0 to all the others. # This must be done manually unless DistributedDataParallel is used. if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "fast_ddp": + from ..distributed import all_reduce_gradients + logger.debug( f"Broadcasting the model parameters to assure that each of {self.args.world_size} workers start the training from the same point." ) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 714bb1e754..73e6efc948 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -41,6 +41,7 @@ from ..accelerate.state import GaudiAcceleratorState, GaudiPartialState from ..accelerate.utils import GaudiDistributedType +from .gaudi_configuration import GaudiConfig if is_torch_available(): @@ -573,6 +574,16 @@ def __str__(self): def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) + # Hack to make sure bf16/fp32 ops are specified before calling habana_frameworks.torch.core + if self.gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(self.gaudi_config_name) + if ( + (self.bf16 or gaudi_config.use_torch_autocast) + and not self.deepspeed + and self.half_precision_backend == "hpu_amp" + ): + gaudi_config.declare_autocast_bf16_fp32_ops() + logger.info("PyTorch: setting up devices") if not is_accelerate_available(min_version="0.21.0"): raise ImportError( diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 9bf906dbaf..a158f4f983 100644 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -20,8 +20,6 @@ import numpy as np import torch -from habana_frameworks.torch.hpu import memory_stats -from habana_frameworks.torch.hpu import random as hpu_random from packaging import version from transformers.utils import is_torch_available @@ -136,6 +134,8 @@ def get_hpu_memory_stats(device=None) -> Dict[str, float]: Returns: Dict[str, float]: memory stats. """ + from habana_frameworks.torch.hpu import memory_stats + mem_stats = memory_stats(device) mem_dict = { @@ -156,6 +156,8 @@ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) if is_torch_available(): + from habana_frameworks.torch.hpu import random as hpu_random + torch.manual_seed(seed) hpu_random.manual_seed_all(seed) From 5c6a2f37f14b35e952c26a6a1df222fe6940595d Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sun, 10 Sep 2023 15:01:26 +0000 Subject: [PATCH 3/4] Delete debug file --- .../gaudi_config_custom_autocast.json | 27 ------------------- 1 file changed, 27 deletions(-) delete mode 100644 examples/language-modeling/gaudi_config_custom_autocast.json diff --git a/examples/language-modeling/gaudi_config_custom_autocast.json b/examples/language-modeling/gaudi_config_custom_autocast.json deleted file mode 100644 index e97c77324a..0000000000 --- a/examples/language-modeling/gaudi_config_custom_autocast.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "use_torch_autocast": true, - "use_fused_adam": true, - "use_fused_clip_norm": true, - "autocast_bf16_ops": [ - "add", - "addmm", - "bmm", - "div", - "dropout", - "gelu", - "iadd", - "linear", - "layer_norm", - "matmul", - "mm", - "rsub", - "softmax", - "truediv" - ], - "autocast_fp32_ops": [ - "embedding", - "nll_loss", - "log_softmax", - "cross_entropy" - ] -} \ No newline at end of file From cc1b68a6b13b54e28b993fab8a6761d4488a3793 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:44:03 +0000 Subject: [PATCH 4/4] Add test --- tests/test_gaudi_configuration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_gaudi_configuration.py b/tests/test_gaudi_configuration.py index 755def2ce1..14e3e7838e 100644 --- a/tests/test_gaudi_configuration.py +++ b/tests/test_gaudi_configuration.py @@ -58,6 +58,8 @@ def test_default_parameter_types(self): self.assertTrue(is_list_of_strings(gaudi_config.hmp_bf16_ops)) self.assertTrue(is_list_of_strings(gaudi_config.hmp_fp32_ops)) + self.assertIsNone(gaudi_config.autocast_bf16_ops) + self.assertIsNone(gaudi_config.autocast_fp32_ops) def test_write_bf16_fp32_ops_to_text_files(self): gaudi_config = GaudiConfig()