Skip to content

Commit

Permalink
Fix log bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Apr 27, 2023
1 parent 7cd2a93 commit 3ca95a1
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,32 @@ def main():
logger.info("DeepSpeed is enabled.")
else:
if args.gaudi_config_name_or_path is None:
gaudi_config = None
logger.warning(
"`--gaudi_config_name_or_path` was not specified so not using Habana Mixed Precision for this run."
)
else:
from optimum.habana import GaudiConfig

gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name_or_path)
from habana_frameworks.torch.hpex import hmp

# Open temporary files to mixed-precision write ops
with tempfile.NamedTemporaryFile() as hmp_bf16_file:
with tempfile.NamedTemporaryFile() as hmp_fp32_file:
# hmp.convert needs ops to be written in text files
gaudi_config.write_bf16_fp32_ops_to_text_files(
hmp_bf16_file.name,
hmp_fp32_file.name,
)
hmp.convert(
opt_level=gaudi_config.hmp_opt_level,
bf16_file_path=hmp_bf16_file.name,
fp32_file_path=hmp_fp32_file.name,
isVerbose=gaudi_config.hmp_is_verbose,
)

if gaudi_config.use_habana_mixed_precision:
from habana_frameworks.torch.hpex import hmp

# Open temporary files to mixed-precision write ops
with tempfile.NamedTemporaryFile() as hmp_bf16_file:
with tempfile.NamedTemporaryFile() as hmp_fp32_file:
# hmp.convert needs ops to be written in text files
gaudi_config.write_bf16_fp32_ops_to_text_files(
hmp_bf16_file.name,
hmp_fp32_file.name,
)
hmp.convert(
opt_level=gaudi_config.hmp_opt_level,
bf16_file_path=hmp_bf16_file.name,
fp32_file_path=hmp_fp32_file.name,
isVerbose=gaudi_config.hmp_is_verbose,
)
logger.info("Single-device run.")

# Tweak generation so that it runs faster on Gaudi
Expand Down Expand Up @@ -199,7 +202,11 @@ def main():

if rank in [-1, 0]:
logger.info(f"Args: {args}")
logger.info(f"device: {args.device}, n_hpu: {world_size}, bf16: True")

use_bf16 = False
if use_deepspeed or (gaudi_config is not None and gaudi_config.use_habana_mixed_precision):
use_bf16 = True
logger.info(f"device: {args.device}, n_hpu: {world_size}, bf16: {use_bf16}")

# Generation configuration
generation_config = GenerationConfig(
Expand Down

0 comments on commit 3ca95a1

Please sign in to comment.