Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for autocast custom ops in GaudiTrainer #308

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions optimum/habana/transformers/gaudi_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import warnings
from pathlib import Path

Expand Down Expand Up @@ -59,6 +61,8 @@ def __init__(self, **kwargs):
self.hmp_is_verbose = kwargs.pop("hmp_is_verbose", False)
# Torch Autocast
self.use_torch_autocast = kwargs.pop("use_torch_autocast", False)
self.autocast_bf16_ops = kwargs.pop("autocast_bf16_ops", None)
self.autocast_fp32_ops = kwargs.pop("autocast_fp32_ops", None)

if self.use_habana_mixed_precision and self.use_torch_autocast:
raise ValueError(
Expand All @@ -82,10 +86,31 @@ def write_bf16_fp32_ops_to_text_files(
self,
path_to_bf16_file: Path,
path_to_fp32_file: Path,
autocast: bool = False,
):
for path, ops in zip(
[Path(path_to_bf16_file), Path(path_to_fp32_file)], [self.hmp_bf16_ops, self.hmp_fp32_ops]
):
bf16_ops = self.autocast_bf16_ops if autocast else self.hmp_bf16_ops
fp32_ops = self.autocast_fp32_ops if autocast else self.hmp_fp32_ops

for path, ops in zip([Path(path_to_bf16_file), Path(path_to_fp32_file)], [bf16_ops, fp32_ops]):
with path.open("w") as text_file:
# writelines does not add new lines after each element so "\n" is inserted
text_file.writelines(op + "\n" for op in ops)

def declare_autocast_bf16_fp32_ops(self):
if self.autocast_bf16_ops is not None and self.autocast_fp32_ops is not None:
if "habana_frameworks.torch.core" in sys.modules:
raise RuntimeError(
"Setting bf16/fp32 ops for Torch Autocast but `habana_frameworks.torch.core` has already been imported. "
"You should instantiate your Gaudi config and your training arguments before importing from `habana_frameworks.torch` or calling a method from `optimum.habana.utils`."
)
else:
autocast_bf16_filename = "/tmp/lower_list.txt"
autocast_fp32_filename = "/tmp/fp32_list.txt"

self.write_bf16_fp32_ops_to_text_files(
autocast_bf16_filename,
autocast_fp32_filename,
autocast=True,
)
os.environ["LOWER_LIST"] = autocast_bf16_filename
os.environ["FP32_LIST"] = autocast_fp32_filename
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def gaudi_gpt2_forward(

from habana_frameworks.torch.hpex import hmp

with hmp.disable_casts():
with hmp.disable_casts(), torch.autocast(enabled=False, device_type="hpu"):
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# If a 2D or 3D attention mask is provided for the cross-attention
Expand Down
22 changes: 13 additions & 9 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
is_safetensors_available,
)

from optimum.habana.distributed import all_reduce_gradients
from optimum.utils import logging

from ..accelerate import GaudiAccelerator
Expand Down Expand Up @@ -186,14 +185,6 @@ def __init__(
self.gaudi_config = copy.deepcopy(gaudi_config)

if self.args.use_habana:
if self.args.use_lazy_mode:
try:
import habana_frameworks.torch.core as htcore
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
raise error
self.htcore = htcore

if self.args.use_hpu_graphs_for_inference:
self.already_wrapped_for_hpu_graphs = False

Expand All @@ -219,6 +210,9 @@ def __init__(
"`--bf16` was given and `use_habana_mixed_precision` is True in the Gaudi configuration. Using Torch Autocast as mixed-precision backend."
)

if self.use_hpu_amp and "LOWER_LIST" not in os.environ:
gaudi_config.declare_autocast_bf16_fp32_ops()

if self.gaudi_config.use_habana_mixed_precision and not (self.use_hpu_amp or self.use_cpu_amp):
try:
from habana_frameworks.torch.hpex import hmp
Expand Down Expand Up @@ -249,6 +243,14 @@ def __init__(
isVerbose=self.gaudi_config.hmp_is_verbose,
)

if self.args.use_lazy_mode:
try:
import habana_frameworks.torch.core as htcore
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
raise error
self.htcore = htcore

try:
from habana_frameworks.torch.hpu import random as hpu_random
except ImportError as error:
Expand Down Expand Up @@ -719,6 +721,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
# In multi-worker training: broadcast model parameters from worker:0 to all the others.
# This must be done manually unless DistributedDataParallel is used.
if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "fast_ddp":
from ..distributed import all_reduce_gradients

logger.debug(
f"Broadcasting the model parameters to assure that each of {self.args.world_size} workers start the training from the same point."
)
Expand Down
11 changes: 11 additions & 0 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from ..accelerate.state import GaudiAcceleratorState, GaudiPartialState
from ..accelerate.utils import GaudiDistributedType
from .gaudi_configuration import GaudiConfig


if is_torch_available():
Expand Down Expand Up @@ -573,6 +574,16 @@ def __str__(self):
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])

# Hack to make sure bf16/fp32 ops are specified before calling habana_frameworks.torch.core
if self.gaudi_config_name is not None:
gaudi_config = GaudiConfig.from_pretrained(self.gaudi_config_name)
if (
(self.bf16 or gaudi_config.use_torch_autocast)
and not self.deepspeed
and self.half_precision_backend == "hpu_amp"
):
gaudi_config.declare_autocast_bf16_fp32_ops()

logger.info("PyTorch: setting up devices")
if not is_accelerate_available(min_version="0.21.0"):
raise ImportError(
Expand Down
6 changes: 4 additions & 2 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import numpy as np
import torch
from habana_frameworks.torch.hpu import memory_stats
from habana_frameworks.torch.hpu import random as hpu_random
from packaging import version
from transformers.utils import is_torch_available

Expand Down Expand Up @@ -136,6 +134,8 @@ def get_hpu_memory_stats(device=None) -> Dict[str, float]:
Returns:
Dict[str, float]: memory stats.
"""
from habana_frameworks.torch.hpu import memory_stats

mem_stats = memory_stats(device)

mem_dict = {
Expand All @@ -156,6 +156,8 @@ def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
from habana_frameworks.torch.hpu import random as hpu_random

torch.manual_seed(seed)
hpu_random.manual_seed_all(seed)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_gaudi_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def test_default_parameter_types(self):

self.assertTrue(is_list_of_strings(gaudi_config.hmp_bf16_ops))
self.assertTrue(is_list_of_strings(gaudi_config.hmp_fp32_ops))
self.assertIsNone(gaudi_config.autocast_bf16_ops)
self.assertIsNone(gaudi_config.autocast_fp32_ops)

def test_write_bf16_fp32_ops_to_text_files(self):
gaudi_config = GaudiConfig()
Expand Down
Loading