diff --git a/docs/source/package_reference/gaudi_config.mdx b/docs/source/package_reference/gaudi_config.mdx
index 66a31e24f7..8c29d06a66 100644
--- a/docs/source/package_reference/gaudi_config.mdx
+++ b/docs/source/package_reference/gaudi_config.mdx
@@ -16,69 +16,20 @@ limitations under the License.
# Gaudi Configuration
-In order to make the most of Gaudi, it is advised to rely on advanced features such as Habana Mixed Precision or optimized operators.
-You can specify which features to use in a Gaudi configuration, which will take the form of a JSON file following this template:
-
-```JSON
-{
- "use_habana_mixed_precision": true/false,
- "hmp_is_verbose": true/false,
- "use_fused_adam": true/false,
- "use_fused_clip_norm": true/false,
- "hmp_bf16_ops": [
- "torch operator to compute in bf16",
- "..."
- ],
- "hmp_fp32_ops": [
- "torch operator to compute in fp32",
- "..."
- ]
-}
-```
-
Here is a description of each configuration parameter:
-- `use_habana_mixed_precision` enables to decide whether or not Habana Mixed Precision (HMP) should be used. HMP allows to mix *fp32* and *bf16* operations. You can find more information [here](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/PT_Mixed_Precision.html).
-- `hmp_is_verbose` enables to decide whether to log precision decisions for each operation for debugging purposes. It is disabled by default. You can find an example of such log [here](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/PT_Mixed_Precision.html#hmp-logs).
- `use_fused_adam` enables to decide whether to use the [custom fused implementation of the ADAM optimizer provided by Habana](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Custom_Ops_PyTorch.html#custom-optimizers).
- `use_fused_clip_norm` enables to decide whether to use the [custom fused implementation of gradient norm clipping provided by Habana](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Custom_Ops_PyTorch.html#other-custom-ops).
-- `hmp_bf16_ops` enables to specify the Torch operations that should be computed in *bf16*. You can find more information about casting rules [here](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/PT_Mixed_Precision.html#basic-design-rules).
-- `hmp_fp32_ops` enables to specify the Torch operations that should be computed in *fp32*. You can find more information about casting rules [here](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/PT_Mixed_Precision.html#basic-design-rules).
-
-
-
-`hmp_is_verbose`, `hmp_bf16_ops` and `hmp_fp32_ops` will not be used if `use_habana_mixed_precision` is false.
+- `use_torch_autocast` enables PyTorch autocast; used to define good pre-defined config; users should favor `--bf16` training argument
+- `autocast_bf16_ops` list of operations that should be run with bf16 precision under autocast context; using environment flag LOWER_LIST is a preffered way for operator autocast list override
+- `autocast_fp32_ops` list of operations that should be run with fp32 precision under autocast context; using environment flag FP32_LIST is a preffered way for operator autocast list override
-
You can find examples of Gaudi configurations in the [Habana model repository on the Hugging Face Hub](https://huggingface.co/habana). For instance, [for BERT Large we have](https://huggingface.co/Habana/bert-large-uncased-whole-word-masking/blob/main/gaudi_config.json):
```JSON
{
- "use_habana_mixed_precision": true,
- "hmp_is_verbose": false,
"use_fused_adam": true,
"use_fused_clip_norm": true,
- "hmp_bf16_ops": [
- "add",
- "addmm",
- "bmm",
- "div",
- "dropout",
- "gelu",
- "iadd",
- "linear",
- "layer_norm",
- "matmul",
- "mm",
- "rsub",
- "softmax",
- "truediv"
- ],
- "hmp_fp32_ops": [
- "embedding",
- "nll_loss",
- "log_softmax"
- ]
}
```
diff --git a/docs/source/usage_guides/accelerate_training.mdx b/docs/source/usage_guides/accelerate_training.mdx
index 04925a32b6..b3a3934dd6 100644
--- a/docs/source/usage_guides/accelerate_training.mdx
+++ b/docs/source/usage_guides/accelerate_training.mdx
@@ -57,44 +57,16 @@ To not take them into account in the computation of the throughput at the end of
## Mixed-Precision Training
Mixed-precision training enables to compute some operations using lighter data types to accelerate training.
-Habana Mixed Precision (HMP) proposes to mix *fp32* and *bf16* operations.
+Optimum Habana enables mixed precision training in a similar fashion as 🤗 Transformers:
+- argument `--bf16` enables usage of PyTorch autocast
+- argument `--half_precision_backend [hpu_amp, cpu_amp]` is used to specify a device on which mixed precision operations should be performed
-
-
-Please refer to the [list of supported PyTorch operators](https://docs.habana.ai/en/latest/PyTorch/Pytorch_Operators/Pytorch_Operators.html) beforehand to make sure the ones you are interested in are compatible with *bf16*.
-
-
-To apply HMP, you must set `"use_habana_mixed_precision"` to `true` in the Gaudi configuration file.
-Then, you can specify which operators to compute in *bf16* with `"hmp_bf16_ops"` and which operators to compute in *fp32* with `"hmp_fp32_ops"`.
-If these operators are not specified, their default values are set to be the ones written in the [Gaudi configuration file of BERT](https://huggingface.co/Habana/bert-large-uncased-whole-word-masking/blob/main/gaudi_config.json), which is a good starting point for applying HMP:
-```
-"hmp_bf16_ops": [
- "add",
- "addmm",
- "bmm",
- "div",
- "dropout",
- "gelu",
- "iadd",
- "linear",
- "layer_norm",
- "matmul",
- "mm",
- "rsub",
- "softmax",
- "truediv"
-],
-"hmp_fp32_ops": [
- "embedding",
- "nll_loss",
- "log_softmax"
-]
-```
-
-
+
-Torch Autocast can also be used as a backend for mixed-precision training. You need to add the argument `--bf16` to enable it.
+Please refer to the [advanced autocast usage on Gaudi](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/Autocast.html) for more informations regarding:
+- default autocast operations
+- default autocast operations override
diff --git a/examples/audio-classification/README.md b/examples/audio-classification/README.md
index 54ec3ce821..2c34826eed 100644
--- a/examples/audio-classification/README.md
+++ b/examples/audio-classification/README.md
@@ -47,7 +47,8 @@ python run_audio_classification.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/wav2vec2 \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
On a single HPU, this script should run in ~13 minutes and yield an accuracy of **97.96%**.
@@ -83,7 +84,8 @@ python ../gaudi_spawn.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/wav2vec2 \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
On 8 HPUs, this script should run in ~12 minutes and yield an accuracy of **80.49%**.
@@ -157,7 +159,8 @@ python run_audio_classification.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --gaudi_config_name Habana/wav2vec2
+ --gaudi_config_name Habana/wav2vec2 \
+ --bf16
```
diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py
index 8f570c98f1..583d639eec 100644
--- a/examples/audio-classification/run_audio_classification.py
+++ b/examples/audio-classification/run_audio_classification.py
@@ -237,7 +237,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/contrastive-image-text/README.md b/examples/contrastive-image-text/README.md
index 058f126d96..cd6b60801c 100644
--- a/examples/contrastive-image-text/README.md
+++ b/examples/contrastive-image-text/README.md
@@ -110,7 +110,8 @@ python run_clip.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/clip \
--throughput_warmup_steps 3 \
- --dataloader_num_workers 16
+ --dataloader_num_workers 16 \
+ --bf16
```
@@ -141,7 +142,8 @@ python ../gaudi_spawn.py --world_size 8 --use_mpi run_clip.py \
--throughput_warmup_steps 3 \
--dataloader_num_workers 16 \
--mediapipe_dataloader \
- --use_hpu_graphs_for_training
+ --use_hpu_graphs_for_training \
+ --bf16
```
> `--mediapipe_dataloader` only works on Gaudi2.
@@ -246,5 +248,6 @@ python run_clip.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --gaudi_config_name Habana/clip
+ --gaudi_config_name Habana/clip \
+ --bf16
```
diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py
index 3f592fdf16..b2f1c5846b 100644
--- a/examples/contrastive-image-text/run_bridgetower.py
+++ b/examples/contrastive-image-text/run_bridgetower.py
@@ -299,7 +299,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py
index f361cbffaa..a0fd41c2b4 100644
--- a/examples/contrastive-image-text/run_clip.py
+++ b/examples/contrastive-image-text/run_clip.py
@@ -301,7 +301,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md
index 811b05d4d2..088d1f66af 100644
--- a/examples/image-classification/README.md
+++ b/examples/image-classification/README.md
@@ -47,7 +47,8 @@ python run_image_classification.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 3 \
- --dataloader_num_workers 1
+ --dataloader_num_workers 1 \
+ --bf16
```
For Swin, you need to change/add the following arguments:
@@ -95,7 +96,8 @@ python run_image_classification.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 3 \
- --dataloader_num_workers 1
+ --dataloader_num_workers 1 \
+ --bf16
```
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
@@ -196,7 +198,8 @@ python ../gaudi_spawn.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 3 \
- --dataloader_num_workers 1
+ --dataloader_num_workers 1 \
+ --bf16
```
For Swin, you need to change/add the following arguments:
@@ -279,4 +282,5 @@ python run_image_classification.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/vit \
- --dataloader_num_workers 1
+ --dataloader_num_workers 1 \
+ --bf16
diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py
index 239ef6e64e..2052117743 100644
--- a/examples/image-classification/run_image_classification.py
+++ b/examples/image-classification/run_image_classification.py
@@ -240,7 +240,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md
index dad97020c6..551bd52e57 100644
--- a/examples/language-modeling/README.md
+++ b/examples/language-modeling/README.md
@@ -178,7 +178,8 @@ python run_mlm.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/roberta-base \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
To run on your own training and validation files, use the following command:
@@ -197,7 +198,8 @@ python run_mlm.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/roberta-base \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
If your dataset is organized with one sample per line, you can use the `--line_by_line` flag (otherwise the script
@@ -223,7 +225,8 @@ python ../gaudi_spawn.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/roberta-base \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -247,7 +250,8 @@ python run_clm.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -315,7 +319,8 @@ python run_clm.py \
--gaudi_config_name Habana/gpt2 \
--use_habana \
--use_lazy_mode \
- --use_hpu_graphs_for_inference
+ --use_hpu_graphs_for_inference \
+ --bf16
```
diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py
index f3b395d3a1..f9546abe03 100644
--- a/examples/language-modeling/run_clm.py
+++ b/examples/language-modeling/run_clm.py
@@ -311,7 +311,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py
index fad7e53190..37970c6698 100644
--- a/examples/language-modeling/run_mlm.py
+++ b/examples/language-modeling/run_mlm.py
@@ -302,7 +302,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/question-answering/README.md b/examples/question-answering/README.md
index 65cf8c3484..677145387f 100644
--- a/examples/question-answering/README.md
+++ b/examples/question-answering/README.md
@@ -53,7 +53,8 @@ python run_qa.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -79,7 +80,8 @@ python ../gaudi_spawn.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -148,7 +150,8 @@ python run_qa.py \
--output_dir /tmp/squad/ \
--use_habana \
--use_lazy_mode \
- --use_hpu_graphs_for_inference
+ --use_hpu_graphs_for_inference \
+ --bf16
```
@@ -198,7 +201,8 @@ python run_seq2seq_qa.py \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
--save_strategy epoch \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
For multi-card and DeepSpeed runs, you can use `python ../gaudi_spawn.py --world_size 8 --use_mpi` and `python ../gaudi_spawn.py --world_size 8 --use_deepspeed` as shown in the previous sections.
diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py
index 729269c798..23f1712ec5 100644
--- a/examples/question-answering/run_qa.py
+++ b/examples/question-answering/run_qa.py
@@ -292,7 +292,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py
index 994a3ae216..8982e84825 100644
--- a/examples/question-answering/run_seq2seq_qa.py
+++ b/examples/question-answering/run_seq2seq_qa.py
@@ -338,7 +338,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md
index 11f65a88e4..78a349f49e 100644
--- a/examples/speech-recognition/README.md
+++ b/examples/speech-recognition/README.md
@@ -67,7 +67,8 @@ python run_speech_recognition_ctc.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name="Habana/wav2vec2" \
- --throughput_warmup_steps="3"
+ --throughput_warmup_steps="3" \
+ --bf16
```
On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**.
@@ -106,7 +107,8 @@ python ../gaudi_spawn.py \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/wav2vec2 \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**.
@@ -185,5 +187,6 @@ python run_speech_recognition_ctc.py \
--do_eval \
--use_habana \
--use_lazy_mode \
- --gaudi_config_name="Habana/wav2vec2"
+ --gaudi_config_name="Habana/wav2vec2" \
+ --bf16
```
diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py
index 497009d4c9..1e6da0e2a6 100644
--- a/examples/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/speech-recognition/run_speech_recognition_ctc.py
@@ -444,7 +444,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/summarization/README.md b/examples/summarization/README.md
index abf9317047..8ebed989ed 100644
--- a/examples/summarization/README.md
+++ b/examples/summarization/README.md
@@ -47,7 +47,8 @@ python run_summarization.py \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
--save_strategy epoch \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
Only T5 models `t5-small`, `t5-base`, `t5-large`, `t5-3b` and `t5-11b` must use an additional argument: `--source_prefix "summarize: "`.
@@ -78,7 +79,8 @@ python run_summarization.py \
--gaudi_config_name Habana/t5 \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
The task of summarization also supports custom CSV and JSONLINES formats.
@@ -163,7 +165,8 @@ python ../gaudi_spawn.py \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
--save_strategy epoch \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -223,7 +226,8 @@ python run_summarization.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/t5 \
--ignore_pad_token_for_loss False \
- --pad_to_max_length
+ --pad_to_max_length \
+ --bf16
```
You can run inference with BART on the CNN-DailyMail dataset on 1 Gaudi card with the following command:
diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py
index 5e74e38ccd..3d5b220aa7 100644
--- a/examples/summarization/run_summarization.py
+++ b/examples/summarization/run_summarization.py
@@ -389,7 +389,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md
index 012f0d5803..57119ca825 100644
--- a/examples/text-classification/README.md
+++ b/examples/text-classification/README.md
@@ -53,7 +53,8 @@ python run_glue.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
@@ -80,7 +81,8 @@ python ../gaudi_spawn.py \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
@@ -150,5 +152,6 @@ python run_glue.py \
--output_dir ./output/mrpc/ \
--use_habana \
--use_lazy_mode \
- --use_hpu_graphs_for_inference
+ --use_hpu_graphs_for_inference \
+ --bf16
```
diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py
index 579c2c2662..c60fe4c990 100755
--- a/examples/text-classification/run_glue.py
+++ b/examples/text-classification/run_glue.py
@@ -280,7 +280,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/examples/translation/README.md b/examples/translation/README.md
index b429104a98..6a24ad151a 100644
--- a/examples/translation/README.md
+++ b/examples/translation/README.md
@@ -49,7 +49,8 @@ python run_translation.py \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
--save_strategy epoch \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
If you get a terrible BLEU score, make sure that you didn't forget to use the `--source_prefix` argument.
@@ -84,7 +85,8 @@ python run_translation.py \
--gaudi_config_name Habana/t5 \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
The task of translation supports only custom JSONLINES files, with each line being a dictionary with the key `"translation"` and its value another dictionary whose keys is the language pair. For example:
@@ -117,7 +119,8 @@ python run_translation.py \
--gaudi_config_name Habana/t5 \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -148,7 +151,8 @@ python ../gaudi_spawn.py \
--ignore_pad_token_for_loss False \
--pad_to_max_length \
--save_strategy epoch \
- --throughput_warmup_steps 3
+ --throughput_warmup_steps 3 \
+ --bf16
```
@@ -228,5 +232,6 @@ python run_translation.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/t5 \
--ignore_pad_token_for_loss False \
- --pad_to_max_length
+ --pad_to_max_length \
+ --bf16
```
diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py
index d5b6bbdcd4..549d1f92a2 100644
--- a/examples/translation/run_translation.py
+++ b/examples/translation/run_translation.py
@@ -334,7 +334,7 @@ def main():
)
# Log on each process the small summary:
- mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision
+ mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py
index 27724d0fd5..83a3a04116 100644
--- a/optimum/habana/diffusers/pipelines/pipeline_utils.py
+++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py
@@ -19,7 +19,6 @@
import inspect
import os
import sys
-import tempfile
from typing import Optional, Union
import torch
@@ -109,47 +108,15 @@ def __init__(
f"`gaudi_config` must be a string or a GaudiConfig object but is {type(gaudi_config)}."
)
- if self.gaudi_config.use_habana_mixed_precision or self.gaudi_config.use_torch_autocast:
+ if self.gaudi_config.use_torch_autocast:
if bf16_full_eval:
logger.warning(
- "`use_habana_mixed_precision` or `use_torch_autocast` is True in the given Gaudi configuration but "
+ "`use_torch_autocast` is True in the given Gaudi configuration but "
"`torch_dtype=torch.blfloat16` was given. Disabling mixed precision and continuing in bf16 only."
)
self.gaudi_config.use_torch_autocast = False
- self.gaudi_config.use_habana_mixed_precision = False
- elif self.gaudi_config.use_torch_autocast:
- # Open temporary files to write mixed-precision ops
- with tempfile.NamedTemporaryFile() as hmp_bf16_file:
- with tempfile.NamedTemporaryFile() as hmp_fp32_file:
- self.gaudi_config.write_bf16_fp32_ops_to_text_files(
- hmp_bf16_file.name,
- hmp_fp32_file.name,
- )
- os.environ["LOWER_LIST"] = str(hmp_bf16_file)
- os.environ["FP32_LIST"] = str(hmp_fp32_file)
-
- import habana_frameworks.torch.core # noqa
- elif self.gaudi_config.use_habana_mixed_precision:
- try:
- from habana_frameworks.torch.hpex import hmp
- except ImportError as error:
- error.msg = f"Could not import habana_frameworks.torch.hpex. {error.msg}."
- raise error
-
- # Open temporary files to write mixed-precision ops
- with tempfile.NamedTemporaryFile() as hmp_bf16_file:
- with tempfile.NamedTemporaryFile() as hmp_fp32_file:
- # hmp.convert needs ops to be written in text files
- self.gaudi_config.write_bf16_fp32_ops_to_text_files(
- hmp_bf16_file.name,
- hmp_fp32_file.name,
- )
- hmp.convert(
- opt_level=self.gaudi_config.hmp_opt_level,
- bf16_file_path=hmp_bf16_file.name,
- fp32_file_path=hmp_fp32_file.name,
- isVerbose=self.gaudi_config.hmp_is_verbose,
- )
+ else:
+ self.gaudi_config.declare_autocast_bf16_fp32_ops()
# Workaround for Synapse 1.11 for full bf16 and Torch Autocast
if bf16_full_eval or self.gaudi_config.use_torch_autocast:
diff --git a/optimum/habana/transformers/gaudi_configuration.py b/optimum/habana/transformers/gaudi_configuration.py
index ca4665491d..76638d8e95 100644
--- a/optimum/habana/transformers/gaudi_configuration.py
+++ b/optimum/habana/transformers/gaudi_configuration.py
@@ -15,7 +15,6 @@
import os
import sys
-import warnings
from pathlib import Path
from optimum.configuration_utils import BaseConfig
@@ -54,45 +53,27 @@ class GaudiConfig(BaseConfig):
FULL_CONFIGURATION_FILE = "gaudi_config.json"
def __init__(self, **kwargs):
- # Habana Mixed Precision (MHP) configuration
- self.use_habana_mixed_precision = kwargs.pop("use_habana_mixed_precision", False)
- self.hmp_bf16_ops = kwargs.pop("hmp_bf16_ops", DEFAULT_BF16_OPS)
- self.hmp_fp32_ops = kwargs.pop("hmp_fp32_ops", DEFAULT_FP32_OPS)
- self.hmp_is_verbose = kwargs.pop("hmp_is_verbose", False)
# Torch Autocast
self.use_torch_autocast = kwargs.pop("use_torch_autocast", False)
self.autocast_bf16_ops = kwargs.pop("autocast_bf16_ops", None)
self.autocast_fp32_ops = kwargs.pop("autocast_fp32_ops", None)
self.use_dynamic_shapes = kwargs.pop("use_dynamic_shapes", False)
- if self.use_habana_mixed_precision and self.use_torch_autocast:
- raise ValueError(
- "`use_habana_mixed_precision` and `use_torch_autocast` cannot be both `True` in your Gaudi configuration, you must choose one or the other to perform mixed-precision training."
- )
-
# Use Habana's custom AdamW implementation
self.use_fused_adam = kwargs.pop("use_fused_adam", False)
# Use Habana's custom fused clip norm implementation
self.use_fused_clip_norm = kwargs.pop("use_fused_clip_norm", False)
# TODO: to remove in a future version
- if "hmp_opt_level" in kwargs:
- warnings.warn(
- "`hmp_opt_level` is deprecated and will be removed in a future version.",
- FutureWarning,
- )
- self.hmp_opt_level = kwargs.pop("hmp_opt_level", "O1")
def write_bf16_fp32_ops_to_text_files(
self,
path_to_bf16_file: Path,
path_to_fp32_file: Path,
- autocast: bool = False,
):
- bf16_ops = self.autocast_bf16_ops if autocast else self.hmp_bf16_ops
- fp32_ops = self.autocast_fp32_ops if autocast else self.hmp_fp32_ops
-
- for path, ops in zip([Path(path_to_bf16_file), Path(path_to_fp32_file)], [bf16_ops, fp32_ops]):
+ for path, ops in zip(
+ [Path(path_to_bf16_file), Path(path_to_fp32_file)], [self.autocast_bf16_ops, self.autocast_fp32_ops]
+ ):
with path.open("w") as text_file:
# writelines does not add new lines after each element so "\n" is inserted
text_file.writelines(op + "\n" for op in ops)
@@ -111,7 +92,6 @@ def declare_autocast_bf16_fp32_ops(self):
self.write_bf16_fp32_ops_to_text_files(
autocast_bf16_filename,
autocast_fp32_filename,
- autocast=True,
)
os.environ["LOWER_LIST"] = autocast_bf16_filename
os.environ["FP32_LIST"] = autocast_fp32_filename
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index aa1610fb09..480e329cfb 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -187,7 +187,7 @@ def adapt_transformers_to_gaudi():
transformers.models.codegen.modeling_codegen.CodeGenBlock.forward = gaudi_codegen_block_forward
# Replace invert_attention_mask and get_extended_attention_mask
- # so that HMP is disabled for specific parts of the code
+ # so that Torch Autocast is disabled for specific parts of the code
transformers.modeling_utils.ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask
transformers.modeling_utils.ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask
# AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced
diff --git a/optimum/habana/transformers/models/albert/modeling_albert.py b/optimum/habana/transformers/models/albert/modeling_albert.py
index 994afcff91..6ac9b80073 100644
--- a/optimum/habana/transformers/models/albert/modeling_albert.py
+++ b/optimum/habana/transformers/models/albert/modeling_albert.py
@@ -34,7 +34,7 @@ def gaudi_albert_forward(
) -> Union[BaseModelOutputWithPooling, Tuple]:
"""
Same as https://github.com/huggingface/transformers/blob/a9eee2ffecc874df7dd635b2c6abb246fdb318cc/src/transformers/models/albert/modeling_albert.py#L689
- except that HMP is disabled for computing:
+ except that mixed precision is disabled for computing:
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py
index d5a26ed0d3..7c0d6ca451 100644
--- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py
+++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py
@@ -355,7 +355,7 @@ def gaudi_gpt2_forward(
"""
Copied from GPT2Model.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
The only differences are:
- - disable HMP cast for attention_mask
+ - disable autocast for attention_mask
- add new args token_idx
"""
@@ -414,9 +414,7 @@ def gaudi_gpt2_forward(
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- from habana_frameworks.torch.hpex import hmp
-
- with hmp.disable_casts(), torch.autocast(enabled=False, device_type="hpu"):
+ with torch.autocast(enabled=False, device_type="hpu"):
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py
index 7be8216d0a..3647ac8df8 100644
--- a/optimum/habana/transformers/models/modeling_all_models.py
+++ b/optimum/habana/transformers/models/modeling_all_models.py
@@ -24,7 +24,7 @@
def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> torch.Tensor:
"""
Same as https://github.com/huggingface/transformers/blob/a9eee2ffecc874df7dd635b2c6abb246fdb318cc/src/transformers/modeling_utils.py#L640
- except that HMP is disabled for computing:
+ except that mixed precision is disabled for computing:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
"""
if encoder_attention_mask.dim() == 3:
@@ -51,7 +51,7 @@ def gaudi_get_extended_attention_mask(
) -> torch.Tensor:
"""
Same as https://github.com/huggingface/transformers/blob/a9eee2ffecc874df7dd635b2c6abb246fdb318cc/src/transformers/modeling_utils.py#L692
- except that HMP is disabled for computing:
+ except that mixed precision is disabled for computing:
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
"""
if dtype is None:
diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py
index f4837e056b..bf6a87133b 100644
--- a/optimum/habana/transformers/models/opt/modeling_opt.py
+++ b/optimum/habana/transformers/models/opt/modeling_opt.py
@@ -46,7 +46,6 @@ def gaudi_opt_attention_forward(
Copied from OPTAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py
The only differences are:
- add new args token_idx
- - disable HMP for attention softmax
- optimize KV cache
"""
# if key_value_states are provided this layer is used as a cross-attention layer
@@ -118,10 +117,7 @@ def gaudi_opt_attention_forward(
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- from habana_frameworks.torch.hpex import hmp
-
- with hmp.disable_casts():
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py
index ee3293a774..305d019f97 100644
--- a/optimum/habana/transformers/trainer.py
+++ b/optimum/habana/transformers/trainer.py
@@ -20,7 +20,6 @@
import random
import shutil
import sys
-import tempfile
import time
import warnings
from collections.abc import Mapping
@@ -195,7 +194,6 @@ def __init__(
if self.args.deepspeed:
# Mixed-precision backends are turned off when using DeepSpeed since it manages this itself
- self.gaudi_config.use_habana_mixed_precision = False
self.gaudi_config.use_torch_autocast = False
self.use_hpu_amp = False
@@ -209,38 +207,10 @@ def __init__(
logger.warning(
"The argument `--bf16` was not given but `use_torch_autocast` is True in the Gaudi configuration so mixed-precision training with Torch Autocast is enabled."
)
- elif self.gaudi_config.use_habana_mixed_precision and self.use_hpu_amp:
- self.gaudi_config.use_habana_mixed_precision = False
- logger.warning(
- "`--bf16` was given and `use_habana_mixed_precision` is True in the Gaudi configuration. Using Torch Autocast as mixed-precision backend."
- )
if self.use_hpu_amp and "LOWER_LIST" not in os.environ:
gaudi_config.declare_autocast_bf16_fp32_ops()
- if self.gaudi_config.use_habana_mixed_precision and not (self.use_hpu_amp or self.use_cpu_amp):
- try:
- from habana_frameworks.torch.hpex import hmp
- except ImportError as error:
- error.msg = f"Could not import habana_frameworks.torch.hpex. {error.msg}."
- raise error
- self.hmp = hmp
-
- # Open temporary files to mixed-precision write ops
- with tempfile.NamedTemporaryFile() as hmp_bf16_file:
- with tempfile.NamedTemporaryFile() as hmp_fp32_file:
- # hmp.convert needs ops to be written in text files
- self.gaudi_config.write_bf16_fp32_ops_to_text_files(
- hmp_bf16_file.name,
- hmp_fp32_file.name,
- )
- self.hmp.convert(
- opt_level=self.gaudi_config.hmp_opt_level,
- bf16_file_path=hmp_bf16_file.name,
- fp32_file_path=hmp_fp32_file.name,
- isVerbose=self.gaudi_config.hmp_is_verbose,
- )
-
if self.args.use_lazy_mode:
try:
import habana_frameworks.torch.core as htcore
@@ -922,31 +892,14 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
self.FusedNorm.clip_norm(model.parameters())
else:
# Revert to normal clipping otherwise
- if (
- args.use_habana
- and (not (self.use_hpu_amp or self.use_cpu_amp))
- and self.gaudi_config.use_habana_mixed_precision
- ):
- with self.hmp.disable_casts():
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
- else:
- self.accelerator.clip_grad_norm_(
- model.parameters(),
- args.max_grad_norm,
- )
+ self.accelerator.clip_grad_norm_(
+ model.parameters(),
+ args.max_grad_norm,
+ )
# Optimizer step
optimizer_was_run = True
- if (
- args.use_habana
- and self.gaudi_config.use_habana_mixed_precision
- and (not self.gaudi_config.use_fused_adam)
- and (not (self.use_hpu_amp or self.use_cpu_amp))
- ):
- with self.hmp.disable_casts():
- self.optimizer.step()
- else:
- self.optimizer.step()
+ self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
diff --git a/tests/configs/gaudi_config_trainer_test.json b/tests/configs/gaudi_config_trainer_test.json
index dda29507f4..c9fbc2a8c3 100644
--- a/tests/configs/gaudi_config_trainer_test.json
+++ b/tests/configs/gaudi_config_trainer_test.json
@@ -1,12 +1,4 @@
{
- "use_habana_mixed_precision": false,
- "hmp_is_verbose": false,
"use_fused_adam": true,
- "use_fused_clip_norm": true,
- "hmp_bf16_ops": [
- "add",
- "mse_loss",
- "mul"
- ],
- "hmp_fp32_ops": []
+ "use_fused_clip_norm": true
}
\ No newline at end of file
diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py
index 84ca30d4c1..4088081542 100644
--- a/tests/test_diffusers.py
+++ b/tests/test_diffusers.py
@@ -24,7 +24,6 @@
import requests
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
-from habana_frameworks.torch.hpex import hmp
from parameterized import parameterized
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -42,11 +41,9 @@
if os.environ.get("GAUDI2_CI", "0") == "1":
- THROUGHPUT_BASELINE_HMP = 0.981
THROUGHPUT_BASELINE_BF16 = 1.019
THROUGHPUT_BASELINE_AUTOCAST = 0.389
else:
- THROUGHPUT_BASELINE_HMP = 0.298
THROUGHPUT_BASELINE_BF16 = 0.309
THROUGHPUT_BASELINE_AUTOCAST = 0.111
@@ -204,32 +201,31 @@ def get_dummy_inputs(self, device, seed=0):
return inputs
def test_stable_diffusion_ddim(self):
- with hmp.disable_casts():
- device = "cpu"
+ device = "cpu"
- components = self.get_dummy_components()
- gaudi_config = GaudiConfig(use_habana_mixed_precision=False)
+ components = self.get_dummy_components()
+ gaudi_config = GaudiConfig(use_torch_autocast=False)
- sd_pipe = GaudiStableDiffusionPipeline(
- use_habana=True,
- gaudi_config=gaudi_config,
- **components,
- )
- sd_pipe.set_progress_bar_config(disable=None)
+ sd_pipe = GaudiStableDiffusionPipeline(
+ use_habana=True,
+ gaudi_config=gaudi_config,
+ **components,
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images[0]
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images[0]
- image_slice = image[-3:, -3:, -1]
+ image_slice = image[-3:, -3:, -1]
- self.assertEqual(image.shape, (64, 64, 3))
- expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
+ self.assertEqual(image.shape, (64, 64, 3))
+ expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
- self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
def test_stable_diffusion_no_safety_checker(self):
- gaudi_config = GaudiConfig(use_habana_mixed_precision=False)
+ gaudi_config = GaudiConfig()
scheduler = GaudiDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
@@ -516,33 +512,6 @@ def test_stable_diffusion_hpu_graphs(self):
self.assertEqual(len(images), 10)
self.assertEqual(images[-1].shape, (64, 64, 3))
- @slow
- def test_no_throughput_regression_hmp(self):
- prompts = [
- "An image of a squirrel in Picasso style",
- "High quality photo of an astronaut riding a horse in space",
- ]
- num_images_per_prompt = 11
- batch_size = 4
- model_name = "runwayml/stable-diffusion-v1-5"
- scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
-
- pipeline = GaudiStableDiffusionPipeline.from_pretrained(
- model_name,
- scheduler=scheduler,
- use_habana=True,
- use_hpu_graphs=True,
- gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"),
- )
- set_seed(27)
- outputs = pipeline(
- prompt=prompts,
- num_images_per_prompt=num_images_per_prompt,
- batch_size=batch_size,
- )
- self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
- self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_BASELINE_HMP)
-
@slow
def test_no_throughput_regression_bf16(self):
prompts = [
@@ -604,163 +573,160 @@ def test_no_throughput_regression_autocast(self):
def test_no_generation_regression(self):
model_name = "CompVis/stable-diffusion-v1-4"
# fp32
- with hmp.disable_casts():
- scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
- pipeline = GaudiStableDiffusionPipeline.from_pretrained(
- model_name,
- scheduler=scheduler,
- safety_checker=None,
- use_habana=True,
- use_hpu_graphs=True,
- gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
+ scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+ pipeline = GaudiStableDiffusionPipeline.from_pretrained(
+ model_name,
+ scheduler=scheduler,
+ safety_checker=None,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ set_seed(27)
+ outputs = pipeline(
+ prompt="An image of a squirrel in Picasso style",
+ output_type="np",
+ )
+
+ if os.environ.get("GAUDI2_CI", "0") == "1":
+ expected_slice = np.array(
+ [
+ 0.350823,
+ 0.34849027,
+ 0.33486015,
+ 0.35479546,
+ 0.3231264,
+ 0.33130097,
+ 0.34374988,
+ 0.30728853,
+ 0.30011398,
+ ]
)
- set_seed(27)
- outputs = pipeline(
- prompt="An image of a squirrel in Picasso style",
- output_type="np",
+ else:
+ expected_slice = np.array(
+ [0.70760196, 0.7136303, 0.7000798, 0.714934, 0.6776865, 0.6800843, 0.6923707, 0.6653969, 0.6408076]
)
+ image = outputs.images[0]
- if os.environ.get("GAUDI2_CI", "0") == "1":
- expected_slice = np.array(
- [
- 0.350823,
- 0.34849027,
- 0.33486015,
- 0.35479546,
- 0.3231264,
- 0.33130097,
- 0.34374988,
- 0.30728853,
- 0.30011398,
- ]
- )
- else:
- expected_slice = np.array(
- [0.70760196, 0.7136303, 0.7000798, 0.714934, 0.6776865, 0.6800843, 0.6923707, 0.6653969, 0.6408076]
- )
- image = outputs.images[0]
-
- self.assertEqual(image.shape, (512, 512, 3))
- self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3)
+ self.assertEqual(image.shape, (512, 512, 3))
+ self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3)
@slow
def test_no_generation_regression_ldm3d(self):
model_name = "Intel/ldm3d-4c"
# fp32
- with hmp.disable_casts():
- scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
- pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained(
- model_name,
- scheduler=scheduler,
- safety_checker=None,
- use_habana=True,
- use_hpu_graphs=True,
- gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
+ scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+ pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained(
+ model_name,
+ scheduler=scheduler,
+ safety_checker=None,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(),
+ )
+ set_seed(27)
+ outputs = pipeline(
+ prompt="An image of a squirrel in Picasso style",
+ output_type="np",
+ )
+
+ if os.environ.get("GAUDI2_CI", "0") == "1":
+ expected_slice_rgb = np.array(
+ [
+ 0.2099357,
+ 0.16664368,
+ 0.08352646,
+ 0.20643419,
+ 0.16748399,
+ 0.08781305,
+ 0.21379063,
+ 0.19943115,
+ 0.04389626,
+ ]
+ )
+ expected_slice_depth = np.array(
+ [
+ 0.68369114,
+ 0.6827824,
+ 0.6852779,
+ 0.6836072,
+ 0.6888298,
+ 0.6895473,
+ 0.6853674,
+ 0.67561126,
+ 0.660434,
+ ]
)
- set_seed(27)
- outputs = pipeline(
- prompt="An image of a squirrel in Picasso style",
- output_type="np",
+ else:
+ expected_slice_rgb = np.array([0.7083766, 1.0, 1.0, 0.70610344, 0.9867363, 1.0, 0.7214538, 1.0, 1.0])
+ expected_slice_depth = np.array(
+ [
+ 0.919621,
+ 0.92072034,
+ 0.9184986,
+ 0.91994286,
+ 0.9242079,
+ 0.93387043,
+ 0.92345214,
+ 0.93558526,
+ 0.9223714,
+ ]
)
+ rgb = outputs.rgb[0]
+ depth = outputs.depth[0]
- if os.environ.get("GAUDI2_CI", "0") == "1":
- expected_slice_rgb = np.array(
- [
- 0.2099357,
- 0.16664368,
- 0.08352646,
- 0.20643419,
- 0.16748399,
- 0.08781305,
- 0.21379063,
- 0.19943115,
- 0.04389626,
- ]
- )
- expected_slice_depth = np.array(
- [
- 0.68369114,
- 0.6827824,
- 0.6852779,
- 0.6836072,
- 0.6888298,
- 0.6895473,
- 0.6853674,
- 0.67561126,
- 0.660434,
- ]
- )
- else:
- expected_slice_rgb = np.array([0.7083766, 1.0, 1.0, 0.70610344, 0.9867363, 1.0, 0.7214538, 1.0, 1.0])
- expected_slice_depth = np.array(
- [
- 0.919621,
- 0.92072034,
- 0.9184986,
- 0.91994286,
- 0.9242079,
- 0.93387043,
- 0.92345214,
- 0.93558526,
- 0.9223714,
- ]
- )
- rgb = outputs.rgb[0]
- depth = outputs.depth[0]
-
- self.assertEqual(rgb.shape, (512, 512, 3))
- self.assertEqual(depth.shape, (512, 512, 1))
- self.assertLess(np.abs(expected_slice_rgb - rgb[-3:, -3:, -1].flatten()).max(), 5e-3)
- self.assertLess(np.abs(expected_slice_depth - depth[-3:, -3:, -1].flatten()).max(), 5e-3)
+ self.assertEqual(rgb.shape, (512, 512, 3))
+ self.assertEqual(depth.shape, (512, 512, 1))
+ self.assertLess(np.abs(expected_slice_rgb - rgb[-3:, -3:, -1].flatten()).max(), 5e-3)
+ self.assertLess(np.abs(expected_slice_depth - depth[-3:, -3:, -1].flatten()).max(), 5e-3)
@slow
def test_no_generation_regression_upscale(self):
model_name = "stabilityai/stable-diffusion-x4-upscaler"
# fp32
- with hmp.disable_casts():
- scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
- pipeline = GaudiStableDiffusionUpscalePipeline.from_pretrained(
- model_name,
- scheduler=scheduler,
- use_habana=True,
- use_hpu_graphs=True,
- gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
+ scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+ pipeline = GaudiStableDiffusionUpscalePipeline.from_pretrained(
+ model_name,
+ scheduler=scheduler,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ set_seed(27)
+
+ url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
+ response = requests.get(url)
+ low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
+ low_res_img = low_res_img.resize((128, 128))
+ prompt = "a white cat"
+ upscaled_image = pipeline(prompt=prompt, image=low_res_img, output_type="np").images[0]
+ if os.environ.get("GAUDI2_CI", "0") == "1":
+ expected_slice = np.array(
+ [
+ 0.16527882,
+ 0.161616,
+ 0.15665859,
+ 0.1660901,
+ 0.1594379,
+ 0.14936888,
+ 0.1578255,
+ 0.15342498,
+ 0.14590919,
+ ]
+ )
+ else:
+ expected_slice = np.array(
+ [
+ 0.1652787,
+ 0.16161594,
+ 0.15665877,
+ 0.16608998,
+ 0.1594378,
+ 0.14936894,
+ 0.15782538,
+ 0.15342498,
+ 0.14590913,
+ ]
)
- set_seed(27)
-
- url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
- response = requests.get(url)
- low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
- low_res_img = low_res_img.resize((128, 128))
- prompt = "a white cat"
- upscaled_image = pipeline(prompt=prompt, image=low_res_img, output_type="np").images[0]
- if os.environ.get("GAUDI2_CI", "0") == "1":
- expected_slice = np.array(
- [
- 0.16527882,
- 0.161616,
- 0.15665859,
- 0.1660901,
- 0.1594379,
- 0.14936888,
- 0.1578255,
- 0.15342498,
- 0.14590919,
- ]
- )
- else:
- expected_slice = np.array(
- [
- 0.1652787,
- 0.16161594,
- 0.15665877,
- 0.16608998,
- 0.1594378,
- 0.14936894,
- 0.15782538,
- 0.15342498,
- 0.14590913,
- ]
- )
- self.assertEqual(upscaled_image.shape, (512, 512, 3))
- self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3)
+ self.assertEqual(upscaled_image.shape, (512, 512, 3))
+ self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3)
diff --git a/tests/test_gaudi_configuration.py b/tests/test_gaudi_configuration.py
index 14e3e7838e..8dcb31c738 100644
--- a/tests/test_gaudi_configuration.py
+++ b/tests/test_gaudi_configuration.py
@@ -17,7 +17,6 @@
import tempfile
import unittest
from pathlib import Path
-from typing import List
from optimum.habana import GaudiConfig
@@ -26,22 +25,6 @@
FP32_OPS_REFERENCE_FILE = Path(__file__).parent.resolve() / Path("configs/fp32_ops.txt")
-def is_list_of_strings(my_list: List) -> bool:
- """
- This method assesses whether an object is a list of strings or not.
-
- Args:
- my_list (List): list to assess
-
- Returns:
- bool: whether the input argument is a list of strings or not
- """
- if my_list and isinstance(my_list, list):
- return all(isinstance(op, str) for op in my_list)
- else:
- return False
-
-
class GaudiConfigTester(unittest.TestCase):
"""
Unit tests for Gaudi configuration class GaudiConfig.
@@ -50,19 +33,37 @@ class GaudiConfigTester(unittest.TestCase):
def test_default_parameter_types(self):
gaudi_config = GaudiConfig()
- self.assertIsInstance(gaudi_config.use_habana_mixed_precision, bool)
- self.assertIsInstance(gaudi_config.hmp_is_verbose, bool)
self.assertIsInstance(gaudi_config.use_fused_adam, bool)
self.assertIsInstance(gaudi_config.use_fused_clip_norm, bool)
self.assertIsInstance(gaudi_config.use_torch_autocast, bool)
- self.assertTrue(is_list_of_strings(gaudi_config.hmp_bf16_ops))
- self.assertTrue(is_list_of_strings(gaudi_config.hmp_fp32_ops))
self.assertIsNone(gaudi_config.autocast_bf16_ops)
self.assertIsNone(gaudi_config.autocast_fp32_ops)
def test_write_bf16_fp32_ops_to_text_files(self):
- gaudi_config = GaudiConfig()
+ gaudi_config = GaudiConfig(
+ autocast_bf16_ops=[
+ "add",
+ "addmm",
+ "bmm",
+ "div",
+ "dropout",
+ "gelu",
+ "iadd",
+ "linear",
+ "layer_norm",
+ "matmul",
+ "mm",
+ "rsub",
+ "softmax",
+ "truediv",
+ ],
+ autocast_fp32_ops=[
+ "embedding",
+ "nll_loss",
+ "log_softmax",
+ ],
+ )
with tempfile.NamedTemporaryFile() as bf16_file:
with tempfile.NamedTemporaryFile() as fp32_file: