diff --git a/optimum/habana/transformers/gaudi_configuration.py b/optimum/habana/transformers/gaudi_configuration.py index 7230c26953..da30341acd 100644 --- a/optimum/habana/transformers/gaudi_configuration.py +++ b/optimum/habana/transformers/gaudi_configuration.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import warnings from pathlib import Path @@ -59,6 +61,8 @@ def __init__(self, **kwargs): self.hmp_is_verbose = kwargs.pop("hmp_is_verbose", False) # Torch Autocast self.use_torch_autocast = kwargs.pop("use_torch_autocast", False) + self.autocast_bf16_ops = kwargs.pop("autocast_bf16_ops", None) + self.autocast_fp32_ops = kwargs.pop("autocast_fp32_ops", None) if self.use_habana_mixed_precision and self.use_torch_autocast: raise ValueError( @@ -82,10 +86,31 @@ def write_bf16_fp32_ops_to_text_files( self, path_to_bf16_file: Path, path_to_fp32_file: Path, + autocast: bool = False, ): - for path, ops in zip( - [Path(path_to_bf16_file), Path(path_to_fp32_file)], [self.hmp_bf16_ops, self.hmp_fp32_ops] - ): + bf16_ops = self.autocast_bf16_ops if autocast else self.hmp_bf16_ops + fp32_ops = self.autocast_fp32_ops if autocast else self.hmp_fp32_ops + + for path, ops in zip([Path(path_to_bf16_file), Path(path_to_fp32_file)], [bf16_ops, fp32_ops]): with path.open("w") as text_file: # writelines does not add new lines after each element so "\n" is inserted text_file.writelines(op + "\n" for op in ops) + + def declare_autocast_bf16_fp32_ops(self): + if self.autocast_bf16_ops is not None and self.autocast_fp32_ops is not None: + if "habana_frameworks.torch.core" in sys.modules: + raise RuntimeError( + "Setting bf16/fp32 ops for Torch Autocast but `habana_frameworks.torch.core` has already been imported. " + "You should instantiate your Gaudi config and your training arguments before importing from `habana_frameworks.torch` or calling a method from `optimum.habana.utils`." + ) + else: + autocast_bf16_filename = "/tmp/lower_list.txt" + autocast_fp32_filename = "/tmp/fp32_list.txt" + + self.write_bf16_fp32_ops_to_text_files( + autocast_bf16_filename, + autocast_fp32_filename, + autocast=True, + ) + os.environ["LOWER_LIST"] = autocast_bf16_filename + os.environ["FP32_LIST"] = autocast_fp32_filename diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index b9418bb61f..c61c90a8d2 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -392,7 +392,7 @@ def gaudi_gpt2_forward( from habana_frameworks.torch.hpex import hmp - with hmp.disable_casts(): + with hmp.disable_casts(), torch.autocast(enabled=False, device_type="hpu"): attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 6cdf139698..b42b5fe2f3 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -84,7 +84,6 @@ is_safetensors_available, ) -from optimum.habana.distributed import all_reduce_gradients from optimum.utils import logging from ..accelerate import GaudiAccelerator @@ -186,14 +185,6 @@ def __init__( self.gaudi_config = copy.deepcopy(gaudi_config) if self.args.use_habana: - if self.args.use_lazy_mode: - try: - import habana_frameworks.torch.core as htcore - except ImportError as error: - error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}." - raise error - self.htcore = htcore - if self.args.use_hpu_graphs_for_inference: self.already_wrapped_for_hpu_graphs = False @@ -219,6 +210,9 @@ def __init__( "`--bf16` was given and `use_habana_mixed_precision` is True in the Gaudi configuration. Using Torch Autocast as mixed-precision backend." ) + if self.use_hpu_amp and "LOWER_LIST" not in os.environ: + gaudi_config.declare_autocast_bf16_fp32_ops() + if self.gaudi_config.use_habana_mixed_precision and not (self.use_hpu_amp or self.use_cpu_amp): try: from habana_frameworks.torch.hpex import hmp @@ -249,6 +243,14 @@ def __init__( isVerbose=self.gaudi_config.hmp_is_verbose, ) + if self.args.use_lazy_mode: + try: + import habana_frameworks.torch.core as htcore + except ImportError as error: + error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}." + raise error + self.htcore = htcore + try: from habana_frameworks.torch.hpu import random as hpu_random except ImportError as error: @@ -719,6 +721,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # In multi-worker training: broadcast model parameters from worker:0 to all the others. # This must be done manually unless DistributedDataParallel is used. if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "fast_ddp": + from ..distributed import all_reduce_gradients + logger.debug( f"Broadcasting the model parameters to assure that each of {self.args.world_size} workers start the training from the same point." ) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 714bb1e754..73e6efc948 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -41,6 +41,7 @@ from ..accelerate.state import GaudiAcceleratorState, GaudiPartialState from ..accelerate.utils import GaudiDistributedType +from .gaudi_configuration import GaudiConfig if is_torch_available(): @@ -573,6 +574,16 @@ def __str__(self): def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) + # Hack to make sure bf16/fp32 ops are specified before calling habana_frameworks.torch.core + if self.gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(self.gaudi_config_name) + if ( + (self.bf16 or gaudi_config.use_torch_autocast) + and not self.deepspeed + and self.half_precision_backend == "hpu_amp" + ): + gaudi_config.declare_autocast_bf16_fp32_ops() + logger.info("PyTorch: setting up devices") if not is_accelerate_available(min_version="0.21.0"): raise ImportError( diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 9bf906dbaf..a158f4f983 100644 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -20,8 +20,6 @@ import numpy as np import torch -from habana_frameworks.torch.hpu import memory_stats -from habana_frameworks.torch.hpu import random as hpu_random from packaging import version from transformers.utils import is_torch_available @@ -136,6 +134,8 @@ def get_hpu_memory_stats(device=None) -> Dict[str, float]: Returns: Dict[str, float]: memory stats. """ + from habana_frameworks.torch.hpu import memory_stats + mem_stats = memory_stats(device) mem_dict = { @@ -156,6 +156,8 @@ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) if is_torch_available(): + from habana_frameworks.torch.hpu import random as hpu_random + torch.manual_seed(seed) hpu_random.manual_seed_all(seed) diff --git a/tests/test_gaudi_configuration.py b/tests/test_gaudi_configuration.py index 755def2ce1..14e3e7838e 100644 --- a/tests/test_gaudi_configuration.py +++ b/tests/test_gaudi_configuration.py @@ -58,6 +58,8 @@ def test_default_parameter_types(self): self.assertTrue(is_list_of_strings(gaudi_config.hmp_bf16_ops)) self.assertTrue(is_list_of_strings(gaudi_config.hmp_fp32_ops)) + self.assertIsNone(gaudi_config.autocast_bf16_ops) + self.assertIsNone(gaudi_config.autocast_fp32_ops) def test_write_bf16_fp32_ops_to_text_files(self): gaudi_config = GaudiConfig()