diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml
index 24af938f01..acc42aff01 100644
--- a/.github/workflows/fast_tests.yml
+++ b/.github/workflows/fast_tests.yml
@@ -15,7 +15,8 @@ concurrency:
jobs:
transformers:
name: Run tests for optimum.habana.transformers
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -39,7 +40,8 @@ jobs:
name: Run tests for optimum.habana.diffusers
needs:
- transformers # required to wait for the previous tests to finish
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
diff --git a/.github/workflows/slow_tests.yml b/.github/workflows/slow_tests.yml
index 82914019e1..588ec3cac0 100644
--- a/.github/workflows/slow_tests.yml
+++ b/.github/workflows/slow_tests.yml
@@ -12,7 +12,8 @@ concurrency:
jobs:
example-diff:
name: Test examples differences
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -37,7 +38,8 @@ jobs:
if: ${{ !cancelled() && (success() || failure()) }}
needs:
- example-diff # run the job when the previous test job is done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -63,7 +65,8 @@ jobs:
needs:
- example-diff
- stable-diffusion # run the job when the previous test job is done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -89,7 +92,8 @@ jobs:
needs:
- example-diff
- deepspeed # run the job when the previous test job is done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -116,7 +120,8 @@ jobs:
- example-diff
- deepspeed
- multi-card # run the job when the previous test jobs are done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -144,7 +149,8 @@ jobs:
- deepspeed
- multi-card
- single-card # run the job when the previous test jobs are done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
if: github.event.schedule == '0 21 * * 6'
@@ -179,7 +185,8 @@ jobs:
- multi-card
- single-card
- albert-xxl-single-card # run the job when the previous test jobs are done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -209,7 +216,8 @@ jobs:
- single-card
- albert-xxl-single-card
- text-generation # run the job when the previous test jobs are done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -240,7 +248,8 @@ jobs:
- albert-xxl-single-card
- text-generation
- trl # run the job when the previous test jobs are done
- runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner
+ runs-on:
+ group: aws-dl1-24xlarge
steps:
- name: Checkout Optimum Habana
uses: actions/checkout@v2
diff --git a/Makefile b/Makefile
index 6e87a399a3..dbc1198a8c 100644
--- a/Makefile
+++ b/Makefile
@@ -41,9 +41,32 @@ fast_tests_diffusers:
python -m pip install .[tests]
python -m pytest tests/test_diffusers.py
+# Run single-card non-regression tests on image classification models
+fast_tests_image_classifications:
+ pip install timm
+ python -m pip install .[tests]
+ python -m pytest tests/test_image_classification.py
+
+# Run unit and integration tests related to Image segmentation
+fast_tests_image_segmentation:
+ python -m pip install .[tests]
+ python -m pytest tests/test_image_segmentation.py
+
+# Run unit and integration tests related to text feature extraction
+fast_tests_feature_extraction:
+ python -m pip install .[tests]
+ python -m pytest tests/test_feature_extraction.py
+
+# Run unit and integration tests related to VideoMAE
+fast_test_videomae:
+ python -m pip install .[tests]
+ python -m pytest tests/test_video_mae.py
+
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
+ python -m pip install peft==0.10.0
+ python -m pytest tests/test_peft_inference.py
python -m pytest tests/test_pipeline.py
# Run multi-card non-regression tests
@@ -61,6 +84,7 @@ slow_tests_diffusers: test_installs
python -m pip install peft==0.7.0
python -m pytest tests/test_diffusers.py -v -s -k "test_train_text_to_image_"
python -m pytest tests/test_diffusers.py -v -s -k "test_train_controlnet"
+ python -m pytest tests/test_diffusers.py -v -s -k "test_deterministic_image_generation"
# Run text-generation non-regression tests
slow_tests_text_generation_example: test_installs
@@ -71,6 +95,11 @@ slow_tests_text_generation_example: test_installs
slow_tests_image_to_text_example: test_installs
python -m pytest tests/test_image_to_text_example.py -v -s --token $(TOKEN)
+# Run visual question answering tests
+slow_tests_openclip_vqa_example: test_installs
+ python -m pip install -r examples/visual-question-answering/openclip_requirements.txt
+ python -m pytest tests/test_openclip_vqa.py
+
slow_tests_fsdp: test_installs
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)
diff --git a/README.md b/README.md
index fabff9e260..d6a8a57010 100644
--- a/README.md
+++ b/README.md
@@ -59,9 +59,9 @@ The `--upgrade-strategy eager` option is needed to ensure `optimum-habana` is up
To use the example associated with the latest stable release, run:
> ```
> git clone https://github.com/huggingface/optimum-habana
-> cd optimum-habana && git checkout v1.12.0
+> cd optimum-habana && git checkout v1.12.1
> ```
-> with `v1.12.0` the version number of this release.
+> with `v1.12.1` the version number of this release.
### Option 2: Use the latest main branch under development
@@ -179,7 +179,7 @@ The following model architectures, tasks and device distributions have been vali
| Architecture | Training | Inference |
Tasks |
|--------------|:--------:|:---------:|:-----------------------|
-| BERT | :heavy_check_mark: | :heavy_check_mark: | [text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
+| BERT | :heavy_check_mark: | :heavy_check_mark: | [text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text feature extraction](https://github.com/huggingface/optimum-habana/tree/main/examples/text-feature-extraction) |
| RoBERTa | :heavy_check_mark: | :heavy_check_mark: | [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
| ALBERT | :heavy_check_mark: | :heavy_check_mark: | [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
| DistilBERT |:heavy_check_mark: | :heavy_check_mark: | [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
@@ -189,7 +189,7 @@ The following model architectures, tasks and device distributions have been vali
| GPT-J | DeepSpeed | Single cardDeepSpeed | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| GPT-NeoX | DeepSpeed | DeepSpeed | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| OPT | | DeepSpeed | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
-| Llama 2 / CodeLlama / Llama 3 / Llama Guard | :heavy_check_mark: | :heavy_check_mark: | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification) (Llama Guard) |
+| Llama 2 / CodeLlama / Llama 3 / Llama Guard / Granite | :heavy_check_mark: | :heavy_check_mark: | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification) (Llama Guard) |
| StableLM | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Falcon | LoRA | :heavy_check_mark: | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| CodeGen | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
@@ -214,6 +214,8 @@ The following model architectures, tasks and device distributions have been vali
| OWLViT | | Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| Segment Anything Model | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
+| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
@@ -229,6 +231,16 @@ The following model architectures, tasks and device distributions have been vali
+- PyTorch Image Models/TIMM:
+
+
+
+| Architecture | Training | Inference | Tasks |
+|---------------------|:--------:|:---------:|:------|
+| FastViT | |
Single card |
[image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
+
+
+
- TRL:
@@ -249,4 +261,4 @@ After training your model, feel free to submit it to the Intel [leaderboard](htt
## Development
-Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
\ No newline at end of file
+Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index b33cfd062e..583c33642b 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -33,11 +33,11 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
-- Transformers
+- Transformers:
| Architecture | Training | Inference | Tasks |
|--------------|:--------:|:---------:|:------|
-| BERT | ✅ | ✅ |
[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
+| BERT | ✅ | ✅ |
[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text feature extraction](https://github.com/huggingface/optimum-habana/tree/main/examples/text-feature-extraction) |
| RoBERTa | ✅ | ✅ |
[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
| ALBERT | ✅ | ✅ |
[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
| DistilBERT | ✅ | ✅ |
[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling) |
@@ -47,7 +47,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| GPT-J |
DeepSpeed |
Single cardDeepSpeed |
[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| GPT-NeoX |
DeepSpeed |
DeepSpeed |
[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| OPT | |
DeepSpeed |
[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
-| Llama 2 / CodeLlama / Llama 3 / Llama Guard | ✅ | ✅ |
[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification) (Llama Guard) |
+| Llama 2 / CodeLlama / Llama 3 / Llama Guard / Granite | ✅ | ✅ |
[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)[text classification](https://github.com/huggingface/optimum-habana/tree/main/examples/text-classification) (Llama Guard) |
| StableLM | |
Single card |
[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Falcon |
LoRA | ✅ |
[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| CodeGen | |
Single card |
[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
@@ -72,9 +72,10 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| OWLViT | |
Single card |
[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | |
Single card |
[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | |
Single card |
[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| SAM | |
Single card |
[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
+| VideoMAE | |
Single card |
[Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
-
-- Diffusers
+- Diffusers:
| Architecture | Training | Inference | Tasks |
|---------------------|:--------:|:---------:|:------|
@@ -82,6 +83,11 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Stable Diffusion XL |
[fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl) |
Single card |
[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| LDM3D | |
Single card |
[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
+- PyTorch Image Models/TIMM:
+
+| Architecture | Training | Inference | Tasks |
+|---------------------|:--------:|:---------:|:------|
+| FastViT | |
Single card |
[image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
- TRL:
@@ -116,4 +122,4 @@ Besides, [this page](https://github.com/huggingface/optimum-habana/tree/main/exa
Technical descriptions of how the Habana classes and methods of 🤗 Optimum Habana work.
-
\ No newline at end of file
+
diff --git a/examples/audio-classification/README.md b/examples/audio-classification/README.md
index 7e91e46eac..ec227545a7 100644
--- a/examples/audio-classification/README.md
+++ b/examples/audio-classification/README.md
@@ -84,7 +84,7 @@ python ../gaudi_spawn.py \
--max_length_seconds 8 \
--attention_mask False \
--warmup_ratio 0.1 \
- --num_train_epochs 10 \
+ --num_train_epochs 5 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--seed 0 \
@@ -94,7 +94,8 @@ python ../gaudi_spawn.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/wav2vec2 \
--throughput_warmup_steps 3 \
- --bf16
+ --bf16 \
+ --trust_remote_code True
```
On 8 HPUs, this script should run in ~12 minutes and yield an accuracy of **80.49%**.
@@ -141,7 +142,8 @@ python ../gaudi_spawn.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/wav2vec2 \
--throughput_warmup_steps 3 \
- --deepspeed ../../tests/configs/deepspeed_zero_2.json
+ --deepspeed ../../tests/configs/deepspeed_zero_2.json \
+ --trust_remote_code True
```
[The documentation](https://huggingface.co/docs/optimum/habana/usage_guides/deepspeed) provides more information about how to use DeepSpeed within Optimum Habana.
diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py
index b8c1e146c9..74e148efd5 100644
--- a/examples/audio-classification/run_audio_classification.py
+++ b/examples/audio-classification/run_audio_classification.py
@@ -167,9 +167,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -254,12 +254,14 @@ def main():
data_args.dataset_config_name,
split=data_args.train_split_name,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
raw_datasets["eval"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=data_args.eval_split_name,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
if data_args.audio_column_name not in raw_datasets["train"].column_names:
diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py
index dd4b7e3fba..0d3498a511 100644
--- a/examples/contrastive-image-text/run_bridgetower.py
+++ b/examples/contrastive-image-text/run_bridgetower.py
@@ -102,9 +102,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -203,9 +203,9 @@ def __post_init__(self):
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
- if self.validation_file is not None:
- extension = self.validation_file.split(".")[-1]
- assert extension == "json", "`validation_file` should be a json file."
+ if self.test_file is not None:
+ extension = self.test_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
dataset_name_mapping = {
@@ -328,6 +328,7 @@ def main():
data_dir=data_args.data_dir,
token=model_args.token,
revision=data_args.dataset_revision,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py
index 8d3b3a28a5..c1abae0011 100644
--- a/examples/contrastive-image-text/run_clip.py
+++ b/examples/contrastive-image-text/run_clip.py
@@ -107,9 +107,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -201,9 +201,9 @@ def __post_init__(self):
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
- if self.validation_file is not None:
- extension = self.validation_file.split(".")[-1]
- assert extension == "json", "`validation_file` should be a json file."
+ if self.test_file is not None:
+ extension = self.test_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
dataset_name_mapping = {
@@ -325,6 +325,7 @@ def main():
keep_in_memory=False,
data_dir=data_args.data_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md
index 642cf427df..0ae5a82834 100644
--- a/examples/image-classification/README.md
+++ b/examples/image-classification/README.md
@@ -16,7 +16,7 @@ limitations under the License.
# Image Classification Examples
-This directory contains a script that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit) or [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)) on HPUs. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data).
+This directory contains a script that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit) or [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)) on HPUs. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data). This directory also contains a script to demonstrate a single HPU inference for [PyTorch-Image-Models/TIMM](https://huggingface.co/docs/timm/index).
## Requirements
@@ -43,7 +43,7 @@ python run_image_classification.py \
--do_eval \
--learning_rate 3e-5 \
--num_train_epochs 5 \
- --per_device_train_batch_size 64 \
+ --per_device_train_batch_size 128 \
--per_device_eval_batch_size 64 \
--evaluation_strategy epoch \
--save_strategy epoch \
@@ -195,7 +195,7 @@ python ../gaudi_spawn.py \
--do_eval \
--learning_rate 2e-4 \
--num_train_epochs 5 \
- --per_device_train_batch_size 64 \
+ --per_device_train_batch_size 128 \
--per_device_eval_batch_size 64 \
--evaluation_strategy epoch \
--save_strategy epoch \
@@ -235,7 +235,7 @@ python ../gaudi_spawn.py \
--do_eval \
--learning_rate 2e-4 \
--num_train_epochs 5 \
- --per_device_train_batch_size 64 \
+ --per_device_train_batch_size 128 \
--per_device_eval_batch_size 64 \
--evaluation_strategy epoch \
--save_strategy epoch \
@@ -295,3 +295,25 @@ python run_image_classification.py \
--gaudi_config_name Habana/vit \
--dataloader_num_workers 1 \
--bf16
+```
+
+## TIMM/FastViT Examples
+
+This directory contains an example script that demonstrates using FastViT with graph mode.
+
+### Single-HPU inference
+
+```bash
+python3 run_timm_example.py \
+ --model_name_or_path "timm/fastvit_t8.apple_in1k" \
+ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" \
+ --warmup 3 \
+ --n_iterations 20 \
+ --use_hpu_graphs \
+ --bf16 \
+ --print_result
+```
+Models that have been validated:
+ - [timm/fastvit_t8.apple_dist_in1k](https://huggingface.co/timm/fastvit_t8.apple_dist_in1k)
+ - [timm/fastvit_t8.apple_in1k](https://huggingface.co/timm/fastvit_t8.apple_in1k)
+ - [timm/fastvit_sa12.apple_in1k](https://huggingface.co/timm/fastvit_sa12.apple_in1k)
\ No newline at end of file
diff --git a/examples/image-classification/requirements.txt b/examples/image-classification/requirements.txt
index 87694059fe..7b0e43a8d2 100644
--- a/examples/image-classification/requirements.txt
+++ b/examples/image-classification/requirements.txt
@@ -3,3 +3,4 @@ torchvision>=0.6.0
datasets>=2.14.0
evaluate
scikit-learn
+timm>=0.9.16
\ No newline at end of file
diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py
index 9f25269b4b..30779dbc03 100644
--- a/examples/image-classification/run_image_classification.py
+++ b/examples/image-classification/run_image_classification.py
@@ -172,9 +172,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -259,6 +259,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/image-classification/run_timm_example.py b/examples/image-classification/run_timm_example.py
new file mode 100644
index 0000000000..6d96b01024
--- /dev/null
+++ b/examples/image-classification/run_timm_example.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# Copied from https://huggingface.co/timm/fastvit_t8.apple_in1k
+
+import argparse
+import time
+
+import habana_frameworks.torch as ht
+import requests
+import timm
+import torch
+from PIL import Image
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="timm/fastvit_t8.apple_in1k",
+ type=str,
+ help="Path of the pre-trained model",
+ )
+ parser.add_argument(
+ "--image_path",
+ default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
+ type=str,
+ help='Path of the input image. Should be a single string (eg: --image_path "URL")',
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ action="store_true",
+ help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ help="Whether to use bf16 precision for classification.",
+ )
+ parser.add_argument(
+ "--print_result",
+ action="store_true",
+ help="Whether to print the classification results.",
+ )
+ parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
+ parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
+
+ args = parser.parse_args()
+
+ adapt_transformers_to_gaudi()
+
+ model = timm.create_model(args.model_name_or_path, pretrained=True)
+ model.to("hpu")
+ model = model.eval()
+ data_config = timm.data.resolve_model_data_config(model)
+ transforms = timm.data.create_transform(**data_config, is_training=False)
+
+ img = Image.open(requests.get(args.image_path, stream=True).raw)
+
+ if args.use_hpu_graphs:
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
+ model.to("hpu")
+
+ with torch.no_grad(), autocast:
+ for i in range(args.warmup):
+ inputs = transforms(img).unsqueeze(0).to("hpu")
+ outputs = model(inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(args.n_iterations):
+ inputs = transforms(img).unsqueeze(0).to("hpu")
+ model_start_time = time.time()
+ outputs = model(inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ if args.print_result:
+ top5_probabilities, top5_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=5)
+ print("top5_class_indices: " + str(top5_class_indices.to("cpu").numpy()))
+
+ print("n_iterations: " + str(args.n_iterations))
+ print("Total latency (ms): " + str(total_model_time * 1000))
+ print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))
diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md
index 7b41f870e8..f42fe0c0a0 100644
--- a/examples/image-to-text/README.md
+++ b/examples/image-to-text/README.md
@@ -15,11 +15,23 @@ limitations under the License.
-->
# Image to Text Examples
-
-This directory contains a script that showcases how to use the Transformers pipeline API to run image to text task on HPUs.
+This directory contains a script that showcases how to perform image to text generation on Intel® Gaudi® AI Accelerators.
## Single-HPU inference
+Models that have been validated:
+ - [nlpconnect/vit-gpt2-image-captioning](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning)
+ - [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large)
+ - [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base)
+ - [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
+ - [llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)
+ - [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
+ - [llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf)
+ - [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
+
+### Inference with BF16
+
+To run Salesforce/blip-image-captioning-large inference, use the following command:
```bash
python3 run_pipeline.py \
--model_name_or_path Salesforce/blip-image-captioning-large \
@@ -27,16 +39,44 @@ python3 run_pipeline.py \
--use_hpu_graphs \
--bf16
```
-Models that have been validated:
- - [nlpconnect/vit-gpt2-image-captioning](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning)
- - [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large)
- - [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base)
-### Running with FP8
+To run Llava-1.5-7b inference, use the following command:
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --use_hpu_graphs \
+ --bf16
+```
+
+To run Llava-1.5-13b inference, use the following command:
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-1.5-13b-hf \
+ --use_hpu_graphs \
+ --bf16
+```
+
+To run Llava-v1.6-mistral-7b inference, use the following command:
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --use_hpu_graphs \
+ --bf16
+```
-Llava-1.5-7b and Llava-1.5-13b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
+To run Llava-v1.6-vicuna-13b inference, use the following command:
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+ --use_hpu_graphs \
+ --bf16
+```
+
+### Inference with FP8
-More information on enabling fp8 in SynapseAI is available here:
+Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
+
+More information on enabling FP8 in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html
Here is an example to measure the tensor quantization statistics on Llava-1.5-7b:
@@ -56,3 +96,65 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
--use_hpu_graphs \
--bf16
```
+
+
+Here is an example to measure the tensor quantization statistics on Llava-v1.6-mistral-7b:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
+--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+--image_path "https://llava-vl.github.io/static/images/view.jpg" \
+--use_hpu_graphs \
+--bf16
+```
+
+Here is an example to quantize the model based on previous measurements for Llava-v1.6-mistral-7b:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
+--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+--image_path "https://llava-vl.github.io/static/images/view.jpg" \
+--use_hpu_graphs \
+--bf16
+```
+
+Here is an example to measure the tensor quantization statistics on Llava-v1.6-vicuna-13b:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
+--model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+--image_path "https://llava-vl.github.io/static/images/view.jpg" \
+--use_hpu_graphs \
+--bf16
+```
+
+Here is an example to quantize the model based on previous measurements for Llava-v1.6-vicuna-13b:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
+--model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+--image_path "https://llava-vl.github.io/static/images/view.jpg" \
+--use_hpu_graphs \
+--bf16
+```
+
+### Inference with FusedSDPA
+
+Habana FusedSDPA is a fused and optimized implementation of torch.nn.functional.scaled_dot_product_attention() for Gaudi. For more details, refer to [Gaudi online documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa). Currently FusedSDPA works with BF16 precision for Llava models.
+
+Use the following commands to run Llava-1.5-7b inference with FusedSDPA
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention
+```
+
+
+Use the following commands to run Llava-v1.6-mistral-7b inference with FusedSDPA
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention
+```
\ No newline at end of file
diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py
index 52df29f52f..2d24175b0f 100644
--- a/examples/image-to-text/run_pipeline.py
+++ b/examples/image-to-text/run_pipeline.py
@@ -91,6 +91,12 @@ def main():
action="store_true",
help="Whether to ignore eos, set False to disable it.",
)
+ parser.add_argument(
+ "--use_flash_attention",
+ action="store_true",
+ help="Whether to enable Habana Flash Attention, provided that the model supports it.",
+ )
+
args = parser.parse_args()
# set args.quant_config with env variable if it is set
@@ -109,7 +115,7 @@ def main():
args.prompt = "\nUSER: What's the content of the image?\nASSISTANT:"
elif args.prompt is None and model_type == "llava_next":
args.prompt = "[INST] \nWhat is shown in this image? [/INST]"
- if args.model_name_or_path == "llava-hf/llava-v1.6-vicuna-13b-hf":
+ if args.model_name_or_path in ["llava-hf/llava-v1.6-vicuna-13b-hf", "llava-hf/llava-v1.6-vicuna-7b-hf"]:
args.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:"
image_paths = args.image_path
@@ -149,6 +155,7 @@ def main():
"hpu_graphs": args.use_hpu_graphs,
"max_new_tokens": args.max_new_tokens,
"ignore_eos": args.ignore_eos,
+ "use_flash_attention": args.use_flash_attention,
}
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
@@ -165,7 +172,6 @@ def main():
# warm up
for i in range(args.warmup):
generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)
-
torch.hpu.synchronize()
if args.quant_config:
habana_quantization_toolkit.finish_measurements(generator.model)
diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md
index 8c77f0e818..5932726b0f 100644
--- a/examples/language-modeling/README.md
+++ b/examples/language-modeling/README.md
@@ -114,7 +114,7 @@ python ../gaudi_spawn.py \
--model_name_or_path EleutherAI/gpt-j-6b \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
- --per_device_train_batch_size 4 \
+ --per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--do_train \
--do_eval \
@@ -362,7 +362,7 @@ python run_clm.py \
## PEFT
-### LORA/ADALORA/IA3
+### LORA/ADALORA/IA3/LLAMA_ADAPTER
To run LoRA finetuning, you can use `run_lora_clm.py`.
Here are single-/multi-device command examples for Llama1-7B, Falcon-40B, Llama2-70B, Llama3-8B and Llama3-70B.
@@ -469,6 +469,43 @@ python ../gaudi_spawn.py \
--low_cpu_mem_usage True
```
+- Multi-card finetuning of Llama2-7B with FP8:
+```bash
+LOWER_LIST=ops_bf16.txt python ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_lora_clm.py \
+ --model_name_or_path meta-llama/Llama-2-7b-hf \
+ --dataset_name tatsu-lab/alpaca \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 16 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 3e-4 \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "constant" \
+ --max_grad_norm 0.3 \
+ --logging_steps 20 \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --throughput_warmup_steps 18 \
+ --lora_rank=8 \
+ --lora_alpha=16 \
+ --lora_dropout=0.05 \
+ --lora_target_modules "q_proj" "v_proj" \
+ --dataset_concatenation \
+ --max_seq_length 512 \
+ --ddp_bucket_cap_mb 50 \
+ --adam_epsilon 1e-08 \
+ --validation_split_percentage 10 \
+ --low_cpu_mem_usage True \
+ --pipelining_fwd_bwd \
+ --fp8 True
+```
+
- Multi-card finetuning of codegen-16B-mono:
```bash
python ../gaudi_spawn.py \
@@ -535,12 +572,12 @@ LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
--validation_split_percentage 6
```
-- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization and LoRA:
+- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization, LoRA and FP8 precision:
> The following command requires Habana DeepSpeed 1.13.0 or later.
```bash
-PT_HPU_MAX_COMPOUND_OP_SIZE=10 DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 \
+PT_HPU_MAX_COMPOUND_OP_SIZE=10 \
python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--deepspeed llama2_ds_zero3_config.json \
@@ -550,7 +587,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--num_train_epochs 2 \
--max_seq_len 2048 \
--per_device_train_batch_size 10 \
- --per_device_eval_batch_size 10 \
+ --per_device_eval_batch_size 1 \
--gradient_checkpointing \
--evaluation_strategy epoch \
--eval_delay 2 \
@@ -571,7 +608,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True \
- --flash_attention_causal_mask True
+ --flash_attention_causal_mask True \
+ --fp8 True
```
- Multi-card finetuning of Llama2-70B with FSDP and LoRA:
@@ -653,7 +691,7 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 ..
--validation_split_percentage 5 \
--deepspeed ds_falcon_180b_z3.json
```
-Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`.
+Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`, or enable llama-adapter for llama model using `--peft_type llama-adapter`.
### Prompt/Prefix/P-tuning
diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py
index fb00e93fb2..9433e8f3bf 100644
--- a/examples/language-modeling/run_clm.py
+++ b/examples/language-modeling/run_clm.py
@@ -131,9 +131,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -341,6 +341,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
@@ -350,6 +351,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
@@ -358,6 +360,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py
index 2f287fb89d..e7989f6d80 100644
--- a/examples/language-modeling/run_lora_clm.py
+++ b/examples/language-modeling/run_lora_clm.py
@@ -30,7 +30,7 @@
import torch
import transformers
from datasets import load_dataset
-from peft import AdaLoraConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners
+from peft import AdaLoraConfig, AdaptionPromptConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoConfig,
@@ -103,7 +103,11 @@ class ModelArguments:
trust_remote_code: bool = field(
default=False,
metadata={
- "help": "should enable when using custom model architecture that is not yet part of the Hugging Face transformers package like MPT)."
+ "help": (
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
+ )
},
)
use_cache: bool = field(
@@ -338,7 +342,7 @@ class FinetuneArguments:
default="lora",
metadata={
"help": ("The PEFT type to use."),
- "choices": ["lora", "ia3", "adalora"],
+ "choices": ["lora", "ia3", "adalora", "llama-adapter"],
},
)
ia3_target_modules: List[str] = field(
@@ -349,6 +353,14 @@ class FinetuneArguments:
default_factory=lambda: None,
metadata={"help": "Target feedforward modules for the IA3 method."},
)
+ adapter_layers: int = field(
+ default=30,
+ metadata={"help": "Number of adapter layers (from the top) in llama-adapter"},
+ )
+ adapter_len: int = field(
+ default=10,
+ metadata={"help": "Number of adapter tokens to insert in llama-adapter"},
+ )
PROMPT_DICT = {
@@ -466,6 +478,7 @@ def main():
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"token": model_args.token,
+ "padding_side": "right",
}
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
@@ -493,6 +506,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in raw_datasets.keys() and training_args.do_eval:
@@ -502,6 +516,7 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
@@ -509,6 +524,7 @@ def main():
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
@@ -636,11 +652,16 @@ def main():
tokenizer.pad_token_id = tokenizer.eos_token_id
def tokenize(prompt, add_eos_token=True):
+ if not data_args.dataset_concatenation:
+ add_eos_token = False
+ padding = "max_length"
+ else:
+ padding = False
results = tokenizer(
prompt,
truncation=True,
max_length=data_args.max_seq_length,
- padding=False,
+ padding=padding,
return_tensors=None,
)
for i in range(len(results["input_ids"])):
@@ -779,6 +800,19 @@ def compute_metrics(eval_preds):
feedforward_modules=finetune_args.feedforward_modules,
task_type=TaskType.CAUSAL_LM,
)
+ elif finetune_args.peft_type == "llama-adapter":
+ peft_config = AdaptionPromptConfig(
+ adapter_layers=finetune_args.adapter_layers,
+ adapter_len=finetune_args.adapter_len,
+ task_type=TaskType.CAUSAL_LM,
+ )
+ from optimum.habana.peft.layer import (
+ GaudiAdaptedAttention_getattr,
+ GaudiAdaptedAttentionPreAttnForward,
+ )
+
+ tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
+ tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
lora_model = get_peft_model(model, peft_config)
diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py
index 17c2432760..18015ca515 100644
--- a/examples/language-modeling/run_mlm.py
+++ b/examples/language-modeling/run_mlm.py
@@ -129,9 +129,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -340,6 +340,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
@@ -349,6 +350,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
@@ -357,6 +359,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/language-modeling/run_prompt_tuning_clm.py b/examples/language-modeling/run_prompt_tuning_clm.py
index 49043f3930..42798c0d5e 100644
--- a/examples/language-modeling/run_prompt_tuning_clm.py
+++ b/examples/language-modeling/run_prompt_tuning_clm.py
@@ -114,9 +114,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -248,6 +248,7 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
+ trust_remote_code=model_args.trust_remote_code,
)
if data_args.dataset_name == "ought/raft" and data_args.dataset_config_name == "twitter_complaints":
text_column = "Tweet text"
diff --git a/examples/object-segementation/README.md b/examples/object-segementation/README.md
index 4afb598492..936180e4f2 100644
--- a/examples/object-segementation/README.md
+++ b/examples/object-segementation/README.md
@@ -13,10 +13,12 @@ limitations under the License.
# Object Segmentation Examples
-This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode.
+This directory contains two example scripts that demonstrate how to perform object segmentation on Gaudi with graph mode.
## Single-HPU inference
+### ClipSeg Model
+
```bash
python3 run_example.py \
--model_name_or_path "CIDAS/clipseg-rd64-refined" \
@@ -29,4 +31,21 @@ python3 run_example.py \
--print_result
```
Models that have been validated:
- - [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
\ No newline at end of file
+ - [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
+
+### Segment Anything Model
+
+```bash
+python3 run_example_sam.py \
+ --model_name_or_path "facebook/sam-vit-huge" \
+ --image_path "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" \
+ --point_prompt "450,600" \
+ --warmup 3 \
+ --n_iterations 20 \
+ --use_hpu_graphs \
+ --bf16 \
+ --print_result
+```
+Models that have been validated:
+ - [facebook/sam-vit-base](https://huggingface.co/facebook/sam-vit-base)
+ - [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge)
\ No newline at end of file
diff --git a/examples/object-segementation/run_example_sam.py b/examples/object-segementation/run_example_sam.py
new file mode 100644
index 0000000000..d3911c0c2e
--- /dev/null
+++ b/examples/object-segementation/run_example_sam.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# Copied from https://huggingface.co/facebook/sam-vit-base
+
+import argparse
+import time
+
+import habana_frameworks.torch as ht
+import requests
+import torch
+from PIL import Image
+from transformers import AutoModel, AutoProcessor
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="facebook/sam-vit-huge",
+ type=str,
+ help="Path of the pre-trained model",
+ )
+ parser.add_argument(
+ "--image_path",
+ default="https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png",
+ type=str,
+ help='Path of the input image. Should be a single string (eg: --image_path "URL")',
+ )
+ parser.add_argument(
+ "--point_prompt",
+ default="450, 600",
+ type=str,
+ help='Prompt for segmentation. It should be a string seperated by comma. (eg: --point_prompt "450, 600")',
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ action="store_true",
+ help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ help="Whether to use bf16 precision for classification.",
+ )
+ parser.add_argument(
+ "--print_result",
+ action="store_true",
+ help="Whether to save the segmentation result.",
+ )
+ parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
+ parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
+
+ args = parser.parse_args()
+
+ adapt_transformers_to_gaudi()
+
+ processor = AutoProcessor.from_pretrained(args.model_name_or_path)
+ model = AutoModel.from_pretrained(args.model_name_or_path)
+
+ image = Image.open(requests.get(args.image_path, stream=True).raw).convert("RGB")
+ points = []
+ for text in args.point_prompt.split(","):
+ points.append(int(text))
+ points = [[points]]
+
+ if args.use_hpu_graphs:
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
+ model.to("hpu")
+
+ with torch.no_grad(), autocast:
+ for i in range(args.warmup):
+ inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
+ outputs = model(**inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(args.n_iterations):
+ inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
+ model_start_time = time.time()
+ outputs = model(**inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ if args.print_result:
+ if i == 0: # generate/output once only
+ iou = outputs.iou_scores
+ print("iou score: " + str(iou))
+
+ print("n_iterations: " + str(args.n_iterations))
+ print("Total latency (ms): " + str(total_model_time * 1000))
+ print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))
diff --git a/examples/question-answering/README.md b/examples/question-answering/README.md
index d531bd9fcd..fabb165e35 100755
--- a/examples/question-answering/README.md
+++ b/examples/question-answering/README.md
@@ -50,7 +50,7 @@ PT_HPU_LAZY_MODE=0 python run_qa.py \
--dataset_name squad \
--do_train \
--do_eval \
- --per_device_train_batch_size 24 \
+ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
@@ -78,7 +78,7 @@ PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \
--dataset_name squad \
--do_train \
--do_eval \
- --per_device_train_batch_size 24 \
+ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
@@ -106,7 +106,7 @@ python ../gaudi_spawn.py \
--dataset_name squad \
--do_train \
--do_eval \
- --per_device_train_batch_size 24 \
+ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py
index e58f7f42a2..b7022310c2 100644
--- a/examples/question-answering/run_qa.py
+++ b/examples/question-answering/run_qa.py
@@ -102,9 +102,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -319,6 +319,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py
index 50880a1f7c..ff56d5b4e6 100644
--- a/examples/question-answering/run_seq2seq_qa.py
+++ b/examples/question-answering/run_seq2seq_qa.py
@@ -102,9 +102,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -364,6 +364,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py
index 048da1dd5d..c01778d6d9 100644
--- a/examples/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/speech-recognition/run_speech_recognition_ctc.py
@@ -261,9 +261,9 @@ class DataTrainingArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -467,6 +467,7 @@ def main():
data_args.dataset_config_name,
split=data_args.train_split_name,
token=data_args.token,
+ trust_remote_code=data_args.trust_remote_code,
)
if data_args.audio_column_name not in raw_datasets["train"].column_names:
@@ -492,6 +493,7 @@ def main():
data_args.dataset_config_name,
split=data_args.eval_split_name,
token=data_args.token,
+ trust_remote_code=data_args.trust_remote_code,
)
if data_args.max_eval_samples is not None:
diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py
index 06733f8e7c..5985825601 100755
--- a/examples/speech-recognition/run_speech_recognition_seq2seq.py
+++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py
@@ -106,9 +106,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -372,6 +372,7 @@ def main():
split=data_args.train_split_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
if training_args.do_eval:
@@ -381,6 +382,7 @@ def main():
split=data_args.eval_split_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md
index 5f33c6fb7e..2d3839b262 100644
--- a/examples/stable-diffusion/README.md
+++ b/examples/stable-diffusion/README.md
@@ -31,7 +31,7 @@ python text_to_image_generation.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--prompts "An image of a squirrel in Picasso style" \
--num_images_per_prompt 20 \
- --batch_size 4 \
+ --batch_size 7 \
--image_save_dir /tmp/stable_diffusion_images \
--use_habana \
--use_hpu_graphs \
@@ -60,11 +60,27 @@ python text_to_image_generation.py \
--bf16
```
+### Distributed inference with multiple HPUs
+Here is how to generate images with two prompts on two HPUs:
+```bash
+python ../gaudi_spawn.py \
+ --world_size 2 text_to_image_generation.py \
+ --model_name_or_path runwayml/stable-diffusion-v1-5 \
+ --prompts "An image of a squirrel in Picasso style" "A shiny flying horse taking off" \
+ --num_images_per_prompt 20 \
+ --batch_size 4 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16 \
+ --distributed
+```
+
> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster.
> You can enable this mode with `--use_hpu_graphs`.
-
### Stable Diffusion 2
[Stable Diffusion 2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion_2) can also be used to generate images with this script. Here is an example for a single prompt:
@@ -74,7 +90,7 @@ python text_to_image_generation.py \
--model_name_or_path stabilityai/stable-diffusion-2-1 \
--prompts "An image of a squirrel in Picasso style" \
--num_images_per_prompt 10 \
- --batch_size 2 \
+ --batch_size 7 \
--height 768 \
--width 768 \
--image_save_dir /tmp/stable_diffusion_images \
@@ -100,7 +116,7 @@ python text_to_image_generation.py \
--model_name_or_path "Intel/ldm3d-4c" \
--prompts "An image of a squirrel in Picasso style" \
--num_images_per_prompt 10 \
- --batch_size 2 \
+ --batch_size 7 \
--height 768 \
--width 768 \
--image_save_dir /tmp/stable_diffusion_images \
@@ -109,6 +125,23 @@ python text_to_image_generation.py \
--gaudi_config Habana/stable-diffusion-2 \
--ldm3d
```
+Here is how to generate images and depth maps with two prompts on two HPUs:
+```bash
+python ../gaudi_spawn.py \
+ --world_size 2 text_to_image_generation.py \
+ --model_name_or_path "Intel/ldm3d-4c" \
+ --prompts "An image of a squirrel in Picasso style" "A shiny flying horse taking off" \
+ --num_images_per_prompt 10 \
+ --batch_size 2 \
+ --height 768 \
+ --width 768 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion-2 \
+ --ldm3d \
+ --distributed
+```
> There are three different checkpoints for LDM3D:
> - use [original checkpoint](https://huggingface.co/Intel/ldm3d) to generate outputs from the paper
@@ -125,7 +158,7 @@ python text_to_image_generation.py \
--model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
--prompts "Sailing ship painting by Van Gogh" \
--num_images_per_prompt 20 \
- --batch_size 4 \
+ --batch_size 7 \
--image_save_dir /tmp/stable_diffusion_xl_images \
--scheduler euler_discrete \
--use_habana \
@@ -173,6 +206,25 @@ python text_to_image_generation.py \
--bf16
```
+Here is how to generate SDXL images with two prompts on two HPUs:
+```bash
+python ../gaudi_spawn.py \
+ --world_size 2 text_to_image_generation.py \
+ --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
+ --prompts "Sailing ship painting by Van Gogh" "A shiny flying horse taking off" \
+ --prompts_2 "Red tone" "Blue tone" \
+ --negative_prompts "Low quality" "Sketch" \
+ --negative_prompts_2 "Clouds" "Clouds" \
+ --num_images_per_prompt 20 \
+ --batch_size 8 \
+ --image_save_dir /tmp/stable_diffusion_xl_images \
+ --scheduler euler_discrete \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16 \
+ --distributed
+```
> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster.
> You can enable this mode with `--use_hpu_graphs`.
@@ -219,7 +271,7 @@ python text_to_image_generation.py \
--prompts "futuristic-looking woman" \
--control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
--num_images_per_prompt 20 \
- --batch_size 4 \
+ --batch_size 7 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
@@ -236,7 +288,7 @@ python text_to_image_generation.py \
--prompts "futuristic-looking woman" "a rusty robot" \
--control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
--num_images_per_prompt 10 \
- --batch_size 4 \
+ --batch_size 7 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
@@ -244,6 +296,25 @@ python text_to_image_generation.py \
--bf16
```
+Here is how to generate images conditioned by canny edge model and with two prompts on two HPUs:
+```bash
+pip install -r requirements.txt
+python ../gaudi_spawn.py \
+ --world_size 2 text_to_image_generation.py \
+ --model_name_or_path runwayml/stable-diffusion-v1-5 \
+ --controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \
+ --prompts "futuristic-looking woman" "a rusty robot" \
+ --control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \
+ --num_images_per_prompt 10 \
+ --batch_size 4 \
+ --image_save_dir /tmp/controlnet_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16 \
+ --distributed
+```
+
Here is how to generate images conditioned by open pose model:
```bash
pip install -r requirements.txt
@@ -254,7 +325,7 @@ python text_to_image_generation.py \
--control_image https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png \
--control_preprocessing_type "none" \
--num_images_per_prompt 20 \
- --batch_size 4 \
+ --batch_size 7 \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
@@ -273,13 +344,107 @@ python text_to_image_generation.py \
--prompts "bird" \
--seed 0 \
--num_images_per_prompt 10 \
- --batch_size 2 \
+ --batch_size 7 \
--image_save_dir /tmp/controlnet-2-1_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion-2
```
+## Image-to-image Generation
+
+### Single Prompt
+
+Here is how to generate images with one prompt and one image.
+Take instruct-pix2pix as an example.
+
+```bash
+pip install -r requirements.txt
+python image_to_image_generation.py \
+ --model_name_or_path "timbrooks/instruct-pix2pix" \
+ --src_image_path "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg" \
+ --prompts "turn him into cyborg" \
+ --num_images_per_prompt 20 \
+ --batch_size 4 \
+ --guidance_scale 7.5 \
+ --image_guidance_scale 1 \
+ --num_inference_steps 10 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16
+```
+
+> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
+> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster.
+> You can enable this mode with `--use_hpu_graphs`.
+
+
+### Multiple Prompts
+
+Here is how to generate images with several prompts and one image.
+```bash
+pip install -r requirements.txt
+python image_to_image_generation.py \
+ --model_name_or_path "timbrooks/instruct-pix2pix" \
+ --src_image_path "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg" \
+ --prompts "turn him into cyborg" "a strong soldier"\
+ --num_images_per_prompt 20 \
+ --batch_size 4 \
+ --guidance_scale 7.5 \
+ --image_guidance_scale 1 \
+ --num_inference_steps 10 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16
+```
+
+> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
+> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster.
+> You can enable this mode with `--use_hpu_graphs`.
+
+
+### Stable Diffusion XL Refiner
+
+Here is how to generate SDXL images with a single prompt and one image:
+```bash
+pip install -r requirements.txt
+python image_to_image_generation.py \
+ --model_name_or_path "stabilityai/stable-diffusion-xl-refiner-1.0" \
+ --src_image_path "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg" \
+ --prompts "turn him into cyborg" \
+ --num_images_per_prompt 20 \
+ --batch_size 4 \
+ --guidance_scale 7.5 \
+ --num_inference_steps 10 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16
+```
+
+### Stable Diffusion Image Variations
+
+Here is how to generate images with one image, it does not accept prompt input
+```bash
+pip install -r requirements.txt
+python image_to_image_generation.py \
+ --model_name_or_path "lambdalabs/sd-image-variations-diffusers" \
+ --src_image_path "https://github.com/SHI-Labs/Versatile-Diffusion/blob/master/assets/demo/reg_example/ghibli.jpg?raw=true" \
+ --num_images_per_prompt 20 \
+ --batch_size 4 \
+ --image_save_dir /tmp/stable_diffusion_images \
+ --guidance_scale 3 \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16
+```
+
# Stable Video Diffusion Examples
Stable Video Diffusion (SVD) was unveiled in [Stable Video Diffusion Announcement](https://stability.ai/news/stable-video-diffusion-open-ai-video-model)
@@ -323,3 +488,38 @@ python image_to_video_generation.py \
--gaudi_config Habana/stable-diffusion \
--bf16
```
+
+## Inpainting Example
+Inpainting replaces or edits specific areas of an image. For more details, please refer to [Huging Face Diffusers doc](https://huggingface.co/docs/diffusers/en/using-diffusers/inpaint).
+### Stable Diffusion Inpainting
+```bash
+python text_to_image_generation.py \
+ --model_name_or_path runwayml/stable-diffusion-inpainting \
+ --base_image https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png \
+ --mask_image https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png \
+ --prompts "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k" \
+ --seed 0 \
+ --num_images_per_prompt 12 \
+ --batch_size 4 \
+ --image_save_dir ./inpaiting_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion
+```
+
+### Stable Diffusion XL Inpainting
+```bash
+python text_to_image_generation.py \
+ --model_name_or_path diffusers/stable-diffusion-xl-1.0-inpainting-0.1\
+ --base_image https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png \
+ --mask_image https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png \
+ --prompts "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k" \
+ --seed 0 \
+ --scheduler euler_discrete \
+ --num_images_per_prompt 12 \
+ --batch_size 4 \
+ --image_save_dir ./xl_inpaiting_images \
+ --use_habana \
+ --use_hpu_graphs \
+ --gaudi_config Habana/stable-diffusion
+```
diff --git a/examples/stable-diffusion/image_to_image_generation.py b/examples/stable-diffusion/image_to_image_generation.py
new file mode 100755
index 0000000000..d24b2eba4a
--- /dev/null
+++ b/examples/stable-diffusion/image_to_image_generation.py
@@ -0,0 +1,336 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+
+import PIL
+import requests
+import torch
+from torchvision import transforms
+
+from optimum.habana.diffusers import (
+ GaudiDDIMScheduler,
+ GaudiEulerAncestralDiscreteScheduler,
+ GaudiEulerDiscreteScheduler,
+)
+from optimum.habana.utils import set_seed
+
+
+try:
+ from optimum.habana.utils import check_optimum_habana_min_version
+except ImportError:
+
+ def check_optimum_habana_min_version(*a, **b):
+ return ()
+
+
+# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks.
+check_optimum_habana_min_version("1.10.0")
+
+
+logger = logging.getLogger(__name__)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="runwayml/stable-diffusion-v1-5",
+ type=str,
+ help="Path to pre-trained model",
+ )
+ parser.add_argument(
+ "--src_image_path",
+ default=None,
+ type=str,
+ help="Path to source image",
+ )
+ # Pipeline arguments
+ parser.add_argument(
+ "--prompts",
+ type=str,
+ nargs="*",
+ default="An image of a squirrel in Picasso style",
+ help="The prompt or prompts to guide the image generation.",
+ )
+ parser.add_argument(
+ "--prompts_2",
+ type=str,
+ nargs="*",
+ default=None,
+ help="The second prompt or prompts to guide the image generation (applicable to SDXL).",
+ )
+ parser.add_argument(
+ "--num_images_per_prompt", type=int, default=1, help="The number of images to generate per prompt."
+ )
+ parser.add_argument("--batch_size", type=int, default=1, help="The number of images in a batch.")
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=0,
+ help="The height in pixels of the generated images (0=default from model config).",
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=0,
+ help="The width in pixels of the generated images (0=default from model config).",
+ )
+ parser.add_argument(
+ "--num_inference_steps",
+ type=int,
+ default=50,
+ help=(
+ "The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense"
+ " of slower inference."
+ ),
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=7.5,
+ help=(
+ "Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598)."
+ " Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,"
+ " usually at the expense of lower image quality."
+ ),
+ )
+ parser.add_argument(
+ "--image_guidance_scale",
+ type=float,
+ default=1.5,
+ help=(
+ "Image guidance scale is to push the generated image towards the inital image `image`. Image guidance"
+ "scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to"
+ "generate images that are closely linked to the source image `image`, usually at the expense of lower"
+ "image quality. This pipeline requires a value of at least `1`.used in intruct_pix2pix"
+ ),
+ )
+ parser.add_argument(
+ "--negative_prompts",
+ type=str,
+ nargs="*",
+ default=None,
+ help="The prompt or prompts not to guide the image generation.",
+ )
+ parser.add_argument(
+ "--negative_prompts_2",
+ type=str,
+ nargs="*",
+ default=None,
+ help="The second prompt or prompts not to guide the image generation (applicable to SDXL).",
+ )
+ parser.add_argument(
+ "--eta",
+ type=float,
+ default=0.0,
+ help="Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502.",
+ )
+ parser.add_argument(
+ "--output_type",
+ type=str,
+ choices=["pil", "np"],
+ default="pil",
+ help="Whether to return PIL images or Numpy arrays.",
+ )
+
+ parser.add_argument(
+ "--pipeline_save_dir",
+ type=str,
+ default=None,
+ help="The directory where the generation pipeline will be saved.",
+ )
+ parser.add_argument(
+ "--image_save_dir",
+ type=str,
+ default="./stable-diffusion-generated-images",
+ help="The directory where images will be saved.",
+ )
+
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.")
+
+ # HPU-specific arguments
+ parser.add_argument("--use_habana", action="store_true", help="Use HPU.")
+ parser.add_argument(
+ "--use_hpu_graphs", action="store_true", help="Use HPU graphs on HPU. This should lead to faster generations."
+ )
+ parser.add_argument(
+ "--gaudi_config_name",
+ type=str,
+ default="Habana/stable-diffusion",
+ help=(
+ "Name or path of the Gaudi configuration. In particular, it enables to specify how to apply Habana Mixed"
+ " Precision."
+ ),
+ )
+ parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")
+ parser.add_argument(
+ "--ldm3d", action="store_true", help="Use LDM3D to generate an image and a depth map from a given text prompt."
+ )
+ parser.add_argument(
+ "--profiling_warmup_steps",
+ default=0,
+ type=int,
+ help="Number of steps to ignore for profiling.",
+ )
+ parser.add_argument(
+ "--profiling_steps",
+ default=0,
+ type=int,
+ help="Number of steps to capture for profiling.",
+ )
+ parser.add_argument(
+ "--throughput_warmup_steps",
+ type=int,
+ default=None,
+ help="Number of steps to ignore for throughput calculation.",
+ )
+ args = parser.parse_args()
+
+ # Set image resolution
+ res = {}
+ if args.width > 0 and args.height > 0:
+ res["width"] = args.width
+ res["height"] = args.height
+ sdxl_models = ["stable-diffusion-xl", "sdxl"]
+ sdxl = False
+ kwargs = {
+ "use_habana": args.use_habana,
+ "use_hpu_graphs": args.use_hpu_graphs,
+ "gaudi_config": args.gaudi_config_name,
+ }
+
+ # Import selected pipeline
+ if any(model in args.model_name_or_path for model in sdxl_models):
+ from optimum.habana.diffusers import GaudiStableDiffusionXLImg2ImgPipeline as Img2ImgPipeline
+
+ sdxl = True
+ elif "instruct-pix2pix" in args.model_name_or_path:
+ from optimum.habana.diffusers import GaudiStableDiffusionInstructPix2PixPipeline as Img2ImgPipeline
+
+ kwargs["safety_checker"] = None
+ res["image_guidance_scale"] = args.image_guidance_scale
+ elif "image-variations" in args.model_name_or_path:
+ from optimum.habana.diffusers import GaudiStableDiffusionImageVariationPipeline as Img2ImgPipeline
+
+ kwargs["revision"] = "v2.0"
+
+ if "image-variations" in args.model_name_or_path:
+ im = PIL.Image.open(requests.get(args.src_image_path, stream=True).raw)
+ tform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Resize(
+ (224, 224),
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ antialias=False,
+ ),
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]),
+ ]
+ )
+ image = tform(im).unsqueeze(0)
+ else:
+ image = PIL.Image.open(requests.get(args.src_image_path, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ logger.setLevel(logging.INFO)
+
+ if args.bf16:
+ kwargs["torch_dtype"] = torch.bfloat16
+
+ if args.throughput_warmup_steps is not None:
+ kwargs["throughput_warmup_steps"] = args.throughput_warmup_steps
+
+ pipeline = Img2ImgPipeline.from_pretrained(
+ args.model_name_or_path,
+ **kwargs,
+ )
+ if pipeline.scheduler.config._class_name == "EulerAncestralDiscreteScheduler":
+ pipeline.scheduler = GaudiEulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+ elif pipeline.scheduler.config._class_name == "EulerDiscreteScheduler":
+ pipeline.scheduler = GaudiEulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+ else:
+ pipeline.scheduler = GaudiDDIMScheduler.from_config(pipeline.scheduler.config)
+ # Set seed before running the model
+ set_seed(args.seed)
+ # Generate images
+ if sdxl:
+ outputs = pipeline(
+ image=image,
+ prompt=args.prompts,
+ prompt_2=args.prompts_2,
+ num_images_per_prompt=args.num_images_per_prompt,
+ batch_size=args.batch_size,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ negative_prompt=args.negative_prompts,
+ negative_prompt_2=args.negative_prompts_2,
+ eta=args.eta,
+ output_type=args.output_type,
+ profiling_warmup_steps=args.profiling_warmup_steps,
+ profiling_steps=args.profiling_steps,
+ **res,
+ )
+ else:
+ outputs = pipeline(
+ image=image,
+ prompt=args.prompts,
+ num_images_per_prompt=args.num_images_per_prompt,
+ batch_size=args.batch_size,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ negative_prompt=args.negative_prompts,
+ eta=args.eta,
+ output_type=args.output_type,
+ profiling_warmup_steps=args.profiling_warmup_steps,
+ profiling_steps=args.profiling_steps,
+ **res,
+ )
+
+ # Save the pipeline in the specified directory if not None
+ if args.pipeline_save_dir is not None:
+ pipeline.save_pretrained(args.pipeline_save_dir)
+
+ # Save images in the specified directory if not None and if they are in PIL format
+ if args.image_save_dir is not None:
+ if args.output_type == "pil":
+ image_save_dir = Path(args.image_save_dir)
+ image_save_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Saving images in {image_save_dir.resolve()}...")
+ if args.ldm3d:
+ for i, rgb in enumerate(outputs.rgb):
+ rgb.save(image_save_dir / f"rgb_{i+1}.png")
+ for i, depth in enumerate(outputs.depth):
+ depth.save(image_save_dir / f"depth_{i+1}.png")
+ else:
+ for i, image in enumerate(outputs.images):
+ image.save(image_save_dir / f"image_{i+1}.png")
+ else:
+ logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py
index e52fc5fcba..5ff5ec96ad 100755
--- a/examples/stable-diffusion/text_to_image_generation.py
+++ b/examples/stable-diffusion/text_to_image_generation.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
+from accelerate import PartialState
from optimum.habana.diffusers import (
GaudiDDIMScheduler,
@@ -91,6 +92,18 @@ def main():
default=None,
help="The second prompt or prompts to guide the image generation (applicable to SDXL).",
)
+ parser.add_argument(
+ "--base_image",
+ type=str,
+ default=None,
+ help=("Path to inpaint base image"),
+ )
+ parser.add_argument(
+ "--mask_image",
+ type=str,
+ default=None,
+ help=("Path to inpaint mask image"),
+ )
parser.add_argument(
"--control_image",
type=str,
@@ -220,6 +233,7 @@ def main():
default=0,
help="Number of steps to capture for profiling.",
)
+ parser.add_argument("--distributed", action="store_true", help="Use distributed inference on multi-cards")
parser.add_argument(
"--unet_adapter_name_or_path",
default=None,
@@ -238,6 +252,11 @@ def main():
type=str,
help="Path to lora id",
)
+ parser.add_argument(
+ "--use_cpu_rng",
+ action="store_true",
+ help="Enable deterministic generation using CPU Generator",
+ )
args = parser.parse_args()
# Set image resolution
@@ -272,6 +291,10 @@ def main():
from optimum.habana.diffusers import GaudiStableDiffusionControlNetPipeline
sdxl = False
+
+ elif (args.base_image is not None) and (args.mask_image is not None):
+ from optimum.habana.diffusers import AutoPipelineForInpainting
+
elif any(model in args.model_name_or_path for model in sdxl_models):
from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline
@@ -317,9 +340,36 @@ def main():
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16
+ negative_prompts = args.negative_prompts
+ if args.distributed:
+ distributed_state = PartialState()
+ if args.negative_prompts is not None:
+ with distributed_state.split_between_processes(args.negative_prompts) as negative_prompt:
+ negative_prompts = negative_prompt
+
+ kwargs_common = {
+ "num_images_per_prompt": args.num_images_per_prompt,
+ "batch_size": args.batch_size,
+ "num_inference_steps": args.num_inference_steps,
+ "guidance_scale": args.guidance_scale,
+ "negative_prompt": negative_prompts,
+ "eta": args.eta,
+ "output_type": args.output_type,
+ "profiling_warmup_steps": args.profiling_warmup_steps,
+ "profiling_steps": args.profiling_steps,
+ }
+
+ kwargs_call.update(kwargs_common)
if args.throughput_warmup_steps is not None:
kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps
+ if args.use_cpu_rng:
+ # Patch for the deterministic generation - Need to specify CPU as the torch generator
+ generator = torch.Generator(device="cpu").manual_seed(args.seed)
+ else:
+ generator = None
+ kwargs_call["generator"] = generator
+
# Generate images
if args.control_image is not None:
model_dtype = torch.bfloat16 if args.bf16 else None
@@ -332,23 +382,17 @@ def main():
if args.lora_id:
pipeline.load_lora_weights(args.lora_id)
- # Set seed before running the model
- set_seed(args.seed)
-
- outputs = pipeline(
- prompt=args.prompts,
- image=control_image,
- num_images_per_prompt=args.num_images_per_prompt,
- batch_size=args.batch_size,
- num_inference_steps=args.num_inference_steps,
- guidance_scale=args.guidance_scale,
- negative_prompt=args.negative_prompts,
- eta=args.eta,
- output_type=args.output_type,
- profiling_warmup_steps=args.profiling_warmup_steps,
- profiling_steps=args.profiling_steps,
- **kwargs_call,
- )
+ kwargs_call["image"] = control_image
+
+ elif (args.base_image is not None) and (args.mask_image is not None):
+ from diffusers.utils import load_image
+
+ pipeline = AutoPipelineForInpainting.from_pretrained(args.model_name_or_path, **kwargs)
+ init_image = load_image(args.base_image)
+ mask_image = load_image(args.mask_image)
+ kwargs_call["image"] = init_image
+ kwargs_call["mask_image"] = mask_image
+
elif sdxl:
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.model_name_or_path,
@@ -357,24 +401,18 @@ def main():
if args.lora_id:
pipeline.load_lora_weights(args.lora_id)
- # Set seed before running the model
- set_seed(args.seed)
-
- outputs = pipeline(
- prompt=args.prompts,
- prompt_2=args.prompts_2,
- num_images_per_prompt=args.num_images_per_prompt,
- batch_size=args.batch_size,
- num_inference_steps=args.num_inference_steps,
- guidance_scale=args.guidance_scale,
- negative_prompt=args.negative_prompts,
- negative_prompt_2=args.negative_prompts_2,
- eta=args.eta,
- output_type=args.output_type,
- profiling_warmup_steps=args.profiling_warmup_steps,
- profiling_steps=args.profiling_steps,
- **kwargs_call,
- )
+ prompts_2 = args.prompts_2
+ negative_prompts_2 = args.negative_prompts_2
+ if args.distributed and args.prompts_2 is not None:
+ with distributed_state.split_between_processes(args.prompts_2) as prompt_2:
+ prompts_2 = prompt_2
+ if args.distributed and args.negative_prompts_2 is not None:
+ with distributed_state.split_between_processes(args.negative_prompts_2) as negative_prompt_2:
+ negative_prompts_2 = negative_prompt_2
+
+ kwargs_call["prompt_2"] = prompts_2
+ kwargs_call["negative_prompt_2"] = negative_prompts_2
+
else:
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
args.model_name_or_path,
@@ -392,30 +430,29 @@ def main():
pipeline.text_encoder, args.text_encoder_adapter_name_or_path
)
pipeline.text_encoder = pipeline.text_encoder.merge_and_unload()
- set_seed(args.seed)
-
- outputs = pipeline(
- prompt=args.prompts,
- num_images_per_prompt=args.num_images_per_prompt,
- batch_size=args.batch_size,
- num_inference_steps=args.num_inference_steps,
- guidance_scale=args.guidance_scale,
- negative_prompt=args.negative_prompts,
- eta=args.eta,
- output_type=args.output_type,
- profiling_warmup_steps=args.profiling_warmup_steps,
- profiling_steps=args.profiling_steps,
- **kwargs_call,
- )
+
+ set_seed(args.seed)
+
+ if args.distributed:
+ with distributed_state.split_between_processes(args.prompts) as prompt:
+ outputs = pipeline(prompt=prompt, **kwargs_call)
+ else:
+ outputs = pipeline(prompt=args.prompts, **kwargs_call)
# Save the pipeline in the specified directory if not None
if args.pipeline_save_dir is not None:
- pipeline.save_pretrained(args.pipeline_save_dir)
+ save_dir = args.pipeline_save_dir
+ if args.distributed:
+ save_dir = f"{args.pipeline_save_dir}_{distributed_state.process_index}"
+ pipeline.save_pretrained(save_dir)
# Save images in the specified directory if not None and if they are in PIL format
if args.image_save_dir is not None:
if args.output_type == "pil":
image_save_dir = Path(args.image_save_dir)
+ if args.distributed:
+ image_save_dir = Path(f"{image_save_dir}_{distributed_state.process_index}")
+
image_save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving images in {image_save_dir.resolve()}...")
if args.ldm3d:
diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py
index 9a7193db24..38e08a8541 100644
--- a/examples/stable-diffusion/training/train_text_to_image_sdxl.py
+++ b/examples/stable-diffusion/training/train_text_to_image_sdxl.py
@@ -1136,6 +1136,7 @@ def unwrap_model(model, training=False):
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
+ pipeline = None
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
@@ -1382,22 +1383,21 @@ def compute_time_ids(original_size, crops_coords_top_left):
ema_unet.copy_to(unet.parameters())
# create pipeline
- vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder=("vae" if args.pretrained_vae_model_name_or_path is None else None),
- revision=args.revision,
- variant=args.variant,
- )
- pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- vae=vae,
- unet=unwrap_model(unet),
- revision=args.revision,
- variant=args.variant,
- use_habana=True,
- use_hpu_graphs=args.use_hpu_graphs_for_inference,
- gaudi_config=args.gaudi_config_name,
- )
+ if pipeline is None:
+ pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ use_habana=True,
+ use_hpu_graphs=args.use_hpu_graphs_for_inference,
+ gaudi_config=args.gaudi_config_name,
+ )
+ else:
+ # vae and text encoders are frozen, only need to update unet
+ pipeline.unet = unwrap_model(unet)
+
if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type}
pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
@@ -1433,8 +1433,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
}
)
- del pipeline
-
if t0 is not None:
duration = time.perf_counter() - t0 - (checkpoint_time if args.adjust_throughput else 0)
ttt = time.perf_counter() - t_start
@@ -1457,13 +1455,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
ema_unet.copy_to(unet.parameters())
# Serialize pipeline.
- vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
- )
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py
index db7a4913c9..2d83568092 100755
--- a/examples/summarization/run_summarization.py
+++ b/examples/summarization/run_summarization.py
@@ -124,9 +124,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -428,6 +428,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md
index f5af6bc7d3..8f8313efdb 100644
--- a/examples/text-classification/README.md
+++ b/examples/text-classification/README.md
@@ -51,7 +51,7 @@ python run_glue.py \
--task_name mrpc \
--do_train \
--do_eval \
- --per_device_train_batch_size 32 \
+ --per_device_train_batch_size 64 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--max_seq_length 128 \
@@ -78,7 +78,7 @@ python ../gaudi_spawn.py \
--task_name mrpc \
--do_train \
--do_eval \
- --per_device_train_batch_size 32 \
+ --per_device_train_batch_size 64 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
@@ -106,7 +106,7 @@ python ../gaudi_spawn.py \
--task_name mrpc \
--do_train \
--do_eval \
- --per_device_train_batch_size 32 \
+ --per_device_train_batch_size 64 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
@@ -156,6 +156,7 @@ python run_glue.py \
--do_eval \
--max_seq_length 128 \
--output_dir ./output/mrpc/ \
+ --per_device_eval_batch_size 8 \
--use_habana \
--use_lazy_mode \
--use_hpu_graphs_for_inference \
@@ -178,7 +179,7 @@ python ../gaudi_spawn.py \
--task_name mrpc \
--do_train \
--do_eval \
- --per_device_train_batch_size 32 \
+ --per_device_train_batch_size 64 \
--per_device_eval_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py
index 155d1dd650..69a15579b0 100755
--- a/examples/text-classification/run_glue.py
+++ b/examples/text-classification/run_glue.py
@@ -213,9 +213,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -325,6 +325,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
# Loading a dataset from your local files.
@@ -458,7 +459,7 @@ def main():
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else:
logger.warning(
- "Your model seems to have been trained with labels, but they don't match the dataset: ",
+ "Your model seems to have been trained with labels, but they don't match the dataset: "
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.",
)
diff --git a/examples/text-feature-extraction/README.md b/examples/text-feature-extraction/README.md
new file mode 100644
index 0000000000..9c34ede54a
--- /dev/null
+++ b/examples/text-feature-extraction/README.md
@@ -0,0 +1,39 @@
+
+
+# Feature Extraction Examples
+
+This directory contains a script that showcases how to use text embedding models as feature extractors for text embeddings on HPUs.
+
+## Single-HPU inference
+
+```bash
+python run_feature_extraction.py \
+ --model_name_or_path Supabase/gte-small \
+ --source_sentence "What is a deep learning architecture for feature extraction?" \
+ --input_texts "There are many different variants of apples created every year." \
+ "BERT is a common machine learning architecture for text-based applications." \
+ "Alexander Hamilton is one of the founding fathers of the United States." \
+ --use_hpu_graphs \
+ --bf16
+```
+
+Models that have been validated:
+
+- [Supabase/gte-small](https://huggingface.co/Supabase/gte-small)
+- [thenlper/gte-small](https://huggingface.co/thenlper/gte-small)
+- [thenlper/gte-base](https://huggingface.co/thenlper/gte-base)
+- [thenlper/gte-large](https://huggingface.co/thenlper/gte-large)
diff --git a/examples/text-feature-extraction/run_feature_extraction.py b/examples/text-feature-extraction/run_feature_extraction.py
new file mode 100644
index 0000000000..47320b1979
--- /dev/null
+++ b/examples/text-feature-extraction/run_feature_extraction.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import time
+
+import habana_frameworks.torch as ht
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+from transformers import AutoModel, AutoTokenizer
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+# Adapted from https://huggingface.co/Supabase/gte-small example
+
+adapt_transformers_to_gaudi()
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+
+SOURCE_SENTENCE = "what is the capital of China?"
+COMPARE_TEXTS = [
+ "how to implement quick sort in Python?",
+ "Beijing",
+ "sorting algorithms",
+]
+
+
+def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="Supabase/gte-small",
+ type=str,
+ help="Path to pre-trained model",
+ )
+ parser.add_argument(
+ "--source_sentence",
+ default=SOURCE_SENTENCE,
+ type=str,
+ help="Source sentence to compare with",
+ )
+ parser.add_argument(
+ "--input_texts",
+ default=COMPARE_TEXTS,
+ type=str,
+ nargs="+",
+ help='Text input. Can be a single string (eg: --input_texts "text1"), or a list of space-separated strings (eg: --input_texts "text1" "text2")',
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ action="store_true",
+ help="Whether to wrap model in HPU graph mode (recommended)",
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ help="Whether to perform generation in bf16 precision.",
+ )
+ parser.add_argument(
+ "--warmup",
+ type=int,
+ default=3,
+ help="Number of warmup iterations for benchmarking.",
+ )
+ parser.add_argument(
+ "--n_iterations",
+ type=int,
+ default=5,
+ help="Number of inference iterations for benchmarking.",
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
+ model = AutoModel.from_pretrained(args.model_name_or_path).to("hpu")
+ if args.use_hpu_graphs:
+ model = ht.hpu.wrap_in_hpu_graph(model)
+ input_texts = [args.source_sentence] + args.input_texts
+ batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt").to("hpu")
+
+ if args.warmup:
+ logger.info(f"Initializing warmup for {args.warmup} iterations")
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16), torch.no_grad():
+ for _ in tqdm(range(args.warmup), leave=False):
+ model(**batch_dict)
+ torch.hpu.synchronize()
+
+ start_time = time.time()
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16), torch.no_grad():
+ for _ in tqdm(range(args.n_iterations), leave=False):
+ outputs = model(**batch_dict)
+ embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
+ torch.hpu.synchronize()
+ end_time = time.time()
+ logger.info(f"Total time: {end_time - start_time:.5f} s")
+ logger.info(f"Average time per iteration: {(end_time - start_time) * 1000 / args.n_iterations:.5f} ms")
+ embeddings = F.normalize(embeddings, p=2, dim=1)
+ scores = (embeddings[:1] @ embeddings[1:].T) * 100
+ logger.info(f"Scores for input texts relating to the source sentence: {scores.tolist()}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md
index e020e72a79..440b18713c 100755
--- a/examples/text-generation/README.md
+++ b/examples/text-generation/README.md
@@ -142,7 +142,7 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--bf16 \
--use_hpu_graphs \
--use_kv_cache \
---batch_size 52 \
+--batch_size 180 \
--attn_softmax_bf16 \
--limit_hpu_graphs \
--reuse_cache \
diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py
index b30c2e4447..0a24058f2a 100755
--- a/examples/text-generation/run_generation.py
+++ b/examples/text-generation/run_generation.py
@@ -285,7 +285,7 @@ def setup_parser(parser):
parser.add_argument(
"--trust_remote_code",
action="store_true",
- help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
+ help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
)
args = parser.parse_args()
diff --git a/examples/text-generation/text-generation-pipeline/README.md b/examples/text-generation/text-generation-pipeline/README.md
index 2e6e5b84fa..41b1811006 100644
--- a/examples/text-generation/text-generation-pipeline/README.md
+++ b/examples/text-generation/text-generation-pipeline/README.md
@@ -27,7 +27,8 @@ pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0
If you would like to use the pipeline with LangChain classes, you can install LangChain as follows:
```bash
-pip install langchain==0.1.16
+pip install langchain==0.2.5
+pip install langchain-huggingface
```
## Usage
@@ -139,9 +140,10 @@ python run_pipeline_langchain.py \
--use_hpu_graphs \
--use_kv_cache \
--batch_size 32 \
+ --max_input_tokens 200 \
--max_new_tokens 1024 \
--do_sample \
--device=hpu
```
-> The pipeline class has been validated for LangChain version 0.1.16 and may not work with other versions of the package.
+> The pipeline class has been validated for LangChain version 0.2.5 and may not work with other versions of the package.
diff --git a/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py b/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
index 5d19640c44..556494cd37 100644
--- a/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
+++ b/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
@@ -19,9 +19,8 @@
import math
import time
-from langchain.chains import LLMChain
-from langchain.llms import HuggingFacePipeline
-from langchain.prompts import PromptTemplate
+from langchain_core.prompts import PromptTemplate
+from langchain_huggingface.llms import HuggingFacePipeline
from pipeline import GaudiTextGenerationPipeline
from run_generation import setup_parser
@@ -42,7 +41,7 @@ def main():
pipe = GaudiTextGenerationPipeline(args, logger, use_with_langchain=True, warmup_on_init=False)
# Create LangChain object
- llm = HuggingFacePipeline(pipeline=pipe)
+ hf = HuggingFacePipeline(pipeline=pipe)
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
just say that you don't know, don't try to make up an answer.
@@ -57,7 +56,7 @@ def main():
Answer: """
prompt = PromptTemplate(input_variables=["question"], template=template)
- llm_chain = LLMChain(prompt=prompt, llm=llm)
+ chain = prompt | hf
questions = [
{"question": "Which libraries and model providers offer LLMs?"},
@@ -78,20 +77,18 @@ def main():
logger.info("LangChain warmup (graph compilation)...")
for _ in range(args.warmup):
- # Use LangChain object
- _ = llm_chain.generate(input_questions)
+ _ = chain.batch(input_questions)
torch_hpu.synchronize()
duration = 0
for iteration in range(args.n_iterations):
t0 = time.perf_counter()
- # Use LangChain object
- responses = llm_chain.generate(input_questions)
+ responses = chain.batch(input_questions)
duration += time.perf_counter() - t0
- for i, (question, answer) in enumerate(zip(input_questions, responses.generations)):
+ for i, (question, answer) in enumerate(zip(input_questions, responses)):
print(f"Question[{iteration+1}][{i+1}]: {question['question']}")
- print(f"Response[{iteration+1}][{i+1}]: {repr(answer[0].text)}\n")
+ print(f"Response[{iteration+1}][{i+1}]: {answer}\n")
throughput = args.n_iterations * args.batch_size * args.max_new_tokens / duration
print(f"Inference Duration (for {args.n_iterations} iterations): {duration} seconds")
diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py
index 987c055abe..fa1946b914 100644
--- a/examples/text-generation/utils.py
+++ b/examples/text-generation/utils.py
@@ -238,6 +238,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
assistant_model = wrap_in_hpu_graph(assistant_model)
if _is_peft_model(model):
model.base_model = wrap_in_hpu_graph(model.base_model)
+ if model.peft_type == "ADAPTION_PROMPT":
+ model.base_model.model = wrap_in_hpu_graph(model.base_model.model)
if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)
@@ -372,6 +374,17 @@ def peft_model(args, model_dtype, logger, **model_kwargs):
model.__class__.generate = gaudi_generate
model.__class__.prepare_inputs_for_generation = gaudi_prepare_inputs_for_generation
+ if model.peft_type == "ADAPTION_PROMPT":
+ from peft import tuners
+
+ from optimum.habana.peft.layer import (
+ GaudiAdaptedAttention_getattr,
+ GaudiAdaptedAttentionPreAttnForward,
+ )
+
+ tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
+ tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
+
return model
@@ -461,9 +474,28 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
return generation_config
+def exclude_hpu_graph_configs(args):
+ # Excluded configs for batch size 1 for hpu graph
+ if args.batch_size == 1 and args.limit_hpu_graphs:
+ if "falcon-180B" in args.model_name_or_path or "falcon-180b" in args.model_name_or_path:
+ return False
+ if args.world_size == 2 or args.world_size == 4 or args.world_size == 8:
+ if args.quant_config:
+ if args.max_input_tokens >= 8192 and args.max_new_tokens >= 128:
+ return False
+ else:
+ if args.max_input_tokens >= 4096 and args.max_new_tokens >= 128:
+ return False
+ return True
+ else:
+ return False
+
+
def initialize_model(args, logger):
init_start = time.perf_counter()
setup_distributed(args)
+ if exclude_hpu_graph_configs(args):
+ args.limit_hpu_graphs = False
override_prints(args.global_rank == 0 or args.verbose_workers, logger)
setup_env(args)
setup_device(args)
diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py
index db40ef8f28..942503c4ad 100644
--- a/examples/translation/run_translation.py
+++ b/examples/translation/run_translation.py
@@ -118,9 +118,9 @@ class ModelArguments:
default=False,
metadata={
"help": (
- "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
- "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
- "execute code present on the Hub on your local machine."
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
)
},
)
@@ -380,6 +380,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
diff --git a/examples/trl/README.md b/examples/trl/README.md
index e6a4f0b006..747e95be79 100644
--- a/examples/trl/README.md
+++ b/examples/trl/README.md
@@ -7,6 +7,37 @@ First, you should install the requirements:
```
$ pip install -U -r requirements.txt
```
+## Supervised Finetuning
+The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.
+
+ python sft.py \
+ --model_name_or_path "Qwen/Qwen2-7B" \
+ --dataset_name "philschmid/dolly-15k-oai-style" \
+ --streaming False \
+ --bf16 True \
+ --subset '' \
+ --output_dir ./model_qwen \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 16 \
+ --evaluation_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 3e-4 \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --max_grad_norm 0.3 \
+ --logging_steps 1 \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --throughput_warmup_steps 3 \
+ --use_peft True \
+ --lora_r 4 \
+ --lora_alpha=16 \
+ --lora_dropout=0.05 \
+ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
+ --max_seq_length 512 \
+ --adam_epsilon 1e-08
## DPO pipeline
@@ -19,10 +50,12 @@ There are two main steps to the DPO training process:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi sft.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
+ --dataset_name "lvwerra/stack-exchange-paired" \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
+ --do_train \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
@@ -60,8 +93,10 @@ steps like:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 8 --use_deepspeed sft.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
+ --dataset_name "lvwerra/stack-exchange-paired" \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--output_dir="./sft" \
+ --do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
@@ -133,7 +168,9 @@ There are three main steps to the PPO training process:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi sft.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
+ --dataset_name "lvwerra/stack-exchange-paired" \
--output_dir="./sft" \
+ --do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
@@ -160,8 +197,8 @@ There are three main steps to the PPO training process:
2. Reward modeling using dialog pairs from the SE dataset on the llama-v2-7b-se to create llama-v2-7b-se-rm
```
python ../gaudi_spawn.py --world_size 8 --use_mpi reward_modeling.py \
- --model_name=./sft/final_merged_checkpoint \
- --tokenizer_name=meta-llama/Llama-2-7b-hf \
+ --model_name_or_path=./sft/final_merged_checkpoint \
+ --tokenizer_name_or_path=meta-llama/Llama-2-7b-hf \
--output_dir=./rm
```
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
@@ -173,9 +210,9 @@ There are three main steps to the PPO training process:
3. RL fine-tuning of llama-v2-7b-se with the llama-v2-7b-se-rm reward model:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi ppo.py \
- --model_name=./sft/final_merged_checkpoint \
+ --model_name_or_path=./sft/final_merged_checkpoint \
--reward_model_name=./rm_merged_checkpoint \
- --tokenizer_name=meta-llama/Llama-2-7b-hf \
+ --tokenizer_name_or_path=meta-llama/Llama-2-7b-hf \
--adafactor=False \
--output_max_length=128 \
--batch_size=8 \
diff --git a/examples/trl/ppo.py b/examples/trl/ppo.py
index d4ad127641..b0d24805e6 100644
--- a/examples/trl/ppo.py
+++ b/examples/trl/ppo.py
@@ -1,4 +1,6 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/rl_training.py, enable it for Gaudi2
+import json
+import time
from dataclasses import dataclass, field
from typing import List, Optional
@@ -26,8 +28,10 @@ class ScriptArguments:
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
- model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
- tokenizer_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the tokenizer name"})
+ model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
+ tokenizer_name_or_path: Optional[str] = field(
+ default="meta-llama/Llama-2-7b-hf", metadata={"help": "the tokenizer name"}
+ )
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
@@ -83,7 +87,7 @@ class ScriptArguments:
dataset_name = "lvwerra/stack-exchange-paired"
config = GaudiPPOConfig(
steps=script_args.steps,
- model_name=script_args.model_name,
+ model_name=script_args.model_name_or_path,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
batch_size=script_args.batch_size,
@@ -119,7 +123,7 @@ class ScriptArguments:
sent_kwargs["padding"] = "max_length"
sent_kwargs["max_length"] = script_args.input_max_length + script_args.output_max_length
-tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
+tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.
@@ -133,6 +137,7 @@ class ScriptArguments:
def build_dataset(
tokenizer,
dataset_name="lvwerra/stack-exchange-paired",
+ input_max_length=512,
):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
@@ -168,14 +173,14 @@ def preprocess_function(examples):
num_proc=num_proc,
remove_columns=original_columns,
)
- ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
+ ds = ds.filter(lambda x: len(x["input_ids"]) < input_max_length, batched=False)
ds.set_format(type="torch")
return ds
# We retrieve the dataloader by calling the `build_dataset` function.
-dataset = build_dataset(tokenizer)
+dataset = build_dataset(tokenizer, input_max_length=script_args.input_max_length)
def collator(data):
@@ -232,7 +237,6 @@ def collator(data):
data_collator=collator,
optimizer=optimizer,
)
-
# We then build the sentiment analysis pipeline using our reward model, passing the
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
# set the device to the same device as the PPOTrainer.
@@ -283,12 +287,13 @@ def collator(data):
output_length_sampler = LengthSampler(output_min_length, output_max_length)
else:
output_length_sampler = LengthSampler(output_max_length, output_max_length + 1)
+s0 = time.time()
+sample = 0
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break
-
question_tensors = batch["input_ids"]
-
+ sample = sample + len(question_tensors)
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
@@ -308,5 +313,9 @@ def collator(data):
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
+s1 = time.time()
ppo_trainer.save_pretrained(script_args.output_dir)
+metrics = {"train_runtime": s1 - s0, "train_samples_per_second": sample / (s1 - s0)}
+with open(f"{script_args.output_dir}/all_results.json", mode="w") as file:
+ json.dump(metrics, file)
diff --git a/examples/trl/reward_modeling.py b/examples/trl/reward_modeling.py
index f67d657944..1bd8e65ecf 100644
--- a/examples/trl/reward_modeling.py
+++ b/examples/trl/reward_modeling.py
@@ -43,13 +43,13 @@ class ScriptArguments:
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
- model_name: Optional[str] = field(
+ model_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
- tokenizer_name: Optional[str] = field(
+ tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
@@ -156,10 +156,12 @@ class ScriptArguments:
)
# Load the value-head model and tokenizer.
-tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
+tokenizer_name = (
+ script_args.tokenizer_name_or_path
+ if script_args.tokenizer_name_or_path is not None
+ else script_args.model_name_or_path
+)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=script_args.token)
-tokenizer.pad_token = tokenizer.eos_token
-
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
@@ -171,14 +173,14 @@ class ScriptArguments:
)
torch.autograd.set_detect_anomaly(True)
model = AutoModelForSequenceClassification.from_pretrained(
- script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
+ script_args.model_name_or_path, num_labels=1, torch_dtype=torch.bfloat16
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
-
# Need to do this for gpt2, because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
+tokenizer.padding_side = "right"
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
model.config.use_fused_rope = False
@@ -268,7 +270,10 @@ def on_step_end(self, args, state, control, **kwargs):
trainer.add_callback(EvaluateFirstStepCallback())
-trainer.train(script_args.resume_from_checkpoint)
+train_result = trainer.train(script_args.resume_from_checkpoint)
+metrics = train_result.metrics
+trainer.log_metrics("train", metrics)
+trainer.save_metrics("train", metrics)
print("Saving last checkpoint of the model")
trainer.save_model(script_args.output_dir)
diff --git a/examples/trl/sft.py b/examples/trl/sft.py
index 6edaa43308..170526a99f 100644
--- a/examples/trl/sft.py
+++ b/examples/trl/sft.py
@@ -1,6 +1,7 @@
# Fine-Tune Llama2-7b on SE paired dataset
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/sft_llama2.py, enable it for Gaudi2
import logging
+import math
from dataclasses import dataclass, field
from typing import List, Optional
@@ -13,7 +14,6 @@
from transformers.integrations.deepspeed import (
is_deepspeed_available,
)
-from trl.trainer import ConstantLengthDataset
from optimum.habana import GaudiConfig, GaudiTrainingArguments
from optimum.habana.trl import GaudiSFTTrainer
@@ -26,15 +26,31 @@
@dataclass
class ScriptArguments:
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
- dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
+ dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"})
+ use_peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
- seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
+ max_seq_length: Optional[int] = field(default=1024, metadata={"help": "the max sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
+ validation_split_percentage: Optional[int] = field(
+ default=5,
+ metadata={
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
+ },
+ )
+ use_flash_attention: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to use Habana flash attention for fine-tuning."}
+ )
+ flash_attention_recompute: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."}
+ )
+ flash_attention_causal_mask: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."}
+ )
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
@@ -56,131 +72,147 @@ class ScriptArguments:
)
-parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments))
-script_args, training_args = parser.parse_args_into_dataclasses()
-peft_config = LoraConfig(
- r=script_args.lora_r,
- lora_alpha=script_args.lora_alpha,
- lora_dropout=script_args.lora_dropout,
- target_modules=script_args.lora_target_modules,
- bias="none",
- task_type="CAUSAL_LM",
-)
-
-if training_args.group_by_length and script_args.packing:
- raise ValueError("Cannot use both packing and group by length")
-
-set_seed(training_args.seed)
-
-
-def chars_token_ratio(dataset, tokenizer, nb_examples=400):
- """
- Estimate the average number of characters per token in the dataset.
- """
- total_characters, total_tokens = 0, 0
- for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
- text = prepare_sample_text(example)
- total_characters += len(text)
- if tokenizer.is_fast:
- total_tokens += len(tokenizer(text).tokens())
+if __name__ == "__main__":
+ parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments))
+ script_args, training_args = parser.parse_args_into_dataclasses()
+ if script_args.use_peft:
+ peft_config = LoraConfig(
+ r=script_args.lora_r,
+ lora_alpha=script_args.lora_alpha,
+ lora_dropout=script_args.lora_dropout,
+ target_modules=script_args.lora_target_modules,
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ else:
+ peft_config = None
+
+ if training_args.group_by_length and script_args.packing:
+ raise ValueError("Cannot use both packing and group by length")
+
+ set_seed(training_args.seed)
+
+ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
+ """
+ Estimate the average number of characters per token in the dataset.
+ """
+ total_characters, total_tokens = 0, 0
+ for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
+ text = prepare_sample_text(example)
+ total_characters += len(text)
+ if tokenizer.is_fast:
+ total_tokens += len(tokenizer(text).tokens())
+ else:
+ total_tokens += len(tokenizer.tokenize(text))
+
+ return total_characters / total_tokens
+
+ def prepare_sample_text(example):
+ """Prepare the text from a sample of the dataset."""
+ text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
+ return text
+
+ def create_datasets(tokenizer, args, seed=None):
+ if args.dataset_name:
+ dataset = load_dataset(
+ args.dataset_name,
+ data_dir=args.subset,
+ split=args.split,
+ token=script_args.token,
+ num_proc=args.num_workers if not args.streaming else None,
+ streaming=args.streaming,
+ )
else:
- total_tokens += len(tokenizer.tokenize(text))
-
- return total_characters / total_tokens
-
+ raise ValueError("No dataset_name")
+ if args.streaming:
+ logger.info("Loading the dataset in streaming mode")
+ valid_data = dataset.take(args.size_valid_set)
+ train_data = dataset.skip(args.size_valid_set)
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
+ else:
+ dataset = dataset.train_test_split(test_size=args.validation_split_percentage * 0.01, seed=seed)
+ train_data = dataset["train"]
+ valid_data = dataset["test"]
+ logger.info(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
+ if args.dataset_name == "lvwerra/stack-exchange-paired":
+ chars_per_token = chars_token_ratio(train_data, tokenizer)
+ logger.info(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
+ formating_func = prepare_sample_text
+ else:
+ formating_func = None
+ return train_data, valid_data, formating_func
-def prepare_sample_text(example):
- """Prepare the text from a sample of the dataset."""
- text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
- return text
+ low_cpu_mem_usage = True
+ if is_deepspeed_available():
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+ if is_deepspeed_zero3_enabled():
+ low_cpu_mem_usage = False
-def create_datasets(tokenizer, args, seed=None):
- dataset = load_dataset(
- args.dataset_name,
- data_dir=args.subset,
- split=args.split,
+ base_model = AutoModelForCausalLM.from_pretrained(
+ script_args.model_name_or_path,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ torch_dtype=torch.bfloat16,
token=script_args.token,
- num_proc=args.num_workers if not args.streaming else None,
- streaming=args.streaming,
- )
- if args.streaming:
- print("Loading the dataset in streaming mode")
- valid_data = dataset.take(args.size_valid_set)
- train_data = dataset.skip(args.size_valid_set)
- train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
- else:
- dataset = dataset.train_test_split(test_size=0.005, seed=seed)
- train_data = dataset["train"]
- valid_data = dataset["test"]
- print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
-
- chars_per_token = chars_token_ratio(train_data, tokenizer)
- print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
-
- train_dataset = ConstantLengthDataset(
- tokenizer,
- train_data,
- formatting_func=prepare_sample_text,
- infinite=True,
- seq_length=args.seq_length,
- chars_per_token=chars_per_token,
)
- valid_dataset = ConstantLengthDataset(
- tokenizer,
- valid_data,
- formatting_func=prepare_sample_text,
- infinite=False,
- seq_length=args.seq_length,
- chars_per_token=chars_per_token,
- )
- return train_dataset, valid_dataset
-
-
-low_cpu_mem_usage = True
-if is_deepspeed_available():
- from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
- if is_deepspeed_zero3_enabled():
- low_cpu_mem_usage = False
-
-base_model = AutoModelForCausalLM.from_pretrained(
- script_args.model_name_or_path,
- low_cpu_mem_usage=low_cpu_mem_usage,
- torch_dtype=torch.bfloat16,
- token=script_args.token,
-)
-base_model.config.use_cache = False
-base_model.config.use_fused_rope = False
-
-tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
-tokenizer.pad_token = tokenizer.eos_token
-tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
-
-log_level = training_args.get_process_log_level()
-logger.setLevel(log_level)
-transformers.utils.logging.set_verbosity(log_level)
-transformers.utils.logging.enable_default_handler()
-transformers.utils.logging.enable_explicit_format()
-
-train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
-
-gaudi_config = GaudiConfig()
-gaudi_config.use_fused_adam = True
-gaudi_config.use_fused_clip_norm = True
-trainer = GaudiSFTTrainer(
- model=base_model,
- gaudi_config=gaudi_config,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- peft_config=peft_config,
- packing=script_args.packing,
- max_seq_length=None,
- tokenizer=tokenizer,
- args=training_args,
-)
-train_result = trainer.train()
-trainer.save_model(training_args.output_dir)
-metrics = train_result.metrics
-trainer.log_metrics("train", metrics)
-trainer.save_metrics("train", metrics)
+ base_model.config.use_cache = False
+ if not script_args.use_flash_attention and (
+ script_args.flash_attention_recompute or script_args.flash_attention_recompute
+ ):
+ assert "Need to enable use_flash_attention"
+ base_model.generation_config.use_flash_attention = script_args.use_flash_attention
+ base_model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
+ base_model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask
+
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ train_dataset, eval_dataset, formatting_func = create_datasets(tokenizer, script_args, seed=training_args.seed)
+
+ gaudi_config = GaudiConfig()
+ gaudi_config.use_fused_adam = True
+ gaudi_config.use_fused_clip_norm = True
+ if training_args.do_train:
+ trainer = GaudiSFTTrainer(
+ model=base_model,
+ gaudi_config=gaudi_config,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ peft_config=peft_config,
+ packing=script_args.packing,
+ max_seq_length=script_args.max_seq_length,
+ tokenizer=tokenizer,
+ args=training_args,
+ formatting_func=formatting_func,
+ )
+ train_result = trainer.train()
+ trainer.save_model(training_args.output_dir)
+ metrics = train_result.metrics
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+ metrics = trainer.evaluate()
+ if isinstance(eval_dataset, torch.utils.data.IterableDataset):
+ eval_dataset = list(eval_dataset)
+
+ metrics["eval_samples"] = len(eval_dataset)
+
+ try:
+ perplexity = math.exp(metrics["eval_loss"])
+ except OverflowError:
+ perplexity = float("inf")
+ metrics["perplexity"] = perplexity
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
diff --git a/examples/video-classification/README.md b/examples/video-classification/README.md
new file mode 100644
index 0000000000..6e672b5c7c
--- /dev/null
+++ b/examples/video-classification/README.md
@@ -0,0 +1,70 @@
+
+
+# Video Classification
+
+This directory contains an example script to showcase usage of classifying video data.
+
+## Requirements
+
+First, install the requirements:
+```bash
+pip install -r requirements.txt
+```
+
+## Single-HPU inference
+
+### Single video example
+
+```bash
+python3 run_example.py \
+ --model_name_or_path MCG-NJU/videomae-base-finetuned-kinetics \
+ --video_paths "https://ak.picdn.net/shutterstock/videos/21179416/preview/stock-footage-aerial-shot-winter-forest.mp4" \
+ --use_hpu_graphs \
+ --bf16
+```
+
+Outputs:
+```
+Predicted class for stock-footage-aerial-shot-winter-forest.mp4 is sled dog racing and took 1.243e+00 seconds
+```
+
+### Multi-video example
+
+```bash
+python3 run_example.py \
+ --model_name_or_path MCG-NJU/videomae-base-finetuned-kinetics \
+ --use_hpu_graphs \
+ --bf16 \
+ --warm_up_epochs 3 \
+ --video_paths "https://ak.picdn.net/shutterstock/videos/5629184/preview/stock-footage-senior-couple-looking-through-binoculars-on-sailboat-together-shot-on-red-epic-for-high-quality-k.mp4" \
+ "https://ak.picdn.net/shutterstock/videos/21179416/preview/stock-footage-aerial-shot-winter-forest.mp4" \
+ "https://ak.picdn.net/shutterstock/videos/1063125190/preview/stock-footage-a-beautiful-cookie-with-oranges-lies-on-a-green-tablecloth.mp4" \
+ "https://ak.picdn.net/shutterstock/videos/1039695998/preview/stock-footage-japanese-highrise-office-skyscrapers-tokyo-square.mp4" \
+ "https://ak.picdn.net/shutterstock/videos/9607838/preview/stock-footage-zrenjanin-serbia-march-fans-watching-live-concert-bokeh-blur-urban-background-x.mp4"
+```
+
+Outputs:
+```
+Predicted class for stock-footage-senior-couple-looking-through-binoculars-on-sailboat-together-shot-on-red-epic-for-high-quality-k.mp4 is sailing and took 3.372e-01 seconds
+Predicted class for stock-footage-aerial-shot-winter-forest.mp4 is sled dog racing and took 3.360e-01 seconds
+Predicted class for stock-footage-a-beautiful-cookie-with-oranges-lies-on-a-green-tablecloth.mp4 is cooking sausages and took 3.349e-01 seconds
+Predicted class for stock-footage-japanese-highrise-office-skyscrapers-tokyo-square.mp4 is marching and took 3.362e-01 seconds
+Predicted class for stock-footage-zrenjanin-serbia-march-fans-watching-live-concert-bokeh-blur-urban-background-x.mp4 is slacklining and took 3.358e-01 seconds
+```
+
+Models that have been validated:
+- [MCG-NJU/videomae-base-finetuned-kinetics](https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics)
diff --git a/examples/video-classification/requirements.txt b/examples/video-classification/requirements.txt
new file mode 100644
index 0000000000..308106f0c9
--- /dev/null
+++ b/examples/video-classification/requirements.txt
@@ -0,0 +1 @@
+decord
diff --git a/examples/video-classification/run_example.py b/examples/video-classification/run_example.py
new file mode 100644
index 0000000000..b593fb5955
--- /dev/null
+++ b/examples/video-classification/run_example.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Loosely adapted from https://github.com/huggingface/optimum-habana/pull/783/files#diff-8361a5cbb8a1de8387eaff47125cce70f695f2a5994c66725c942c071835e82b
+
+import argparse
+import io
+import logging
+import os
+import time
+
+import decord
+import habana_frameworks.torch as ht
+import requests
+import torch
+from tqdm import tqdm
+from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+adapt_transformers_to_gaudi()
+
+
+def load_video(path):
+ vr = decord.VideoReader(path)
+ batch = vr.get_batch(list(range(16))).asnumpy()
+ buf = [batch[i, :, :, :] for i in range(16)]
+ logging.info(batch.shape)
+ return buf
+
+
+def download_file(link: str):
+ resp = requests.get(link)
+ return io.BytesIO(resp.content)
+
+
+def get_image_buffers(video_paths: list[str]):
+ for vp in video_paths:
+ logging.info(f"Extracting images from {vp}")
+ try:
+ if vp.startswith("https://") or vp.startswith("http://"):
+ file = download_file(vp)
+ yield load_video(file)
+ elif os.path.isfile(vp):
+ yield load_video(vp)
+ else:
+ logging.error(f"Video path {vp} is not link or a file.")
+ except Exception as e:
+ logging.error(f"Error extracting video information from {vp}")
+ logging.error(f"Trace: {e}")
+ continue
+
+
+def infer(model, inputs, cast_bf16: bool):
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=cast_bf16), torch.no_grad():
+ outputs = model(**inputs)
+ torch.hpu.synchronize()
+ predicted_class_idx = outputs.logits.argmax(-1).item()
+ class_str = model.config.id2label[predicted_class_idx]
+ return class_str
+
+
+def run(
+ model_name: str,
+ video_paths: list[str],
+ warm_up_epcohs: int,
+ use_hpu_graphs: bool,
+ cast_bf16: bool,
+):
+ processor = VideoMAEImageProcessor.from_pretrained(model_name)
+ device = torch.device("hpu")
+ model = VideoMAEForVideoClassification.from_pretrained(model_name)
+ if use_hpu_graphs:
+ model = ht.hpu.wrap_in_hpu_graph(model)
+ model = model.to(device)
+ model.eval()
+
+ bufs = list(get_image_buffers(video_paths))
+
+ start_time = time.time()
+ if warm_up_epcohs:
+ logging.info(f"Warming up model with {warm_up_epcohs} epochs")
+ for i in tqdm(range(warm_up_epcohs), leave=False):
+ for buf in bufs:
+ inputs = processor(buf, return_tensors="pt")
+ inputs.to(device)
+ infer(model, inputs, cast_bf16)
+ if warm_up_epcohs:
+ end_time = time.time()
+ logging.info(f"Completed warm up in {end_time - start_time:.3e} seconds")
+
+ for i, buf in enumerate(bufs):
+ start_time = time.time()
+ inputs = processor(buf, return_tensors="pt")
+ inputs.to(device)
+ class_str = infer(model, inputs, cast_bf16)
+ end_time = time.time()
+
+ print(
+ f"Predicted class for {video_paths[i].split('/')[-1]} is {class_str} and took {end_time - start_time:.3e} seconds"
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="MCG-NJU/videomae-base-finetuned-kinetics",
+ type=str,
+ help="Path to pre-trained model",
+ )
+ parser.add_argument(
+ "--video_paths",
+ default=[
+ "https://ak.picdn.net/shutterstock/videos/21179416/preview/stock-footage-aerial-shot-winter-forest.mp4"
+ ],
+ type=str,
+ nargs="*",
+ help="Paths to video input. Can specify multiple in a space-separated list",
+ )
+ parser.add_argument(
+ "--warm_up_epochs",
+ "-w",
+ default=0,
+ type=int,
+ help="Number of epochs to warm up the model",
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ "-g",
+ action="store_true",
+ help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
+ )
+ parser.add_argument(
+ "--bf16",
+ "-b",
+ action="store_true",
+ help="Whether to perform in bf16 precision.",
+ )
+ parser.add_argument(
+ "--log_level",
+ default=None,
+ type=int,
+ help="Log level for printout information",
+ )
+
+ args = parser.parse_args()
+
+ logging_config = {"format": "[%(levelname)s]%(asctime)s : %(message)s"}
+ if args.log_level:
+ logging_config["level"] = args.log_level
+ logging.basicConfig(**logging_config)
+ logging.info(f"Config: {args}")
+
+ if args.warm_up_epochs <= 0:
+ logging.warning("No warm up sequence, inference time may be inaccurate.")
+
+ run(
+ args.model_name_or_path,
+ args.video_paths,
+ args.warm_up_epochs,
+ args.use_hpu_graphs,
+ args.bf16,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/visual-question-answering/README.md b/examples/visual-question-answering/README.md
index efbe1f2a92..ed4d2a8a66 100644
--- a/examples/visual-question-answering/README.md
+++ b/examples/visual-question-answering/README.md
@@ -16,10 +16,10 @@ limitations under the License.
# Visual Question Answering Examples
-This directory contains a script that showcases how to use the Transformers pipeline API to run visual question answering task on HPUs.
-
## Single-HPU inference
+The `run_pipeline.py` script showcases how to use the Transformers pipeline API to run visual question answering task on HPUs.
+
```bash
python3 run_pipeline.py \
--model_name_or_path Salesforce/blip-vqa-capfilt-large \
@@ -32,4 +32,37 @@ python3 run_pipeline.py \
Models that have been validated:
- [Salesforce/blip-vqa-base](https://huggingface.co/Salesforce/blip-vqa-base)
- [dandelin/vilt-b32-finetuned-vqa](https://huggingface.co/dandelin/vilt-b32-finetuned-vqa)
- - [Salesforce/blip-vqa-capfilt-large](https://huggingface.co/Salesforce/blip-vqa-capfilt-large)
\ No newline at end of file
+ - [Salesforce/blip-vqa-capfilt-large](https://huggingface.co/Salesforce/blip-vqa-capfilt-large)
+
+## OpenCLIP inference
+
+The `run_openclip_vqa.py` can be used to run zero shot image classification with [OpenCLIP Huggingface Models](https://huggingface.co/docs/hub/en/open_clip#using-openclip-at-hugging-face).
+The requirements for `run_openclip_vqa.py` can be installed with `openclip_requirements.txt` as follows:
+
+```bash
+pip install -r openclip_requirements.txt
+```
+
+By default, the script runs the sample outlined in [BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 notebook](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/blob/main/biomed_clip_example.ipynb) which can be run as follows:
+
+```bash
+python run_openclip_vqa.py \
+ --use_hpu_graphs \
+ --bf16
+```
+
+One can also run other OpenCLIP models by specifying model, classifier labels and image URL(s) like so:
+
+```bash
+python run_openclip_vqa.py \
+ --model_name_or_path laion/CLIP-ViT-g-14-laion2B-s12B-b42K \
+ --labels "a dog" "a cat" \
+ --image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \
+ --use_hpu_graphs \
+ --bf16
+```
+
+Models that have been validated:
+ - [BiomedCLIP-PubMedBERT_256-vit_base_patch16_224](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224)
+ - [laion/CLIP-ViT-g-14-laion2B-s12B-b42K](https://huggingface.co/laion/CLIP-ViT-g-14-laion2B-s12B-b42K)
+ - [apple/DFN5B-CLIP-ViT-H-14](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14/tree/main)
\ No newline at end of file
diff --git a/examples/visual-question-answering/openclip_requirements.txt b/examples/visual-question-answering/openclip_requirements.txt
new file mode 100644
index 0000000000..c132e5eb90
--- /dev/null
+++ b/examples/visual-question-answering/openclip_requirements.txt
@@ -0,0 +1,3 @@
+open_clip_torch==2.23.0
+matplotlib
+
diff --git a/examples/visual-question-answering/run_openclip_vqa.py b/examples/visual-question-answering/run_openclip_vqa.py
new file mode 100644
index 0000000000..76b4159149
--- /dev/null
+++ b/examples/visual-question-answering/run_openclip_vqa.py
@@ -0,0 +1,232 @@
+# This script is based on https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/blob/main/biomed_clip_example.ipynb
+import argparse
+import json
+import logging
+import os
+import time
+from pathlib import Path
+from pprint import pprint
+from urllib.request import urlopen
+
+import matplotlib.pyplot as plt
+import numpy
+import torch
+from open_clip import create_model_from_pretrained, get_tokenizer, model
+from PIL import Image
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+DATASET_URL = "https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/"
+LABELS = [
+ "adenocarcinoma histopathology",
+ "brain MRI",
+ "covid line chart",
+ "squamous cell carcinoma histopathology",
+ "immunohistochemistry histopathology",
+ "bone X-ray",
+ "chest X-ray",
+ "pie chart",
+ "hematoxylin and eosin histopathology",
+]
+
+TEST_IMGS = [
+ "squamous_cell_carcinoma_histopathology.jpeg",
+ "H_and_E_histopathology.jpg",
+ "bone_X-ray.jpg",
+ "adenocarcinoma_histopathology.jpg",
+ "covid_line_chart.png",
+ "IHC_histopathology.jpg",
+ "chest_X-ray.jpg",
+ "brain_MRI.jpg",
+ "pie_chart.png",
+]
+
+
+def plot_images_with_metadata(images: list, metadata, output_dir: str, plot_name: str) -> None:
+ print(f"plottypes {type(images)} {type(metadata)} {type(output_dir)} {type(plot_name)}")
+
+ num_images = len(images)
+ fig, axes = plt.subplots(nrows=num_images, ncols=1, figsize=(5, 5 * num_images))
+
+ for i, (img_path, metadata) in enumerate(zip(images, metadata)):
+ img = Image.open(urlopen(img_path))
+ if isinstance(axes, list) or isinstance(axes, numpy.ndarray):
+ ax = axes[i]
+ else:
+ ax = axes
+ ax.imshow(img)
+ ax.axis("off")
+ ax.set_title(f"{metadata['filename']}\n{metadata['top_probs']}", fontsize=14)
+
+ plt.tight_layout()
+ plt.savefig(f"{output_dir}/{plot_name}.png")
+
+
+def run_qa(model: model, images: torch.Tensor, texts: torch.Tensor, device: torch.device) -> tuple:
+ with torch.no_grad():
+ image_features, text_features, logit_scale = model(images, texts)
+ logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
+ sorted_indices = torch.argsort(logits, dim=-1, descending=True)
+ return sorted_indices, logits
+
+
+def postprocess(args: argparse.Namespace, sorted_indices: torch.Tensor, logits: torch.Tensor, topk: int) -> list:
+ logits = logits.float().cpu().numpy()
+ sorted_indices = sorted_indices.int().cpu().numpy()
+ metadata_list = []
+ for i, img in enumerate(args.image_path):
+ img_name = img.split("/")[-1]
+
+ top_probs = []
+ topk = len(args.labels) if topk == -1 else topk
+ for j in range(topk):
+ jth_index = sorted_indices[i][j]
+ top_probs.append(f"{args.labels[jth_index]}: {logits[i][jth_index] * 100:.1f}")
+
+ metadata = {"filename": img_name, "top_probs": "\n".join(top_probs)}
+ metadata_list.append(metadata)
+ return metadata_list
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
+ type=str,
+ help="Path to pre-trained model",
+ )
+ parser.add_argument(
+ "--image_path",
+ default=[DATASET_URL + img for img in TEST_IMGS],
+ type=str,
+ nargs="*",
+ help='Path to image as input. Can be a single string (eg: --image_path "URL1"), or a list of space-separated strings (eg: --image_path "URL1" "URL2")',
+ )
+ parser.add_argument(
+ "--topk",
+ default=1,
+ type=int,
+ help="topk num. Provides top K probabilities for the labels provided.",
+ )
+ parser.add_argument(
+ "--prompt",
+ default="this is a picture of ",
+ type=str,
+ help='Prompt for classification. It should be a string separated by comma. (eg: --prompt "a photo of ")',
+ )
+ parser.add_argument(
+ "--labels",
+ default=LABELS,
+ type=str,
+ nargs="*",
+ help='Labels for classification (eg: --labels "LABEL1"), or a list of space-separated strings (eg: --labels "LABEL1" "LABEL2")',
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ action="store_true",
+ help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ help="Whether to perform in bf16 precision.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=os.getcwd(),
+ type=str,
+ help="Output directory to store results in.",
+ )
+ parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
+ parser.add_argument(
+ "--n_iterations", type=int, default=10, help="Number of inference iterations for benchmarking."
+ )
+ parser.add_argument("--plot_images", action="store_true", help="Plot images with metadata for verification")
+ parser.add_argument(
+ "--plot_name",
+ default="openclip_vqa_plot",
+ type=str,
+ help="Name of the plot generated with the image and corresponding top K results",
+ )
+ parser.add_argument(
+ "--print_result",
+ action="store_true",
+ help="Whether to print the zero shot classification results.",
+ )
+
+ args = parser.parse_args()
+
+ adapt_transformers_to_gaudi()
+
+ precision = "fp32"
+ dtype = torch.float32
+ if args.bf16:
+ precision = "bf16"
+ dtype = torch.bfloat16
+
+ model, preprocess = create_model_from_pretrained(f"hf-hub:{args.model_name_or_path}", precision=precision)
+ tokenizer = get_tokenizer(f"hf-hub:{args.model_name_or_path}")
+
+ device = torch.device("hpu") if torch.hpu.is_available() else torch.device("cpu")
+ device_type = "hpu" if torch.hpu.is_available() else "cpu"
+
+ # Initialize model
+ if args.use_hpu_graphs:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ model = wrap_in_hpu_graph(model)
+ model = model.to(device)
+ model.eval()
+
+ images = torch.stack([preprocess(Image.open(urlopen(img))) for img in args.image_path]).to(device)
+ texts = tokenizer([args.prompt + l for l in args.labels]).to(device)
+
+ # Warm up
+ logger.info("Running warmup")
+ for i in range(args.warmup):
+ with torch.autocast(device_type=device_type, dtype=dtype, enabled=True):
+ _, _ = run_qa(model, images, texts, device=device)
+
+ logger.info("Running inference")
+ start = time.time()
+ for i in range(args.n_iterations):
+ logits = None
+ with torch.autocast(device_type=device_type, dtype=dtype, enabled=True):
+ sorted_indices, logits = run_qa(model, images, texts, device=device)
+ end = time.time()
+
+ # Results and metrics
+ metadata_list = []
+ metadata_list = postprocess(args, sorted_indices, logits, args.topk)
+ if args.print_result:
+ logger.info("Results from the last iteration:")
+ pprint(metadata_list)
+ inference_time_per_iteration = (end - start) * 1000 / args.n_iterations
+ logger.info(f"Inference Time per iteration = {inference_time_per_iteration:.4}ms")
+ throughput = len(args.image_path) * args.n_iterations / (end - start)
+ logger.info(f"Throughput = {throughput:.4} images/s")
+
+ # Store results if necessary
+ if args.output_dir is not None:
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ results = {"throughput": throughput, "inference time per iteration ": inference_time_per_iteration}
+ with (output_dir / "results.json").open("w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=4)
+ if args.plot_images:
+ plot_images_with_metadata(args.image_path, metadata_list, args.output_dir, args.plot_name)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py
index fcfd47c319..d826908ee5 100644
--- a/optimum/habana/accelerate/accelerator.py
+++ b/optimum/habana/accelerate/accelerator.py
@@ -37,7 +37,6 @@
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
- FP8RecipeKwargs,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
@@ -73,12 +72,11 @@
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
+ GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
- te_forward_convert,
- te_setup_fp8_recipe_handler,
- te_wrap_fp8,
- te_wrap_fp8_forward_convert,
+ convert_model,
+ get_fp8_recipe,
)
@@ -113,7 +111,6 @@ def __init__(
dynamo_backend: GaudiDynamoBackend | str | None = None,
distribution_strategy: str = None,
force_autocast: bool = False,
- fp8_recipe_format: str = None,
):
self.trackers = []
if project_config is not None:
@@ -181,7 +178,6 @@ def __init__(
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
- self.fp8_recipe_format = None
self.autocast_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
@@ -203,9 +199,9 @@ def __init__(
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
- elif isinstance(handler, FP8RecipeKwargs):
+ elif isinstance(handler, GaudiFP8RecipeKwargs):
if self.fp8_recipe_handler is not None:
- raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
+ raise ValueError("You can only pass one `GaudiFP8RecipeKwargs` in `kwargs_handler`.")
else:
self.fp8_recipe_handler = handler
elif isinstance(handler, AutocastKwargs):
@@ -225,8 +221,14 @@ def __init__(
_from_accelerator=True,
**kwargs,
)
- if self.fp8_recipe_handler is None and self.state.is_fp8_enabled:
- self.fp8_recipe_handler = te_setup_fp8_recipe_handler(self.fp8_recipe_format)
+
+ if self.state.is_fp8_enabled:
+ if self.fp8_recipe_handler is None:
+ self.fp8_recipe_handler = GaudiFP8RecipeKwargs()
+ # Handling FP8 recipe creation in init since both `prepare_model` and `_prepare_deepspeed` require it.
+ # (Base accelerator handles this in `prepare_model` function)
+ self.fp8_recipe_handler = get_fp8_recipe(self.fp8_recipe_handler)
+
trackers = filter_trackers(log_with, self.logging_dir)
if len(trackers) < 1 and log_with is not None:
warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
@@ -349,31 +351,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
else:
model.forward = convert_outputs_to_fp32(new_forward)
- elif self.state.is_fp8_enabled:
- model = te_wrap_fp8_forward_convert(model, self.fp8_recipe_handler)
- # FP8 is not supported on Gaudi2 yet
- # elif self.mixed_precision == "fp8":
- # if not has_transformer_engine_layers(model):
- # with torch.no_grad():
- # convert_model(model)
- # model._converted_to_transformer_engine = True
- # model._original_forward = model.forward
-
- # kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
- # if "fp8_format" in kwargs:
- # kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
- # fp8_recipe = te_recipe.DelayedScaling(**kwargs)
- # cuda_device_capacity = torch.cuda.get_device_capability()
- # fp8_enabled = cuda_device_capacity[0] >= 9 or (
- # cuda_device_capacity[0] == 8 and cuda_device_capacity[1] >= 9
- # )
- # if not fp8_enabled:
- # logger.warn(
- # f"The current device has compute capability of {cuda_device_capacity} which is "
- # "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
- # "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
- # )
- # model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
+ if self.state.is_fp8_enabled:
+ model = convert_model(model)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
@@ -469,7 +448,7 @@ def _prepare_deepspeed(self, *args):
result = [
self._prepare_one(obj, first_pass=True)
if isinstance(obj, torch.utils.data.DataLoader)
- else te_wrap_fp8(obj)
+ else convert_model(obj)
if isinstance(obj, torch.nn.Module) and self.state.is_fp8_enabled
else obj
for obj in args
@@ -685,8 +664,6 @@ def _prepare_deepspeed(self, *args):
result[i] = scheduler
# pointing for deepspeed_engine_wrapped.backward()
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
- if self.state.is_fp8_enabled:
- model = te_forward_convert(engine, self.fp8_recipe_handler)
self._models.append(engine)
if optimizer is not None:
self._optimizers.append(optimizer)
diff --git a/optimum/habana/accelerate/utils/__init__.py b/optimum/habana/accelerate/utils/__init__.py
index 37181aee9b..ee25954b95 100755
--- a/optimum/habana/accelerate/utils/__init__.py
+++ b/optimum/habana/accelerate/utils/__init__.py
@@ -1,12 +1,12 @@
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
+ GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
from .transformer_engine import (
- te_forward_convert,
- te_setup_fp8_recipe_handler,
- te_wrap_fp8,
- te_wrap_fp8_forward_convert,
+ FP8ContextWrapper,
+ convert_model,
+ get_fp8_recipe,
)
diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py
index eaf5f09158..fce2c06c8c 100644
--- a/optimum/habana/accelerate/utils/dataclasses.py
+++ b/optimum/habana/accelerate/utils/dataclasses.py
@@ -20,7 +20,7 @@
import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
-from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
+from accelerate.utils.dataclasses import BaseEnum, KwargsHandler, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool
@@ -144,3 +144,47 @@ def __post_init__(self):
if self.sync_module_states:
device = torch.device("hpu")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
+
+
+@dataclass
+class GaudiFP8RecipeKwargs(KwargsHandler):
+ """
+ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `transformer-engine`.
+
+ Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/dataclasses.py#L180
+
+ Args:
+ margin (`int`, *optional*, defaults to 0):
+ The margin to use for the scaling factor computation.
+ interval (`int`, *optional*, defaults to 16):
+ The interval to use for how often the scaling factor is recomputed.
+ fp8_format (`str`, *optional*, defaults to "HYBRID"):
+ The format to use for the FP8 recipe. Must be one of `E5M2` or `HYBRID`.
+ amax_history_len (`int`, *optional*, defaults to 1):
+ The length of the history to use for the scaling factor computation
+ amax_compute_algo (`str`, *optional*, defaults to "most_recent"):
+ The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
+ reduce_amax (`bool`, *optional*, defaults to "False"):
+ By default, if `torch.distributed` is initialized, the `amax` value for FP8
+ tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
+ call). This keeps the amaxes and scaling factors synced across the given
+ distributed group. If set to `False`, this reduction is skipped and every
+ HPU maintains local amaxes and scaling factors. To ensure results are
+ numerically identical across checkpointing boundaries in this case, all
+ ranks must checkpoint in order to store the local tensors.
+ """
+
+ margin: int = 0
+ interval: int = 16
+ fp8_format: str = "HYBRID"
+ amax_compute_algo: str = "most_recent"
+ amax_history_len: int = 1
+ reduce_amax: bool = False
+
+ def __post_init__(self):
+ self.fp8_format = self.fp8_format.upper()
+ assert self.fp8_format in ("E5M2", "HYBRID"), "Only E5M2 and HYBRID FP8 formats are currently supported."
+ assert self.amax_compute_algo in (
+ "max",
+ "most_recent",
+ ), "Only max and most_recent `amax_compute_algo` modes are currently supported."
diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py
index 07aa71aa6c..823da61d5c 100755
--- a/optimum/habana/accelerate/utils/transformer_engine.py
+++ b/optimum/habana/accelerate/utils/transformer_engine.py
@@ -13,61 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+
import torch
-te = None
+has_transformer_engine = False
-class SwitchableForwardMaker:
- def __init__(self, module, fp8_recipe_handler):
- self.original_forward = module.forward
- self.fp8_forward = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe_handler)(module.forward)
- self.module = module
- module.forward = self.forward
+def import_te():
+ global te, has_transformer_engine
+ try:
+ import habana_frameworks.torch.hpex.experimental.transformer_engine as te
- def forward(self, *args, **kwargs):
- if self.module.training:
- return self.fp8_forward(*args, **kwargs)
- else:
- return self.original_forward(*args, **kwargs)
+ has_transformer_engine = True
- @staticmethod
- def convert(module, fp8_recipe_handler):
- SwitchableForwardMaker(module, fp8_recipe_handler)
+ except ImportError:
+ has_transformer_engine = False
-def get_te():
- global te
- if te is None:
- try:
- import habana_frameworks.torch.hpex.experimental.transformer_engine as te
+def is_fp8_available():
+ if not has_transformer_engine:
+ import_te()
+ return has_transformer_engine
- te = te
- except ImportError:
- te = None
-
-def convert_model(model, to_transformer_engine=True, _convert_linear=True):
+def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
"""
- Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
+ Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
"""
- if te is None:
+ if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
- from peft.tuners.lora.layer import Linear as PEFTLinear
-
- from optimum.habana.peft.layer import LoRALinear
-
for name, module in model.named_children():
- if type(module) == PEFTLinear and to_transformer_engine and _convert_linear:
- LoRALinear.replace_forward(module)
- if (
- isinstance(module, torch.nn.Linear)
- and not type(module) == PEFTLinear
- and to_transformer_engine
- and _convert_linear
- ):
+ if isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
+ # Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
module.in_features,
module.out_features,
@@ -81,11 +61,14 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True):
te_module.bias = module.bias
setattr(model, name, te_module)
-
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
new_module = torch.nn.Linear(
- module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
+ module.in_features,
+ module.out_features,
+ bias=has_bias,
+ dtype=module.weight.dtype,
+ device=module.weight.device,
)
new_module.weight.copy_(module.weight)
if has_bias:
@@ -93,14 +76,14 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True):
setattr(model, name, new_module)
else:
- convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)
+ _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)
def has_transformer_engine_layers(model):
"""
Returns whether a given model has some `transformer_engine` layer or not.
"""
- if te is None:
+ if not is_fp8_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
for m in model.modules():
if isinstance(m, (te.Linear)):
@@ -108,38 +91,80 @@ def has_transformer_engine_layers(model):
return False
-def te_setup_fp8_recipe_handler(fp8_recipe_format):
- get_te()
- fp8_format = te.recipe.Format.E5M2
- if fp8_recipe_format == "E4M3":
- fp8_format = te.recipe.Format.E4M3
- elif fp8_recipe_format == "HYBRID":
- fp8_format = te.recipe.Format.HYBRID
- fp8_recipe_handler = te.recipe.DelayedScaling(
- fp8_format=fp8_format,
- margin=0,
- interval=16,
- amax_history_len=1,
- amax_compute_algo="most_recent",
- reduce_amax=False,
- )
- fp8_recipe_handler.backend = "TE"
- return fp8_recipe_handler
-
-
-def te_wrap_fp8(model):
+def convert_model(model):
+ """
+ Converts torch.nn.Linear modules to `transformers_engine` Linear modules.
+ Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1303
+ """
if not has_transformer_engine_layers(model):
with torch.no_grad():
- convert_model(model)
+ _convert_model(model)
model._converted_to_transformer_engine = True
return model
-def te_wrap_fp8_forward_convert(model, fp8_recipe_handler):
- model = te_wrap_fp8(model)
- SwitchableForwardMaker.convert(model, fp8_recipe_handler)
- return model
+def get_fp8_recipe(fp8_recipe_handler):
+ """
+ Creates transformer engine FP8 recipe object.
+ Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309
+ """
+ if not is_fp8_available():
+ raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.")
+ kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {}
+ if "fp8_format" in kwargs:
+ kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"])
+ fp8_recipe_handler = te.recipe.DelayedScaling(**kwargs)
+ fp8_recipe_handler.backend = "TE"
+ return fp8_recipe_handler
+
+
+class FP8ContextWrapper:
+ """
+ Helper class for FP8 context related operations.
+ """
+
+ def __init__(self, ctx, fp8_recipe):
+ self.ctx = ctx
+ self.fp8_ctx = self.create_fp8_context(fp8_recipe)
+
+ def __enter__(self):
+ self.ctx.__enter__()
+ self.fp8_ctx.__enter__()
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ self.fp8_ctx.__exit__(exc_type, exc_value, exc_traceback)
+ self.ctx.__exit__(exc_type, exc_value, exc_traceback)
+
+ @staticmethod
+ def create_fp8_context(fp8_recipe):
+ return te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
+
+ @staticmethod
+ def _gradient_checkpointing_wrap(func, *args, **kwargs):
+ """
+ `_gradient_checkpointing_func` always takes the function to be recomputed as the first argument. The function
+ below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context.
+ """
+ _args = list(args)
+ _args[0] = te.distributed.activation_checkpointing()(_args[0])
+ args = tuple(_args)
+
+ return func(*args, **kwargs)
+
+ @staticmethod
+ def gradient_checkpointing_wrap(model):
+ """
+ Wrap `_gradient_checkpointing_func` in the model with `transformer_engine`'s `activation_checkpointing` context.
+ This context is used to signal the `transformer_engine` modules whether they have been called with activation checkpointing enabled or not.
+ """
+ if hasattr(model, "gradient_checkpointing") and model.gradient_checkpointing:
+ model._gradient_checkpointing_func = functools.partial(
+ FP8ContextWrapper._gradient_checkpointing_wrap, model._gradient_checkpointing_func
+ )
+ return
-def te_forward_convert(model, fp8_recipe_handler):
- SwitchableForwardMaker.convert(model, fp8_recipe_handler)
+ for module in model.modules():
+ if hasattr(module, "gradient_checkpointing") and module.gradient_checkpointing:
+ module._gradient_checkpointing_func = functools.partial(
+ FP8ContextWrapper._gradient_checkpointing_wrap, module._gradient_checkpointing_func
+ )
diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py
index 26d5d2d359..e4057b553a 100644
--- a/optimum/habana/diffusers/__init__.py
+++ b/optimum/habana/diffusers/__init__.py
@@ -1,8 +1,18 @@
+from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image
from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline
from .pipelines.pipeline_utils import GaudiDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
+from .pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation import (
+ GaudiStableDiffusionImageVariationPipeline,
+)
+from .pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import GaudiStableDiffusionInpaintPipeline
+from .pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix import (
+ GaudiStableDiffusionInstructPix2PixPipeline,
+)
from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import GaudiStableDiffusionUpscalePipeline
from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline
+from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import GaudiStableDiffusionXLImg2ImgPipeline
+from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import GaudiStableDiffusionXLInpaintPipeline
from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline
from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler
diff --git a/optimum/habana/diffusers/pipelines/auto_pipeline.py b/optimum/habana/diffusers/pipelines/auto_pipeline.py
new file mode 100644
index 0000000000..77171c9502
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/auto_pipeline.py
@@ -0,0 +1,141 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/auto_pipeline.py
+- Added GAUDI_PREFIX_NAME to support Gaudi pipeline in _gaudi_get_task_class.
+- Only AutoPipelineForText2Image and AutoPipelineForInpainting are retained, and reimplement the from_pretrained and from_pipe to support the Gaudi pipelines.
+"""
+
+from collections import OrderedDict
+
+from diffusers.pipelines import (
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
+ auto_pipeline,
+)
+from huggingface_hub.utils import validate_hf_hub_args
+
+from .controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline
+from .stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
+from .stable_diffusion.pipeline_stable_diffusion_inpaint import GaudiStableDiffusionInpaintPipeline
+from .stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline
+from .stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import GaudiStableDiffusionXLInpaintPipeline
+
+
+GAUDI_PREFIX_NAME = "Gaudi"
+
+GAUDI_AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", GaudiStableDiffusionPipeline),
+ ("stable-diffusion-xl", GaudiStableDiffusionXLPipeline),
+ ("stable-diffusion-controlnet", GaudiStableDiffusionControlNetPipeline),
+ ]
+)
+
+
+GAUDI_AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion", GaudiStableDiffusionInpaintPipeline),
+ ("stable-diffusion-xl", GaudiStableDiffusionXLInpaintPipeline),
+ ]
+)
+
+
+GAUDI_SUPPORTED_TASKS_MAPPINGS = [
+ GAUDI_AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
+ GAUDI_AUTO_INPAINT_PIPELINES_MAPPING,
+]
+
+
+def _gaudi_get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
+ def get_model(pipeline_class_name):
+ for task_mapping in GAUDI_SUPPORTED_TASKS_MAPPINGS:
+ for model_name, pipeline in task_mapping.items():
+ if pipeline.__name__ == pipeline_class_name:
+ return model_name
+
+ pipeline_class_name = GAUDI_PREFIX_NAME + pipeline_class_name
+ model_name = get_model(pipeline_class_name)
+
+ if model_name is not None:
+ task_class = mapping.get(model_name, None)
+ if task_class is not None:
+ return task_class
+
+ if throw_error_if_not_exist:
+ raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
+
+
+class AutoPipelineForText2Image(AutoPipelineForText2Image):
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
+ orig_supported_mappings = auto_pipeline.SUPPORTED_TASKS_MAPPINGS
+ orig_txt2img_mappings = auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING
+ orig_func = auto_pipeline._get_task_class
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = GAUDI_SUPPORTED_TASKS_MAPPINGS
+ auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING = GAUDI_AUTO_TEXT2IMAGE_PIPELINES_MAPPING
+ auto_pipeline._get_task_class = _gaudi_get_task_class
+ pipeline = super().from_pretrained(pretrained_model_or_path, **kwargs)
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = orig_supported_mappings
+ auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING = orig_txt2img_mappings
+ auto_pipeline._get_task_class = orig_func
+ return pipeline
+
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ orig_supported_mappings = auto_pipeline.SUPPORTED_TASKS_MAPPINGS
+ orig_txt2img_mappings = auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING
+ orig_func = auto_pipeline._get_task_class
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = GAUDI_SUPPORTED_TASKS_MAPPINGS
+ auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING = GAUDI_AUTO_TEXT2IMAGE_PIPELINES_MAPPING
+ auto_pipeline._get_task_class = _gaudi_get_task_class
+ model = super().from_pipe(pipeline, **kwargs)
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = orig_supported_mappings
+ auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING = orig_txt2img_mappings
+ auto_pipeline._get_task_class = orig_func
+ return model
+
+
+class AutoPipelineForInpainting(AutoPipelineForInpainting):
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
+ orig_supported_mappings = auto_pipeline.SUPPORTED_TASKS_MAPPINGS
+ orig_inpaint_mappings = auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING
+ orig_func = auto_pipeline._get_task_class
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = GAUDI_SUPPORTED_TASKS_MAPPINGS
+ auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING = GAUDI_AUTO_INPAINT_PIPELINES_MAPPING
+ auto_pipeline._get_task_class = _gaudi_get_task_class
+ pipeline = super().from_pretrained(pretrained_model_or_path, **kwargs)
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = orig_supported_mappings
+ auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING = orig_inpaint_mappings
+ auto_pipeline._get_task_class = orig_func
+ return pipeline
+
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ orig_supported_mappings = auto_pipeline.SUPPORTED_TASKS_MAPPINGS
+ orig_inpaint_mappings = auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING
+ orig_func = auto_pipeline._get_task_class
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = GAUDI_SUPPORTED_TASKS_MAPPINGS
+ auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING = GAUDI_AUTO_INPAINT_PIPELINES_MAPPING
+ auto_pipeline._get_task_class = _gaudi_get_task_class
+ model = super().from_pipe(pipeline, **kwargs)
+ auto_pipeline.SUPPORTED_TASKS_MAPPINGS = orig_supported_mappings
+ auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING = orig_inpaint_mappings
+ auto_pipeline._get_task_class = orig_func
+ return model
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 1096dec9ab..118ec641ff 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -244,6 +244,7 @@ def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, negative
zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
):
prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
+
prompt_embeds_batches = torch.stack(prompt_embeds_batches)
return latents_batches, prompt_embeds_batches, num_dummy_samples
@@ -431,10 +432,9 @@ def __call__(
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)
-
if ip_adapter_image is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
- ip_adapter_image, device, batch_size * num_images_per_prompt
+ ip_adapter_image, device, num_prompts * num_images_per_prompt
)
# 4. Prepare timesteps
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
new file mode 100644
index 0000000000..1c5964b3f7
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -0,0 +1,506 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from math import ceil
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionImageVariationPipeline, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from optimum.utils import logging
+
+from ....transformers.gaudi_configuration import GaudiConfig
+from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
+from ..pipeline_utils import GaudiDiffusionPipeline
+from .pipeline_stable_diffusion import GaudiStableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__)
+
+
+class GaudiStableDiffusionImageVariationPipeline(GaudiDiffusionPipeline, StableDiffusionImageVariationPipeline):
+ """
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+ - Generation is performed by batches
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ use_habana (bool, defaults to `False`):
+ Whether to use Gaudi (`True`) or CPU (`False`).
+ use_hpu_graphs (bool, defaults to `False`):
+ Whether to use HPU graphs or not.
+ gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
+ Gaudi configuration to use. Can be a string to download it from the Hub.
+ Or a previously initialized config can be passed.
+ bf16_full_eval (bool, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit.
+ This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ use_habana: bool = False,
+ use_hpu_graphs: bool = False,
+ gaudi_config: Union[str, GaudiConfig] = None,
+ bf16_full_eval: bool = False,
+ ):
+ GaudiDiffusionPipeline.__init__(
+ self,
+ use_habana,
+ use_hpu_graphs,
+ gaudi_config,
+ bf16_full_eval,
+ )
+
+ # Workaround for Synapse 1.11 for full bf16
+ if bf16_full_eval:
+ unet.conv_in.float()
+
+ StableDiffusionImageVariationPipeline.__init__(
+ self,
+ vae,
+ image_encoder,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ requires_safety_checker,
+ )
+
+ self.to(self._device)
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder_hpu(image)
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != num_images:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective number"
+ f" of images of {num_images}. Make sure the number of images matches the length of the generators."
+ )
+
+ if latents is None:
+ # torch.randn is broken on HPU so running it on CPU
+ rand_device = "cpu" if device.type == "hpu" else device
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(num_images)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @classmethod
+ def _split_inputs_into_batches(cls, batch_size, latents, image_embeds, do_classifier_free_guidance):
+ # Use torch.split to generate num_batches batches of size batch_size
+ latents_batches = list(torch.split(latents, batch_size))
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.chunk(image_embeds, 2)[0]
+ image_embeds = torch.chunk(image_embeds, 2)[1]
+ else:
+ negative_prompt_embeds = None
+
+ image_embeds_batches = list(torch.split(image_embeds, batch_size))
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
+
+ # If the last batch has less samples than batch_size, pad it with dummy samples
+ num_dummy_samples = 0
+ if latents_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - latents_batches[-1].shape[0]
+ # Pad latents_batches
+ sequence_to_stack = (latents_batches[-1],) + tuple(
+ torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ latents_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad image_embeds_batches
+ sequence_to_stack = (image_embeds_batches[-1],) + tuple(
+ torch.zeros_like(image_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ image_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if negative_prompt_embeds is not None:
+ sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Stack batches in the same tensor
+ latents_batches = torch.stack(latents_batches)
+ if negative_prompt_embeds is not None:
+ for i in range(len(negative_prompt_embeds_batches)):
+ image_embeds_batches[i] = torch.cat([negative_prompt_embeds_batches[i], image_embeds_batches[i]])
+ image_embeds_batches = torch.stack(image_embeds_batches)
+ return latents_batches, image_embeds_batches, num_dummy_samples
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ batch_size: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ profiling_warmup_steps: Optional[int] = 0,
+ profiling_steps: Optional[int] = 0,
+ **kwargs,
+ ):
+ """
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+ - Added batch_size args
+ """
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ num_images = 1
+ elif isinstance(image, list):
+ num_images = len(image)
+ else:
+ num_images = image.shape[0]
+
+ num_batches = ceil((num_images_per_prompt * num_images) / batch_size)
+ logger.info(
+ f"{num_images} image(s) received, {num_images_per_prompt} generation(s) per prompt,"
+ f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
+ )
+ if num_batches < 3:
+ logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device="cpu")
+ timesteps = self.scheduler.timesteps.to(device)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ num_images * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Split into batches (HPU-specific step)
+ latents_batches, image_embeddings_batches, num_dummy_samples = self._split_inputs_into_batches(
+ batch_size,
+ latents,
+ image_embeddings,
+ do_classifier_free_guidance,
+ )
+ outputs = {
+ "images": [],
+ "has_nsfw_concept": [],
+ }
+ hb_profiler = HabanaProfile(
+ warmup=profiling_warmup_steps,
+ active=profiling_steps,
+ record_shapes=False,
+ )
+ hb_profiler.start()
+
+ # 8. Denoising loop
+ t0 = time.time()
+ t1 = t0
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+ use_warmup_inference_steps = (
+ num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
+ )
+ for j in self.progress_bar(range(num_batches)):
+ # The throughput is calculated from the 3rd iteration
+ # because compilation occurs in the first two iterations
+ if j == throughput_warmup_steps:
+ t1 = time.time()
+ if use_warmup_inference_steps:
+ t0_inf = time.time()
+ latents_batch = latents_batches[0]
+ latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
+ image_embeddings_batch = image_embeddings_batches[0]
+ image_embeddings_batches = torch.roll(image_embeddings_batches, shifts=-1, dims=0)
+ for i in range(len(timesteps)):
+ if use_warmup_inference_steps and i == throughput_warmup_steps:
+ t1_inf = time.time()
+ t1 += t1_inf - t0_inf
+ t = timesteps[0]
+ timesteps = torch.roll(timesteps, shifts=-1, dims=0)
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet_hpu(latent_model_input, t, encoder_hidden_states=image_embeddings_batch)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_batch = self.scheduler.step(noise_pred, t, latents_batch, **extra_step_kwargs).prev_sample
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents_batch)
+ hb_profiler.step()
+ if use_warmup_inference_steps:
+ t1 = warmup_inference_steps_time_adjustment(
+ t1, t1_inf, num_inference_steps, throughput_warmup_steps
+ )
+ if not output_type == "latent":
+ image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents_batch
+ outputs["images"].append(image)
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ hb_profiler.stop()
+ speed_metrics_prefix = "generation"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=t0,
+ num_samples=num_batches * batch_size
+ if t1 == t0 or use_warmup_inference_steps
+ else (num_batches - throughput_warmup_steps) * batch_size,
+ num_steps=num_batches,
+ start_time_after_warmup=t1,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
+ # Remove dummy generations if needed
+ if num_dummy_samples > 0:
+ outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
+
+ # Process generated images
+ for i, image in enumerate(outputs["images"][:]):
+ if i == 0:
+ outputs["images"].clear()
+
+ if output_type == "latent":
+ has_nsfw_concept = None
+ else:
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if output_type == "pil" and isinstance(image, list):
+ outputs["images"] += image
+ elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray):
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = np.concatenate((outputs["images"], image), axis=0)
+ else:
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = torch.cat((outputs["images"], image), 0)
+
+ if has_nsfw_concept is not None:
+ outputs["has_nsfw_concept"] += has_nsfw_concept
+ else:
+ outputs["has_nsfw_concept"] = None
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (outputs["images"], outputs["has_nsfw_concept"])
+
+ return GaudiStableDiffusionPipelineOutput(
+ images=outputs["images"],
+ nsfw_content_detected=outputs["has_nsfw_concept"],
+ throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"],
+ )
+
+ @torch.no_grad()
+ def unet_hpu(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ ):
+ if self.use_hpu_graphs:
+ return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
+ else:
+ return self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )[0]
+
+ @torch.no_grad()
+ def capture_replay(self, latent_model_input, timestep, encoder_hidden_states):
+ inputs = [latent_model_input, timestep, encoder_hidden_states, False]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+ outputs = self.unet(inputs[0], inputs[1], encoder_hidden_states=inputs[2], return_dict=inputs[3])[0]
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
+
+ @torch.no_grad()
+ def image_encoder_hpu(
+ self,
+ image,
+ ):
+ if self.use_hpu_graphs:
+ return self.image_capture_replay(image)
+ else:
+ return self.image_encoder(image).image_embeds
+
+ @torch.no_grad()
+ def image_capture_replay(self, image):
+ inputs = [image]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+ outputs = self.image_encoder(inputs[0]).image_embeds
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000..6b4331c763
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,819 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from math import ceil
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy
+import torch
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline, StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import retrieve_timesteps
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import deprecate, logging
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ....transformers.gaudi_configuration import GaudiConfig
+from ....utils import speed_metrics, warmup_inference_steps_time_adjustment
+from ..pipeline_utils import GaudiDiffusionPipeline
+from .pipeline_stable_diffusion import GaudiStableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GaudiStableDiffusionInpaintPipeline(GaudiDiffusionPipeline, StableDiffusionInpaintPipeline):
+ r"""
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L222
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+
+
+ Pipeline for text-guided image inpainting using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ use_habana (bool, defaults to `False`):
+ Whether to use Gaudi (`True`) or CPU (`False`).
+ use_hpu_graphs (bool, defaults to `False`):
+ Whether to use HPU graphs or not.
+ gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
+ Gaudi configuration to use. Can be a string to download it from the Hub.
+ Or a previously initialized config can be passed.
+ bf16_full_eval (bool, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit.
+ This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
+ """
+
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"]
+
+ def __init__(
+ self,
+ vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ use_habana: bool = True,
+ use_hpu_graphs: bool = True,
+ gaudi_config: Union[str, GaudiConfig] = None,
+ bf16_full_eval: bool = False,
+ ):
+ GaudiDiffusionPipeline.__init__(
+ self,
+ use_habana,
+ use_hpu_graphs,
+ gaudi_config,
+ bf16_full_eval,
+ )
+
+ StableDiffusionInpaintPipeline.__init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ image_encoder,
+ requires_safety_checker,
+ )
+
+ self.to(self._device)
+
+ @classmethod
+ def _split_inputs_into_batches(
+ cls, batch_size, latents, prompt_embeds, negative_prompt_embeds, mask, masked_image_latents
+ ):
+ # Use torch.split to generate num_batches batches of size batch_size
+ latents_batches = list(torch.split(latents, batch_size))
+ prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
+ mask_batches = list(torch.split(mask, batch_size))
+ masked_image_latents_batches = list(torch.split(masked_image_latents, batch_size))
+
+ # If the last batch has less samples than batch_size, pad it with dummy samples
+ num_dummy_samples = 0
+ if latents_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - latents_batches[-1].shape[0]
+ # Pad latents_batches
+ sequence_to_stack = (latents_batches[-1],) + tuple(
+ torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ latents_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad prompt_embeds_batches
+ sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad negative_prompt_embeds_batches if necessary
+ if negative_prompt_embeds is not None:
+ sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if mask_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - mask_batches[-1].shape[0]
+ # Pad mask_batches
+ sequence_to_stack = (mask_batches[-1],) + tuple(
+ torch.zeros_like(mask_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ mask_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if masked_image_latents_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - masked_image_latents_batches[-1].shape[0]
+ # Pad masked_image_latents_batches
+ sequence_to_stack = (masked_image_latents_batches[-1],) + tuple(
+ torch.zeros_like(masked_image_latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ masked_image_latents_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Stack batches in the same tensor
+ latents_batches = torch.stack(latents_batches)
+ if negative_prompt_embeds is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate(
+ zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
+ ):
+ prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
+
+ prompt_embeds_batches = torch.stack(prompt_embeds_batches)
+ mask_batches = torch.stack(mask_batches)
+ masked_image_latents_batches = torch.stack(masked_image_latents_batches)
+
+ return latents_batches, prompt_embeds_batches, num_dummy_samples, mask_batches, masked_image_latents_batches
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ batch_size: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: int = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
+ if passing latents directly it is not encoded again.
+ mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+ and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images in a batch.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ num_prompts = 1
+ elif prompt is not None and isinstance(prompt, list):
+ num_prompts = len(prompt)
+ else:
+ num_prompts = prompt_embeds.shape[0]
+ num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size)
+ logger.info(
+ f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt,"
+ f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
+ )
+ if num_batches < 3:
+ logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ if ip_adapter_image is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ device,
+ num_prompts * num_images_per_prompt,
+ )
+
+ # 4. set timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ num_prompts * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ num_prompts * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if (
+ num_channels_latents + num_channels_mask + num_channels_masked_image
+ != self.unet.config.in_channels
+ ):
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+ use_warmup_inference_steps = (
+ num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
+ )
+
+ self._num_timesteps = len(timesteps)
+
+ # 11. Split into batches (HPU-specific step)
+ latents_batches, prompt_embeds_batches, num_dummy_samples, mask_batches, masked_image_latents_batches = (
+ self._split_inputs_into_batches(
+ batch_size,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ mask,
+ masked_image_latents,
+ )
+ )
+
+ outputs = {
+ "images": [],
+ "has_nsfw_concept": [],
+ }
+ t0 = time.time()
+ t1 = t0
+
+ for j in self.progress_bar(range(num_batches)):
+ # The throughput is calculated from the 3rd iteration
+ # because compilation occurs in the first two iterations
+ if j == throughput_warmup_steps:
+ t1 = time.time()
+ if use_warmup_inference_steps:
+ t0_inf = time.time()
+
+ latents_batch = latents_batches[0]
+ latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
+ prompt_embeds_batch = prompt_embeds_batches[0]
+ prompt_embeds_batches = torch.roll(prompt_embeds_batches, shifts=-1, dims=0)
+ mask_batch = mask_batches[0]
+ mask_batches = torch.roll(mask_batches, shifts=-1, dims=0)
+ masked_image_latents_batch = masked_image_latents_batches[0]
+ masked_image_latents_batches = torch.roll(masked_image_latents_batches, shifts=-1, dims=0)
+
+ for i in range(len(timesteps)):
+ if use_warmup_inference_steps and i == throughput_warmup_steps:
+ t1_inf = time.time()
+ t1 += t1_inf - t0_inf
+
+ if self.interrupt:
+ continue
+
+ timestep = timesteps[0]
+ timesteps = torch.roll(timesteps, shifts=-1, dims=0)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
+ )
+ mask_batch_input = torch.cat([mask_batch] * 2) if self.do_classifier_free_guidance else mask_batch
+ masked_image_latents_batch_input = (
+ torch.cat([masked_image_latents_batch] * 2)
+ if self.do_classifier_free_guidance
+ else masked_image_latents_batch
+ )
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat(
+ [latent_model_input, mask_batch_input, masked_image_latents_batch_input], dim=1
+ )
+ # predict the noise residual
+ noise_pred = self.unet_hpu(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=prompt_embeds_batch,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+ noise_pred.to(torch.float)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_batch = self.scheduler.step(
+ noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False
+ )[0]
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask_batch.chunk(2)
+ else:
+ init_mask = mask_batch
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents_batch = (1 - init_mask) * init_latents_proper + init_mask * latents_batch
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ k_batch = k + "_batch"
+ callback_kwargs[k] = locals()[k_batch]
+ callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
+
+ latents_batch = callback_outputs.pop("latents", latents_batch)
+ prompt_embeds_batch = callback_outputs.pop("prompt_embeds", prompt_embeds_batch)
+
+ mask_batch = callback_outputs.pop("mask", mask_batch)
+ masked_image_latents_batch = callback_outputs.pop(
+ "masked_image_latents", masked_image_latents_batch
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, timestep, latents_batch)
+
+ if use_warmup_inference_steps:
+ t1 = warmup_inference_steps_time_adjustment(
+ t1, t1_inf, num_inference_steps, throughput_warmup_steps
+ )
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
+ image = self.vae.decode(
+ latents_batch / self.vae.config.scaling_factor,
+ return_dict=False,
+ generator=generator,
+ **condition_kwargs,
+ )[0]
+ else:
+ image = latents_batch
+
+ outputs["images"].append(image)
+
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ # Remove dummy generations if needed
+ if num_dummy_samples > 0:
+ outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
+
+ speed_metrics_prefix = "inpainting"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=t0,
+ num_samples=num_batches * batch_size
+ if t1 == t0 or use_warmup_inference_steps
+ else (num_batches - throughput_warmup_steps) * batch_size,
+ num_steps=num_batches,
+ start_time_after_warmup=t1,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
+
+ # Process generated images
+ for i, image in enumerate(outputs["images"][:]):
+ if i == 0:
+ outputs["images"].clear()
+
+ if output_type == "latent":
+ has_nsfw_concept = None
+ else:
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, j, crops_coords) for j in image
+ ]
+
+ if output_type == "pil" and isinstance(image, list):
+ outputs["images"] += image
+ elif output_type in ["np", "numpy"] and isinstance(image, numpy.ndarray):
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = numpy.concatenate((outputs["images"], image), axis=0)
+ else:
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = torch.cat((outputs["images"], image), 0)
+
+ if has_nsfw_concept is not None:
+ outputs["has_nsfw_concept"] += has_nsfw_concept
+ else:
+ outputs["has_nsfw_concept"] = None
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ if not return_dict:
+ return (outputs["images"], outputs["has_nsfw_concept"])
+
+ return GaudiStableDiffusionPipelineOutput(
+ images=outputs["images"],
+ nsfw_content_detected=outputs["has_nsfw_concept"],
+ throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"],
+ )
+
+ @torch.no_grad()
+ def unet_hpu(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ return_dict=False,
+ ):
+ if self.use_hpu_graphs:
+ return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
+ else:
+ return self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ @torch.no_grad()
+ def capture_replay(self, latent_model_input, timestep, encoder_hidden_states):
+ inputs = [latent_model_input, timestep, encoder_hidden_states, False]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+ outputs = self.unet(inputs[0], inputs[1], inputs[2], inputs[3])[0]
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
new file mode 100644
index 0000000000..f87c59ece4
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -0,0 +1,592 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from math import ceil
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionInstructPix2PixPipeline, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import deprecate
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from optimum.utils import logging
+
+from ....transformers.gaudi_configuration import GaudiConfig
+from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
+from ..pipeline_utils import GaudiDiffusionPipeline
+from .pipeline_stable_diffusion import GaudiStableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__)
+
+
+class GaudiStableDiffusionInstructPix2PixPipeline(GaudiDiffusionPipeline, StableDiffusionInstructPix2PixPipeline):
+ """
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+ - Generation is performed by batches
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ use_habana (bool, defaults to `False`):
+ Whether to use Gaudi (`True`) or CPU (`False`).
+ use_hpu_graphs (bool, defaults to `False`):
+ Whether to use HPU graphs or not.
+ gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
+ Gaudi configuration to use. Can be a string to download it from the Hub.
+ Or a previously initialized config can be passed.
+ bf16_full_eval (bool, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit.
+ This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ requires_safety_checker: bool = True,
+ use_habana: bool = False,
+ use_hpu_graphs: bool = False,
+ gaudi_config: Union[str, GaudiConfig] = None,
+ bf16_full_eval: bool = False,
+ ):
+ GaudiDiffusionPipeline.__init__(
+ self,
+ use_habana,
+ use_hpu_graphs,
+ gaudi_config,
+ bf16_full_eval,
+ )
+
+ # Workaround for Synapse 1.11 for full bf16
+ if bf16_full_eval:
+ unet.conv_in.float()
+
+ StableDiffusionInstructPix2PixPipeline.__init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ image_encoder,
+ requires_safety_checker,
+ )
+
+ self.to(self._device)
+
+ def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != num_images:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective number"
+ f" of images of {num_images}. Make sure the number of images matches the length of the generators."
+ )
+
+ if latents is None:
+ # torch.randn is broken on HPU so running it on CPU
+ rand_device = "cpu" if device.type == "hpu" else device
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(num_images)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @classmethod
+ def _split_inputs_into_batches(
+ cls, batch_size, latents, prompt_embeds, image_latents, do_classifier_free_guidance
+ ):
+ # Use torch.split to generate num_batches batches of size batch_size
+ latents_batches = list(torch.split(latents, batch_size))
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.chunk(prompt_embeds, 3)[2]
+ prompt_embeds = torch.chunk(prompt_embeds, 3)[0]
+ uncond_image_latents = torch.chunk(image_latents, 3)[2]
+ image_latents = torch.chunk(image_latents, 3)[0]
+ else:
+ negative_prompt_embeds = None
+ uncond_image_latents = None
+
+ prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
+ image_latents_batches = list(torch.split(image_latents, batch_size))
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
+ if uncond_image_latents is not None:
+ uncond_image_latents_batches = list(torch.split(uncond_image_latents, batch_size))
+
+ # If the last batch has less samples than batch_size, pad it with dummy samples
+ num_dummy_samples = 0
+ if latents_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - latents_batches[-1].shape[0]
+ # Pad latents_batches
+ sequence_to_stack = (latents_batches[-1],) + tuple(
+ torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ latents_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Pad image latents_batches
+ sequence_to_stack = (image_latents_batches[-1],) + tuple(
+ torch.zeros_like(image_latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ image_latents_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Pad prompt_embeds_batches
+ sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if negative_prompt_embeds is not None:
+ sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if uncond_image_latents is not None:
+ sequence_to_stack = (uncond_image_latents_batches[-1],) + tuple(
+ torch.zeros_like(uncond_image_latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ uncond_image_latents_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Stack batches in the same tensor
+ latents_batches = torch.stack(latents_batches)
+ if negative_prompt_embeds is not None:
+ for i in range(len(negative_prompt_embeds_batches)):
+ prompt_embeds_batches[i] = torch.cat(
+ [prompt_embeds_batches[i], negative_prompt_embeds_batches[i], negative_prompt_embeds_batches[i]]
+ )
+ prompt_embeds_batches = torch.stack(prompt_embeds_batches)
+ if uncond_image_latents is not None:
+ for i in range(len(uncond_image_latents_batches)):
+ image_latents_batches[i] = torch.cat(
+ [image_latents_batches[i], image_latents_batches[i], uncond_image_latents_batches[i]]
+ )
+ image_latents_batches = torch.stack(image_latents_batches)
+ return latents_batches, prompt_embeds_batches, image_latents_batches, num_dummy_samples
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.5,
+ image_guidance_scale: float = 1.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ batch_size: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ profiling_warmup_steps: Optional[int] = 0,
+ profiling_steps: Optional[int] = 0,
+ **kwargs,
+ ):
+ """
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+ - Added batch_size args
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
+ # 0. Check inputs
+ self.check_inputs(
+ prompt,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._image_guidance_scale = image_guidance_scale
+
+ device = self._execution_device
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ num_prompts = 1
+ elif prompt is not None and isinstance(prompt, list):
+ num_prompts = len(prompt)
+ else:
+ num_prompts = prompt_embeds.shape[0]
+
+ num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size)
+ logger.info(
+ f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt,"
+ f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
+ )
+ if num_batches < 3:
+ logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
+
+ # check if scheduler is in sigmas space
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
+
+ # 2. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device="cpu")
+ timesteps = self.scheduler.timesteps.to(device)
+
+ # 5. Prepare Image latents
+ image_latents = self.prepare_image_latents(
+ image,
+ num_prompts,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ self.do_classifier_free_guidance,
+ )
+
+ height, width = image_latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ num_prompts * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # 7. Check that shapes of latents and image match the UNet channels
+ num_channels_image = image_latents.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 9. Split into batches (HPU-specific step)
+ latents_batches, prompt_embeds_batches, image_latents_batches, num_dummy_samples = (
+ self._split_inputs_into_batches(
+ batch_size,
+ latents,
+ prompt_embeds,
+ image_latents,
+ self.do_classifier_free_guidance,
+ )
+ )
+ outputs = {
+ "images": [],
+ "has_nsfw_concept": [],
+ }
+ hb_profiler = HabanaProfile(
+ warmup=profiling_warmup_steps,
+ active=profiling_steps,
+ record_shapes=False,
+ )
+ hb_profiler.start()
+
+ # 10. Denoising loop
+ t0 = time.time()
+ t1 = t0
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+ use_warmup_inference_steps = (
+ num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
+ )
+ for j in self.progress_bar(range(num_batches)):
+ # The throughput is calculated from the 3rd iteration
+ # because compilation occurs in the first two iterations
+ if j == throughput_warmup_steps:
+ t1 = time.time()
+ if use_warmup_inference_steps:
+ t0_inf = time.time()
+
+ latents_batch = latents_batches[0]
+ latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
+ image_latents_batch = image_latents_batches[0]
+ image_latents_batches = torch.roll(image_latents_batches, shifts=-1, dims=0)
+ prompt_embeds_batch = prompt_embeds_batches[0]
+ prompt_embeds_batches = torch.roll(prompt_embeds_batches, shifts=-1, dims=0)
+
+ for i in range(len(timesteps)):
+ if use_warmup_inference_steps and i == throughput_warmup_steps:
+ t1_inf = time.time()
+ t1 += t1_inf - t0_inf
+ t = timesteps[0]
+ timesteps = torch.roll(timesteps, shifts=-1, dims=0)
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents_batch] * 3) if self.do_classifier_free_guidance else latents_batch
+ )
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents_batch], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet_hpu(
+ scaled_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_batch,
+ added_cond_kwargs=added_cond_kwargs,
+ )
+
+ if scheduler_is_in_sigma_space:
+ step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
+ sigma = self.scheduler.sigmas[step_index]
+ noise_pred = latent_model_input - sigma * noise_pred
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
+ )
+
+ if scheduler_is_in_sigma_space:
+ noise_pred = (noise_pred - latents_batch) / (-sigma)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_batch = self.scheduler.step(
+ noise_pred, t, latents_batch, **extra_step_kwargs, return_dict=False
+ )[0]
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents_batch = callback_outputs.pop("latents", latents_batch)
+ prompt_embeds_batch = callback_outputs.pop("prompt_embeds", prompt_embeds_batch)
+ image_latents_batch = callback_outputs.pop("image_latents", image_latents_batch)
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents_batch)
+ hb_profiler.step()
+ if use_warmup_inference_steps:
+ t1 = warmup_inference_steps_time_adjustment(
+ t1, t1_inf, num_inference_steps, throughput_warmup_steps
+ )
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents_batch
+ outputs["images"].append(image)
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ hb_profiler.stop()
+ speed_metrics_prefix = "generation"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=t0,
+ num_samples=num_batches * batch_size
+ if t1 == t0 or use_warmup_inference_steps
+ else (num_batches - throughput_warmup_steps) * batch_size,
+ num_steps=num_batches,
+ start_time_after_warmup=t1,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
+ # Remove dummy generations if needed
+ if num_dummy_samples > 0:
+ outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
+
+ # Process generated images
+ for i, image in enumerate(outputs["images"][:]):
+ if i == 0:
+ outputs["images"].clear()
+
+ if output_type == "latent":
+ has_nsfw_concept = None
+ else:
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if output_type == "pil" and isinstance(image, list):
+ outputs["images"] += image
+ elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray):
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = np.concatenate((outputs["images"], image), axis=0)
+ else:
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = torch.cat((outputs["images"], image), 0)
+
+ if has_nsfw_concept is not None:
+ outputs["has_nsfw_concept"] += has_nsfw_concept
+ else:
+ outputs["has_nsfw_concept"] = None
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (outputs["images"], outputs["has_nsfw_concept"])
+
+ return GaudiStableDiffusionPipelineOutput(
+ images=outputs["images"],
+ nsfw_content_detected=outputs["has_nsfw_concept"],
+ throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"],
+ )
+
+ @torch.no_grad()
+ def unet_hpu(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ added_cond_kwargs,
+ ):
+ if self.use_hpu_graphs:
+ return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
+ else:
+ return self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ @torch.no_grad()
+ def capture_replay(self, latent_model_input, timestep, encoder_hidden_states):
+ inputs = [latent_model_input, timestep, encoder_hidden_states, False]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+ outputs = self.unet(inputs[0], inputs[1], encoder_hidden_states=inputs[2], return_dict=inputs[3])[0]
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index e9e5596d71..d010883bbd 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -93,7 +93,7 @@ class GaudiStableDiffusionXLPipeline(GaudiDiffusionPipeline, StableDiffusionXLPi
Whether to use HPU graphs or not.
gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
Gaudi configuration to use. Can be a string to download it from the Hub.
- Or a previously initialized config can be passed.
+ Or a previously initialized config can be passed.
bf16_full_eval (bool, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit.
This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
new file mode 100644
index 0000000000..1320cb11a9
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -0,0 +1,794 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from math import ceil
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import deprecate
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from optimum.utils import logging
+
+from ....transformers.gaudi_configuration import GaudiConfig
+from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
+from ..pipeline_utils import GaudiDiffusionPipeline
+from ..stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from .pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GaudiStableDiffusionXLImg2ImgPipeline(GaudiDiffusionPipeline, StableDiffusionXLImg2ImgPipeline):
+ """
+ Pipeline for image-to-image generation using Stable Diffusion XL on Gaudi devices
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+
+ Extends the [`StableDiffusionXLImg2ImgPipeline`] class:
+ - Generation is performed by batches
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
+ config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ use_habana (bool, defaults to `False`):
+ Whether to use Gaudi (`True`) or CPU (`False`).
+ use_hpu_graphs (bool, defaults to `False`):
+ Whether to use HPU graphs or not.
+ gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
+ Gaudi configuration to use. Can be a string to download it from the Hub.
+ Or a previously initialized config can be passed.
+ bf16_full_eval (bool, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit.
+ This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ use_habana: bool = False,
+ use_hpu_graphs: bool = False,
+ gaudi_config: Union[str, GaudiConfig] = None,
+ bf16_full_eval: bool = False,
+ ):
+ GaudiDiffusionPipeline.__init__(
+ self,
+ use_habana,
+ use_hpu_graphs,
+ gaudi_config,
+ bf16_full_eval,
+ )
+
+ StableDiffusionXLImg2ImgPipeline.__init__(
+ self,
+ vae,
+ text_encoder,
+ text_encoder_2,
+ tokenizer,
+ tokenizer_2,
+ unet,
+ scheduler,
+ image_encoder,
+ feature_extractor,
+ requires_aesthetics_score,
+ force_zeros_for_empty_prompt,
+ add_watermarker,
+ )
+
+ self.to(self._device)
+
+ @classmethod
+ def _split_inputs_into_batches(
+ cls,
+ batch_size,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ add_text_embeds,
+ negative_pooled_prompt_embeds,
+ add_time_ids,
+ negative_add_time_ids,
+ ):
+ # Use torch.split to generate num_batches batches of size batch_size
+ latents_batches = list(torch.split(latents, batch_size))
+ prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
+ if add_text_embeds is not None:
+ add_text_embeds_batches = list(torch.split(add_text_embeds, batch_size))
+ if negative_pooled_prompt_embeds is not None:
+ negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size))
+ if add_time_ids is not None:
+ add_time_ids_batches = list(torch.split(add_time_ids, batch_size))
+ if negative_add_time_ids is not None:
+ negative_add_time_ids_batches = list(torch.split(negative_add_time_ids, batch_size))
+
+ # If the last batch has less samples than batch_size, pad it with dummy samples
+ num_dummy_samples = 0
+ if latents_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - latents_batches[-1].shape[0]
+ # Pad latents_batches
+ sequence_to_stack = (latents_batches[-1],) + tuple(
+ torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ latents_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad prompt_embeds_batches
+ sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad negative_prompt_embeds_batches if necessary
+ if negative_prompt_embeds is not None:
+ sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad add_text_embeds_batches if necessary
+ if add_text_embeds is not None:
+ sequence_to_stack = (add_text_embeds_batches[-1],) + tuple(
+ torch.zeros_like(add_text_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ add_text_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad negative_pooled_prompt_embeds_batches if necessary
+ if negative_pooled_prompt_embeds is not None:
+ sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple(
+ torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :])
+ for _ in range(num_dummy_samples)
+ )
+ negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad add_time_ids_batches if necessary
+ if add_time_ids is not None:
+ sequence_to_stack = (add_time_ids_batches[-1],) + tuple(
+ torch.zeros_like(add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)
+ # Pad negative_add_time_ids_batches if necessary
+ if negative_add_time_ids is not None:
+ sequence_to_stack = (negative_add_time_ids_batches[-1],) + tuple(
+ torch.zeros_like(negative_add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ negative_add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)
+
+ # Stack batches in the same tensor
+ latents_batches = torch.stack(latents_batches)
+
+ if negative_prompt_embeds is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate(
+ zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
+ ):
+ prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
+ prompt_embeds_batches = torch.stack(prompt_embeds_batches)
+
+ if add_text_embeds is not None:
+ if negative_pooled_prompt_embeds is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ for i, (negative_pooled_prompt_embeds_batch, add_text_embeds_batch) in enumerate(
+ zip(negative_pooled_prompt_embeds_batches, add_text_embeds_batches[:])
+ ):
+ add_text_embeds_batches[i] = torch.cat(
+ [negative_pooled_prompt_embeds_batch, add_text_embeds_batch]
+ )
+ add_text_embeds_batches = torch.stack(add_text_embeds_batches)
+ else:
+ add_text_embeds_batches = None
+
+ if add_time_ids is not None:
+ if negative_add_time_ids is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ for i, (negative_add_time_ids_batch, add_time_ids_batch) in enumerate(
+ zip(negative_add_time_ids_batches, add_time_ids_batches[:])
+ ):
+ add_time_ids_batches[i] = torch.cat([negative_add_time_ids_batch, add_time_ids_batch])
+ add_time_ids_batches = torch.stack(add_time_ids_batches)
+ else:
+ add_time_ids_batches = None
+
+ return latents_batches, prompt_embeds_batches, add_text_embeds_batches, add_time_ids_batches, num_dummy_samples
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.3,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ batch_size: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ profiling_warmup_steps: Optional[int] = 0,
+ profiling_steps: Optional[int] = 0,
+ **kwargs,
+ ):
+ """
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+ - Added batch_size args
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ num_prompts = 1
+ elif prompt is not None and isinstance(prompt, list):
+ num_prompts = len(prompt)
+ else:
+ num_prompts = prompt_embeds.shape[0]
+ num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size)
+ logger.info(
+ f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt,"
+ f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
+ )
+ if num_batches < 3:
+ logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. Prepare timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid else None,
+ )
+ timesteps = timesteps.to(device)
+ latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt)
+
+ add_noise = True if self.denoising_start is None else False
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ num_prompts,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ add_noise,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 8. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(num_prompts * num_images_per_prompt, 1)
+ if self.do_classifier_free_guidance:
+ add_neg_time_ids = add_neg_time_ids.repeat(num_prompts * num_images_per_prompt, 1)
+ add_neg_time_ids = add_neg_time_ids.to(device)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, num_prompts * num_images_per_prompt
+ )
+
+ # 7.5 Split into batches (HPU-specific step)
+
+ (
+ latents_batches,
+ text_embeddings_batches,
+ add_text_embeddings_batches,
+ add_time_ids_batches,
+ num_dummy_samples,
+ ) = self._split_inputs_into_batches(
+ batch_size,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ add_text_embeds,
+ negative_pooled_prompt_embeds,
+ add_time_ids,
+ add_neg_time_ids,
+ )
+
+ outputs = {
+ "images": [],
+ }
+ t0 = time.time()
+ t1 = t0
+
+ hb_profiler = HabanaProfile(
+ warmup=profiling_warmup_steps,
+ active=profiling_steps,
+ record_shapes=False,
+ )
+ hb_profiler.start()
+
+ # 9. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 9.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
+ batch_size * num_images_per_prompt
+ )
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ # 8.3 Denoising loop
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+ use_warmup_inference_steps = (
+ num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
+ )
+ for j in self.progress_bar(range(num_batches)):
+ # The throughput is calculated from the 3rd iteration
+ # because compilation occurs in the first two iterations
+ if j == throughput_warmup_steps:
+ t1 = time.time()
+ if use_warmup_inference_steps:
+ t0_inf = time.time()
+
+ latents_batch = latents_batches[0]
+ latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
+ text_embeddings_batch = text_embeddings_batches[0]
+ text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0)
+ add_text_embeddings_batch = add_text_embeddings_batches[0]
+ add_text_embeddings_batches = torch.roll(add_text_embeddings_batches, shifts=-1, dims=0)
+ add_time_ids_batch = add_time_ids_batches[0]
+ add_time_ids_batches = torch.roll(add_time_ids_batches, shifts=-1, dims=0)
+
+ for i in range(len(timesteps)):
+ if use_warmup_inference_steps and i == throughput_warmup_steps:
+ t1_inf = time.time()
+ t1 += t1_inf - t0_inf
+ if self.interrupt:
+ continue
+ timestep = timesteps[0]
+ timesteps = torch.roll(timesteps, shifts=-1, dims=0)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeddings_batch, "time_ids": add_time_ids_batch}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+ noise_pred = self.unet_hpu(
+ latent_model_input,
+ timestep,
+ text_embeddings_batch,
+ timestep_cond,
+ self.cross_attention_kwargs,
+ added_cond_kwargs,
+ )
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_batch = self.scheduler.step(
+ noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
+
+ latents_batch = callback_outputs.pop("latents", latents_batch)
+ _prompt_embeds = callback_outputs.pop("prompt_embeds", None)
+ _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None)
+ if _prompt_embeds is not None and _negative_prompt_embeds is not None:
+ text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds])
+ _add_text_embeds = callback_outputs.pop("add_text_embeds", None)
+ _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None)
+ if _add_text_embeds is not None and _negative_pooled_prompt_embeds is not None:
+ add_text_embeddings_batch = torch.cat([_negative_pooled_prompt_embeds, _add_text_embeds])
+ _add_time_ids = callback_outputs.pop("add_time_ids", None)
+ _negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", None)
+ if _add_time_ids is not None and _negative_add_time_ids is not None:
+ add_time_ids_batch = torch.cat([_add_time_ids, _negative_add_time_ids])
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, timestep, latents)
+
+ hb_profiler.step()
+ if use_warmup_inference_steps:
+ t1 = warmup_inference_steps_time_adjustment(
+ t1, t1_inf, num_inference_steps, throughput_warmup_steps
+ )
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ else:
+ image = latents_batch
+
+ outputs["images"].append(image)
+
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ hb_profiler.stop()
+
+ speed_metrics_prefix = "generation"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=t0,
+ num_samples=num_batches * batch_size
+ if t1 == t0 or use_warmup_inference_steps
+ else (num_batches - throughput_warmup_steps) * batch_size,
+ num_steps=num_batches,
+ start_time_after_warmup=t1,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
+
+ # Remove dummy generations if needed
+ if num_dummy_samples > 0:
+ outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
+
+ # Process generated images
+ for i, image in enumerate(outputs["images"][:]):
+ if i == 0:
+ outputs["images"].clear()
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if output_type == "pil" and isinstance(image, list):
+ outputs["images"] += image
+ elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray):
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = np.concatenate((outputs["images"], image), axis=0)
+ else:
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = torch.cat((outputs["images"], image), 0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return outputs["images"]
+
+ return GaudiStableDiffusionXLPipelineOutput(
+ images=outputs["images"],
+ throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"],
+ )
+
+ @torch.no_grad()
+ def unet_hpu(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ ):
+ if self.use_hpu_graphs:
+ return self.capture_replay(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ )
+ else:
+ return self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ @torch.no_grad()
+ def capture_replay(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ ):
+ inputs = [
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ ]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+
+ outputs = self.unet(
+ sample=inputs[0],
+ timestep=inputs[1],
+ encoder_hidden_states=inputs[2],
+ timestep_cond=inputs[3],
+ cross_attention_kwargs=inputs[4],
+ added_cond_kwargs=inputs[5],
+ return_dict=False,
+ )[0]
+
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
new file mode 100644
index 0000000000..131962df3f
--- /dev/null
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -0,0 +1,1045 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from math import ceil
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy
+import torch
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLInpaintPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import (
+ rescale_noise_cfg,
+ retrieve_timesteps,
+)
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import deprecate, logging, replace_example_docstring
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ....transformers.gaudi_configuration import GaudiConfig
+from ....utils import speed_metrics, warmup_inference_steps_time_adjustment
+from ..pipeline_utils import GaudiDiffusionPipeline
+from .pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from optimum.habana.diffusers import GaudiStableDiffusionXLInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = GaudiStableDiffusionXLInpaintPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... )
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ ... ).images[0]
+ ```
+"""
+
+
+class GaudiStableDiffusionXLInpaintPipeline(GaudiDiffusionPipeline, StableDiffusionXLInpaintPipeline):
+ r"""
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py#L312
+ - Two `mark_step()` were added to add support for lazy mode
+ - Added support for HPU graphs
+
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ use_habana (bool, defaults to `False`):
+ Whether to use Gaudi (`True`) or CPU (`False`).
+ use_hpu_graphs (bool, defaults to `False`):
+ Whether to use HPU graphs or not.
+ gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
+ Gaudi configuration to use. Can be a string to download it from the Hub.
+ Or a previously initialized config can be passed.
+ bf16_full_eval (bool, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit.
+ This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
+ """
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "add_time_ids",
+ "add_text_embeds",
+ "mask",
+ "masked_image_latents",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ use_habana: bool = False,
+ use_hpu_graphs: bool = False,
+ gaudi_config: Union[str, GaudiConfig] = None,
+ bf16_full_eval: bool = False,
+ ):
+ GaudiDiffusionPipeline.__init__(
+ self,
+ use_habana,
+ use_hpu_graphs,
+ gaudi_config,
+ bf16_full_eval,
+ )
+
+ StableDiffusionXLInpaintPipeline.__init__(
+ self,
+ vae,
+ text_encoder,
+ text_encoder_2,
+ tokenizer,
+ tokenizer_2,
+ unet,
+ scheduler,
+ image_encoder,
+ feature_extractor,
+ requires_aesthetics_score,
+ force_zeros_for_empty_prompt,
+ add_watermarker,
+ )
+ self.to(self._device)
+
+ @classmethod
+ def _split_and_cat_tensors(cls, batch_size, input_a, input_b=None, do_classifier_free_guidance=True):
+ if input_a is None:
+ return None, 0
+
+ input_a_batches = list(torch.split(input_a, batch_size))
+ if input_b is not None:
+ input_b_batches = list(torch.split(input_b, batch_size))
+
+ num_dummy_samples = 0
+ if input_a_batches[-1].shape[0] < batch_size:
+ num_dummy_samples = batch_size - input_a_batches[-1].shape[0]
+ # Pad input a
+ sequence_to_stack = (input_a_batches[-1],) + tuple(
+ torch.zeros_like(input_a_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ input_a_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if input_b is not None:
+ # Pad input a
+ sequence_to_stack = (input_b_batches[-1],) + tuple(
+ torch.zeros_like(input_b_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
+ )
+ input_b_batches[-1] = torch.vstack(sequence_to_stack)
+
+ if input_b is not None and do_classifier_free_guidance:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ for i, (input_b_batch, input_a_batch) in enumerate(zip(input_b_batches, input_a_batches[:])):
+ input_a_batches[i] = torch.cat([input_b_batch, input_a_batch])
+
+ input_a_batches = torch.stack(input_a_batches)
+ return input_a_batches, num_dummy_samples
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ batch_size: int = 1,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images in a batch.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+ and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ num_prompts = 1
+ elif prompt is not None and isinstance(prompt, list):
+ num_prompts = len(prompt)
+ else:
+ num_prompts = prompt_embeds.shape[0]
+
+ num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size)
+
+ logger.info(
+ f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt,"
+ f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
+ )
+ if num_batches < 3:
+ logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
+
+ device = self._execution_device
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid else None,
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ num_prompts * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+ image_latents = None
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if (
+ num_channels_latents + num_channels_mask + num_channels_masked_image
+ != self.unet.config.in_channels
+ ):
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
+ if self.do_classifier_free_guidance:
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+ add_neg_time_ids = add_neg_time_ids.to(device)
+ image_embeds = []
+ if ip_adapter_image is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ device,
+ batch_size,
+ )
+
+ # 11 Split into batches (HPU-specific step)
+ latents_batches, num_dummy_samples = self._split_and_cat_tensors(batch_size, latents)
+ prompt_embeds_batches, _ = self._split_and_cat_tensors(
+ batch_size, prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
+ )
+ noise_batches, _ = self._split_and_cat_tensors(batch_size, noise)
+ add_text_embeds_batches, _ = self._split_and_cat_tensors(
+ batch_size, add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance
+ )
+ mask_batches, _ = self._split_and_cat_tensors(batch_size, mask)
+ masked_image_latents_batches, _ = self._split_and_cat_tensors(batch_size, masked_image_latents)
+ image_latents_batches, _ = self._split_and_cat_tensors(batch_size, image_latents)
+
+ # 12. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 12.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ outputs = {
+ "images": [],
+ }
+ t0 = time.time()
+ t1 = t0
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+ use_warmup_inference_steps = (
+ num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
+ )
+
+ for j in self.progress_bar(range(num_batches)):
+ # The throughput is calculated from the 3rd iteration
+ # because compilation occurs in the first two iterations
+ if j == throughput_warmup_steps:
+ t1 = time.time()
+ if use_warmup_inference_steps:
+ t0_inf = time.time()
+
+ latents_batch = latents_batches[0]
+ latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
+ noise_batch = noise_batches[0]
+ noise_batches = torch.roll(noise_batches, shifts=-1, dims=0)
+ prompt_embeds_batch = prompt_embeds_batches[0]
+ prompt_embeds_batches = torch.roll(prompt_embeds_batches, shifts=-1, dims=0)
+ add_text_embeds_batch = add_text_embeds_batches[0]
+ add_text_embeds_batches = torch.roll(add_text_embeds_batches, shifts=-1, dims=0)
+ add_time_ids_batch = add_time_ids
+ mask_batch = mask_batches[0]
+ mask_batches = torch.roll(mask_batches, shifts=-1, dims=0)
+ if masked_image_latents_batches is not None:
+ masked_image_latents_batch = masked_image_latents_batches[0]
+ masked_image_latents_batches = torch.roll(masked_image_latents_batches, shifts=-1, dims=0)
+
+ if image_latents_batches is not None:
+ image_latents_batch = image_latents_batches[0]
+ image_latents_batches = torch.roll(image_latents_batches, shifts=-1, dims=0)
+
+ # If use the diffuser's scheduler of non-Gaudi version, the timesteps need to reset every batch in order to avoid index overflow of timesteps.
+ if j > 0 and "Gaudi" not in self.scheduler.__class__.__name__:
+ self.scheduler._init_step_index(timesteps[0])
+
+ for i, _ in enumerate(timesteps):
+ if use_warmup_inference_steps and i == throughput_warmup_steps:
+ t1_inf = time.time()
+ t1 += t1_inf - t0_inf
+
+ if self.interrupt:
+ continue
+
+ t = timesteps[0]
+ timesteps = torch.roll(timesteps, shifts=-1, dims=0)
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
+ )
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ mask_batch_input = torch.cat([mask_batch] * 2) if self.do_classifier_free_guidance else mask_batch
+
+ if num_channels_unet == 9:
+ masked_image_latents_batch_input = (
+ torch.cat([masked_image_latents_batch] * 2)
+ if self.do_classifier_free_guidance
+ else masked_image_latents_batch
+ )
+ latent_model_input = torch.cat(
+ [latent_model_input, mask_batch_input, masked_image_latents_batch_input], dim=1
+ )
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds_batch, "time_ids": add_time_ids_batch}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ noise_pred = self.unet_hpu(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_batch,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_batch = self.scheduler.step(
+ noise_pred, t, latents_batch, **extra_step_kwargs, return_dict=False
+ )[0]
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents_batch
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask_batch_input.chunk(2)
+ else:
+ init_mask = mask_batch
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise_batch, torch.tensor([noise_timestep])
+ )
+
+ latents_batch = (1 - init_mask) * init_latents_proper + init_mask * latents_batch
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ k_batch = k + "_batch"
+ callback_kwargs[k] = locals()[k_batch]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents_batch = callback_outputs.pop("latents", latents_batch)
+ prompt_embeds_batch = callback_outputs.pop("prompt_embeds", prompt_embeds_batch)
+ add_text_embeds_batch = callback_outputs.pop("add_text_embeds", add_text_embeds_batch)
+ add_time_ids_batch = callback_outputs.pop("add_time_ids", add_time_ids_batch)
+ mask_batch = callback_outputs.pop("mask", mask_batch)
+ masked_image_latents_batch = callback_outputs.pop(
+ "masked_image_latents", masked_image_latents_batch
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if use_warmup_inference_steps:
+ t1 = warmup_inference_steps_time_adjustment(
+ t1, t1_inf, num_inference_steps, throughput_warmup_steps
+ )
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.bfloat16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents_batch = latents_batch.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.bfloat16)
+ else:
+ image = latents_batch
+
+ outputs["images"].append(image)
+
+ if not self.use_hpu_graphs:
+ self.htcore.mark_step()
+
+ # Remove dummy generations if needed
+ if num_dummy_samples > 0:
+ outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
+
+ speed_metrics_prefix = "inpainting"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=t0,
+ num_samples=num_batches * batch_size
+ if t1 == t0 or use_warmup_inference_steps
+ else (num_batches - throughput_warmup_steps) * batch_size,
+ num_steps=num_batches,
+ start_time_after_warmup=t1,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
+
+ # Process generated images
+ for i, image in enumerate(outputs["images"][:]):
+ if i == 0:
+ outputs["images"].clear()
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
+ if output_type == "pil" and isinstance(image, list):
+ outputs["images"] += image
+ elif output_type in ["np", "numpy"] and isinstance(image, numpy.ndarray):
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = numpy.concatenate((outputs["images"], image), axis=0)
+ else:
+ if len(outputs["images"]) == 0:
+ outputs["images"] = image
+ else:
+ outputs["images"] = torch.cat((outputs["images"], image), 0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ if not return_dict:
+ return (outputs["images"],)
+
+ return GaudiStableDiffusionXLPipelineOutput(
+ images=outputs["images"], throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"]
+ )
+
+ @torch.no_grad()
+ def unet_hpu(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ return_dict=False,
+ ):
+ if self.use_hpu_graphs:
+ return self.capture_replay(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ )
+ else:
+ return self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ @torch.no_grad()
+ def capture_replay(
+ self,
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ ):
+ inputs = [
+ latent_model_input,
+ timestep,
+ encoder_hidden_states,
+ timestep_cond,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ ]
+ h = self.ht.hpu.graphs.input_hash(inputs)
+ cached = self.cache.get(h)
+
+ if cached is None:
+ # Capture the graph and cache it
+ with self.ht.hpu.stream(self.hpu_stream):
+ graph = self.ht.hpu.HPUGraph()
+ graph.capture_begin()
+
+ outputs = self.unet(
+ sample=inputs[0],
+ timestep=inputs[1],
+ encoder_hidden_states=inputs[2],
+ timestep_cond=inputs[3],
+ cross_attention_kwargs=inputs[4],
+ added_cond_kwargs=inputs[5],
+ return_dict=False,
+ )[0]
+
+ graph.capture_end()
+ graph_inputs = inputs
+ graph_outputs = outputs
+ self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
+ return outputs
+
+ # Replay the cached graph with updated inputs
+ self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
+ cached.graph.replay()
+ self.ht.core.hpu.default_stream().synchronize()
+
+ return cached.graph_outputs
diff --git a/optimum/habana/peft/__init__.py b/optimum/habana/peft/__init__.py
index 3030f74e89..912681ac90 100644
--- a/optimum/habana/peft/__init__.py
+++ b/optimum/habana/peft/__init__.py
@@ -1,2 +1,6 @@
-from .layer import GaudiAdaloraLayerSVDLinearForward
+from .layer import (
+ GaudiAdaloraLayerSVDLinearForward,
+ GaudiAdaptedAttention_getattr,
+ GaudiAdaptedAttentionPreAttnForward,
+)
from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation
diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py
index dacafb1155..e3a650d310 100755
--- a/optimum/habana/peft/layer.py
+++ b/optimum/habana/peft/layer.py
@@ -1,7 +1,11 @@
+import inspect
+import math
from typing import Any
import torch
-from peft.utils.other import transpose
+import torch.nn.functional as F
+from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG
+from peft.tuners.adaption_prompt.utils import llama_apply_rotary_pos_emb, llama_rotate_half
def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
@@ -34,54 +38,140 @@ def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwarg
return result
-class LoRALinear:
- def __init__(self, module):
- has_bias = module.bias is not None
- self.module = module
- import habana_frameworks.torch.hpex.experimental.transformer_engine as te
-
- self.module.te_linear = te.Linear(
- module.in_features,
- module.out_features,
- bias=has_bias,
- params_dtype=module.weight.dtype,
- skip_weight_param_allocation=True,
- )
-
- def _linear(self, input: torch.Tensor) -> torch.Tensor:
- # TODO: to check if bias is removed from lora linear
- if hasattr(self.module, "bias"):
- return self.module.te_linear(
- input, transpose(self.module.weight, self.module.fan_in_fan_out), bias=self.module.bias
- )
+def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
+ """
+ Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
+ The only differences are:
+ -add reuse cache support.
+ -add past key value list support
+ """
+ hidden_states = kwargs.get("hidden_states")
+ position_ids = kwargs.get("position_ids")
+ past_key_value = kwargs.get("past_key_value")
+ bsz, q_len, _ = hidden_states.size()
+ query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
+
+ factor = model.k_proj.in_features // model.k_proj.out_features
+ value_states = (
+ model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
+ )
+
+ seq_len = q_len
+
+ if past_key_value is not None:
+ if kwargs.get("reuse_cache", False):
+ seq_len += past_key_value[0][-2]
+ elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
+ # for transformers <= 4.35
+ seq_len += past_key_value[0].shape[-2]
else:
- return self.module.te_linear(input, transpose(self.module.weight, self.module.fan_in_fan_out))
+ # since transformers 4.36, this is a DynamicCache instance
+ seq_len += past_key_value.get_seq_length(model.layer_idx)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- previous_dtype = x.dtype
+ # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
+ if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
+ # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
+ cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
+ return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
- if self.module.disable_adapters:
- if self.module.merged:
- self.module.unmerge()
- result = self._linear(x)
- elif self.module.merged:
- result = self._linear(x)
+ past_seen_tokens = 0
+ if position_ids is None:
+ # Compute position_ids, since they are required for transformers > 4.37.2
+ if past_key_value is None:
+ new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
- result = self._linear(x)
- for active_adapter in self.module.active_adapters:
- if active_adapter not in self.module.lora_A.keys():
- continue
- lora_A = self.module.lora_A[active_adapter]
- lora_B = self.module.lora_B[active_adapter]
- dropout = self.module.lora_dropout[active_adapter]
- scaling = self.module.scaling[active_adapter]
- x = x.to(lora_A.weight.dtype)
- result = result.clone() + lora_B(lora_A(dropout(x))) * scaling
-
- result = result.to(previous_dtype)
- return result
-
- @staticmethod
- def replace_forward(module):
- lora_linear = LoRALinear(module)
- module.forward = lora_linear.forward
+ past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
+ new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
+ position_ids = new_cache_positions.unsqueeze(0)
+
+ rotary_emb_kwargs = {"position_ids": position_ids}
+ # The `seq_len` argument has been officially removed in transformers >= 4.39.0
+ if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
+ rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens
+
+ cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)
+
+ # For batched inference unsqueeze it on the correct dim
+ # since: https://github.com/huggingface/transformers/pull/29109
+ if len(cos.shape) == 3:
+ cos = cos.unsqueeze(1)
+ sin = sin.unsqueeze(1)
+
+ return (query_states * cos) + (llama_rotate_half(query_states) * sin)
+
+
+def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
+ """
+ Copied from AdaptedAttention.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/layer.py#L57
+ The only differences are:
+ - replace self.model() with self.model.pre_attn_forward()
+ """
+ if kwargs.get("output_attention", False):
+ raise NotImplementedError("output_attention is not currently supported.")
+
+ output, _, past_key_value = self.model.pre_attn_forward(*args, **kwargs)
+ bsz = output.shape[0]
+ q_len = output.shape[1]
+ embed_dim = output.shape[2]
+ k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
+ v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
+ o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
+ factor = (
+ self.model.k_proj.in_features // self.model.k_proj.out_features
+ ) # Mistral has different input and output dimension for k_proj and v_proj layers
+
+ if k_proj_layer == v_proj_layer:
+ _, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
+ else:
+ key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
+ value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
+
+ # (bsz, num_key_value_heads, adapter_len, head_dim)
+ adapter_k = (
+ key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
+ .repeat(bsz, 1, 1, 1)
+ .transpose(1, 2)
+ )
+ adapter_v = (
+ value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
+ .repeat(bsz, 1, 1, 1)
+ .transpose(1, 2)
+ )
+ # Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
+ # (bsz, num_heads, adapter_len, head_dim)
+ adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
+ adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
+ # Recompute query states.
+ # (bsz, num_heads, q_len, head_dim)
+ query_states = compute_query_states(model=self.model, **kwargs)
+
+ previous_dtype = query_states.dtype
+
+ # (bsz, num_heads, q_len, adapter_len)
+ scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(self.model.head_dim)
+ # Upcast attention to fp32
+ # (bsz, num_heads, q_len, adapter_len)
+ scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
+ # (bsz, q_len, num_heads * head_dim)
+ adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)
+
+ # (bsz, q_len, hidden_size)
+ if o_proj_layer is not None:
+ adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
+
+ # Add adaption prompt output to original output.
+ output = output + adapter_output
+
+ # Restore original dtype.
+ output = output.to(previous_dtype)
+ return output, None, past_key_value
+
+
+def GaudiAdaptedAttention_getattr(self, name: str):
+ """Forward missing attributes to the wrapped module."""
+ try:
+ return super(self.__class__, self).__getattr__(name)
+ except AttributeError:
+ # This is necessary as e.g. causal models have various methods that we
+ # don't want to re-implement here.
+ return getattr(self.model, name)
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index c858ef1e27..a43ebf6375 100755
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -97,6 +97,7 @@
"llava",
"llava_next",
"stablelm",
+ "mamba",
]
@@ -1863,16 +1864,6 @@ def _greedy_search(
input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0
-
- if (
- not model_kwargs.get("pad_done", False)
- and not model_kwargs.get("reuse_cache", False)
- and bucket_internal
- ):
- # Pad the returned pask key values tensors from prefill phase forward run to maximum length
- # before starting the decode phase.
- self._pad_past_key_values(model_kwargs)
- model_kwargs["pad_done"] = True
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
@@ -1882,6 +1873,17 @@ def _greedy_search(
torch_hpu.synchronize()
hb_gen_time.step()
+ if (
+ not model_kwargs.get("pad_done", False)
+ and not model_kwargs.get("reuse_cache", False)
+ and bucket_internal
+ ):
+ # Pad the returned past key values tensors from prefill phase forward run to maximum length
+ # before starting the decode phase.
+ if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]:
+ self._pad_past_key_values(model_kwargs)
+ model_kwargs["pad_done"] = True
+
if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
@@ -2282,17 +2284,6 @@ def _sample(
input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0
-
- if (
- not model_kwargs.get("pad_done", False)
- and not model_kwargs.get("reuse_cache", False)
- and bucket_internal
- ):
- # Pad the returned pask key values tensors from prefill phase forward run to maximum length
- # before starting the decode phase.
- self._pad_past_key_values(model_kwargs)
- model_kwargs["pad_done"] = True
-
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
@@ -2302,6 +2293,17 @@ def _sample(
torch_hpu.synchronize()
hb_gen_time.step()
+ if (
+ not model_kwargs.get("pad_done", False)
+ and not model_kwargs.get("reuse_cache", False)
+ and bucket_internal
+ ):
+ # Pad the returned past key values tensors from prefill phase forward run to maximum length
+ # before starting the decode phase.
+ if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]:
+ self._pad_past_key_values(model_kwargs)
+ model_kwargs["pad_done"] = True
+
if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
@@ -2803,12 +2805,19 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
if self.generation_config.static_shapes:
beam_scores = next_token_scores.flatten()
- static_beam_indices = next_indices.flatten()
+ next_indices_flattened = next_indices.flatten()
+ static_beam_indices = (
+ next_indices_flattened
+ + torch.tensor(
+ [[batch_idx * num_beams] * next_indices.shape[1] for batch_idx in range(batch_size)],
+ device=next_indices.device,
+ ).flatten()
+ )
beam_tokens = next_tokens.remainder(vocab_size).flatten()
beam_trace_scores.index_copy_(0, beam_trace_idx, beam_scores.unsqueeze(0))
- beam_trace_indices.index_copy_(0, beam_trace_idx, static_beam_indices.unsqueeze(0))
+ beam_trace_indices.index_copy_(0, beam_trace_idx, next_indices_flattened.unsqueeze(0))
beam_trace_tokens.index_copy_(0, beam_trace_idx, beam_tokens.unsqueeze(0))
beam_trace_idx.add_(1)
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 8d67911886..3cc6b83abf 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -28,7 +28,12 @@
from .models import (
GaudiBloomForCausalLM,
GaudiBloomMLP,
+ GaudiCLIPAttention,
+ GaudiCLIPEncoder,
+ GaudiCLIPEncoderLayer,
GaudiCLIPVisionEmbeddings,
+ GaudiCLIPVisionModel,
+ GaudiCLIPVisionTransformer,
GaudiCodeGenAttention,
GaudiCodeGenForCausalLM,
GaudiFalconAttention,
@@ -97,6 +102,7 @@
gaudi_BartForConditionalGeneration_prepare_inputs_for_generation,
gaudi_BartLearnedPositionalEmbedding,
gaudi_BartModel_forward,
+ gaudi_BertModel_forward,
gaudi_BlipForConditionalGeneration_generate,
gaudi_BlipForQuestionAnswering_generate,
gaudi_BlipTextAttention_forward,
@@ -133,6 +139,8 @@
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
+ gaudi_MambaForCausalLM_prepare_inputs_for_generation,
+ gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
@@ -177,7 +185,6 @@
gaudi_unconstrained_rational_quadratic_spline,
gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation,
gaudi_vit_self_attention_forward,
- gaudi_VitsResidualCouplingLayer_forward,
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
@@ -287,6 +294,9 @@ def adapt_transformers_to_gaudi():
gaudi_BartForConditionalGeneration_prepare_inputs_for_generation
)
+ # Optimization for BERT on Gaudi
+ transformers.models.bert.modeling_bert.BertModel.forward = gaudi_BertModel_forward
+
# Optimization for codegen generation on Gaudi
transformers.models.codegen.modeling_codegen.CodeGenAttention = GaudiCodeGenAttention
transformers.models.codegen.modeling_codegen.CodeGenForCausalLM = GaudiCodeGenForCausalLM
@@ -370,6 +380,11 @@ def adapt_transformers_to_gaudi():
# Optimization for Clip on Gaudi
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings
+ transformers.models.clip.modeling_clip.CLIPAttention = GaudiCLIPAttention
+ transformers.models.clip.modeling_clip.CLIPEncoderLayer = GaudiCLIPEncoderLayer
+ transformers.models.clip.modeling_clip.CLIPEncoder = GaudiCLIPEncoder
+ transformers.models.clip.modeling_clip.CLIPVisionTransformer = GaudiCLIPVisionTransformer
+ transformers.models.clip.modeling_clip.CLIPVisionModel = GaudiCLIPVisionModel
# Optimization for falcon generation on Gaudi
transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention
@@ -490,7 +505,6 @@ def adapt_transformers_to_gaudi():
transformers.models.vits.modeling_vits._unconstrained_rational_quadratic_spline = (
gaudi_unconstrained_rational_quadratic_spline
)
- transformers.models.vits.modeling_vits.VitsResidualCouplingLayer.forward = gaudi_VitsResidualCouplingLayer_forward
# Optimization for starcoder2 on Gaudi
transformers.models.starcoder2.modeling_starcoder2.Starcoder2ForCausalLM = GaudiStarcoder2ForCausalLM
@@ -521,3 +535,11 @@ def adapt_transformers_to_gaudi():
# Tell transformers which Gaudi models support tracing
transformers.utils.fx._SUPPORTED_MODELS += tuple(cls.__name__ for cls in models_with_tracing_support)
+
+ # Optimization for mamba on Gaudi
+ transformers.models.mamba.modeling_mamba.MambaForCausalLM.prepare_inputs_for_generation = (
+ gaudi_MambaForCausalLM_prepare_inputs_for_generation
+ )
+ transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = (
+ gaudi_MambaForCausalLM_update_model_kwargs_for_generation
+ )
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index f17018e317..4b6c885482 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -10,6 +10,7 @@
gaudi_BartLearnedPositionalEmbedding,
gaudi_BartModel_forward,
)
+from .bert import gaudi_BertModel_forward
from .blip import (
gaudi_BlipForConditionalGeneration_generate,
gaudi_BlipForQuestionAnswering_generate,
@@ -30,7 +31,14 @@
gaudi_bloom_convert_to_standard_cache,
gaudi_bloom_model_forward,
)
-from .clip import GaudiCLIPVisionEmbeddings
+from .clip import (
+ GaudiCLIPAttention,
+ GaudiCLIPEncoder,
+ GaudiCLIPEncoderLayer,
+ GaudiCLIPVisionEmbeddings,
+ GaudiCLIPVisionModel,
+ GaudiCLIPVisionTransformer,
+)
from .codegen import (
GaudiCodeGenAttention,
GaudiCodeGenForCausalLM,
@@ -91,6 +99,10 @@
)
from .llava import GaudiLlavaForConditionalGeneration
from .llava_next import GaudiLlavaNextForConditionalGeneration
+from .mamba import (
+ gaudi_MambaForCausalLM_prepare_inputs_for_generation,
+ gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+)
from .mistral import (
GaudiMistralAttention,
GaudiMistralDecoderLayer,
@@ -193,10 +205,7 @@
gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation,
)
from .vit import gaudi_vit_self_attention_forward
-from .vits import (
- gaudi_unconstrained_rational_quadratic_spline,
- gaudi_VitsResidualCouplingLayer_forward,
-)
+from .vits import gaudi_unconstrained_rational_quadratic_spline
from .wav2vec2 import (
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
diff --git a/optimum/habana/transformers/models/bert/__init__.py b/optimum/habana/transformers/models/bert/__init__.py
new file mode 100644
index 0000000000..2e2b086f9b
--- /dev/null
+++ b/optimum/habana/transformers/models/bert/__init__.py
@@ -0,0 +1 @@
+from .modeling_bert import gaudi_BertModel_forward
diff --git a/optimum/habana/transformers/models/bert/modeling_bert.py b/optimum/habana/transformers/models/bert/modeling_bert.py
new file mode 100644
index 0000000000..b49095ba60
--- /dev/null
+++ b/optimum/habana/transformers/models/bert/modeling_bert.py
@@ -0,0 +1,121 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
+
+
+def gaudi_BertModel_forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ Copied from https://github.com/huggingface/transformers/blob/15c74a28294fe9082b81b24efe58df16fed79a9e/src/transformers/models/bert/modeling_bert.py
+ Changes:
+ - Added dtype to allow for bf16 autocast support on HPU
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ dtype = torch.hpu.get_autocast_hpu_dtype() if torch.hpu.is_autocast_hpu_enabled() else self.dtype
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, dtype=dtype)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
diff --git a/optimum/habana/transformers/models/clip/__init__.py b/optimum/habana/transformers/models/clip/__init__.py
index faa3a3355b..8245f81472 100644
--- a/optimum/habana/transformers/models/clip/__init__.py
+++ b/optimum/habana/transformers/models/clip/__init__.py
@@ -1 +1,8 @@
-from .modeling_clip import GaudiCLIPVisionEmbeddings
+from .modeling_clip import (
+ GaudiCLIPAttention,
+ GaudiCLIPEncoder,
+ GaudiCLIPEncoderLayer,
+ GaudiCLIPVisionEmbeddings,
+ GaudiCLIPVisionModel,
+ GaudiCLIPVisionTransformer,
+)
diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py
index 604c878365..99854a3799 100644
--- a/optimum/habana/transformers/models/clip/modeling_clip.py
+++ b/optimum/habana/transformers/models/clip/modeling_clip.py
@@ -1,5 +1,23 @@
+from typing import Optional, Tuple, Union
+
import torch
-from transformers.models.clip.modeling_clip import CLIPVisionEmbeddings
+from torch import nn
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.models.clip.modeling_clip import (
+ CLIPAttention,
+ CLIPEncoder,
+ CLIPEncoderLayer,
+ CLIPVisionEmbeddings,
+ CLIPVisionModel,
+ CLIPVisionTransformer,
+)
+
+
+try:
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+except ImportError:
+ print("Not using HPU fused scaled dot-product attention kernel.")
+ FusedSDPA = None
class GaudiCLIPVisionEmbeddings(CLIPVisionEmbeddings):
@@ -16,3 +34,302 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
+
+
+class ModuleFusedSDPA(torch.nn.Module):
+ def __init__(self, fusedSDPA):
+ super().__init__()
+ self._hpu_kernel_fsdpa = fusedSDPA
+
+ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
+ return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)
+
+
+class Matmul(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+
+class Softmax(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, dim=None, invAttnHead=None):
+ return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)
+
+
+class GaudiCLIPAttention(CLIPAttention):
+ def __init__(self, config):
+ super().__init__(config=config)
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+ self.bmm1 = Matmul()
+ self.bmm2 = Matmul()
+ self.softmax = Softmax()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
+ The only differences are:
+ - add new args use_flash_attention to enable FusedSDPA
+ """
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ attn_weights_reshaped = None
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ if FusedSDPA and use_flash_attention:
+ import habana_frameworks.torch.hpu as ht
+
+ use_recompute = not self.training
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, self.dropout, False, 1, "fast"
+ )
+ else:
+ attn_weights = self.bmm1(query_states, key_states.transpose(1, 2))
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = self.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = self.bmm2(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+class GaudiCLIPEncoderLayer(CLIPEncoderLayer):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Copied from CLIPEncoderLayer.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
+ The only differences are:
+ - add new args use_flash_attention
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ use_flash_attention=use_flash_attention,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class GaudiCLIPEncoder(CLIPEncoder):
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_flash_attention: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutput]:
+ """
+ Copied from CLIPEncoder.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
+ The only differences are:
+ - add new args use_flash_attention
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions,
+ use_flash_attention=use_flash_attention,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ use_flash_attention=use_flash_attention,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class GaudiCLIPVisionTransformer(CLIPVisionTransformer):
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_flash_attention: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ """
+ Copied from CLIPVisionTransformer.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
+ The only differences are:
+ - add new args use_flash_attention
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_flash_attention=use_flash_attention,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class GaudiCLIPVisionModel(CLIPVisionModel):
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_flash_attention: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ """
+ Copied from CLIPVisionModel.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
+ The only differences are:
+ - add new args use_flash_attention
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_flash_attention=use_flash_attention,
+ )
diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
index b43a5c3237..d4614fe959 100644
--- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -31,11 +31,6 @@ def gaudi_gpt_neox_attention_forward(
- add new args token_idx
- optimize KV cache
"""
- # Workaround till FusedRoPE is fixed
- global FusedRoPE
- if self.training and FusedRoPE is not None:
- FusedRoPE = None
-
has_layer_past = layer_past is not None
# Compute QKV
@@ -64,7 +59,7 @@ def gaudi_gpt_neox_attention_forward(
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
- query, key = apply_customized_rope(query_rot, key_rot, cos, sin, position_ids)
+ query, key = apply_customized_rope(query_rot, key_rot, cos, sin, position_ids, training=self.training)
query = torch.cat((query, query_pass), dim=-1).contiguous()
key = torch.cat((key, key_pass), dim=-1).contiguous()
value = value.contiguous()
@@ -420,26 +415,30 @@ def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dty
self.sin_cached = emb.sin()
-def apply_customized_rope(q, k, cos, sin, position_ids):
+def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
- if q.dtype == torch.bfloat16:
- rope_q = FusedRoPE.apply(
- q,
- cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
- sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
- position_ids,
- )
- else:
+ if training:
rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
- if k.dtype == torch.bfloat16:
- rope_k = FusedRoPE.apply(
- k,
- cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
- sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
- position_ids,
- )
- else:
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
+ else:
+ if q.dtype == torch.bfloat16:
+ rope_q = FusedRoPE.apply(
+ q,
+ cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ position_ids,
+ )
+ else:
+ rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
+ if k.dtype == torch.bfloat16:
+ rope_k = FusedRoPE.apply(
+ k,
+ cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ position_ids,
+ )
+ else:
+ rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
return rope_q, rope_k
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py
index 7cc66488d5..dcba1c0738 100644
--- a/optimum/habana/transformers/models/llama/configuration_llama.py
+++ b/optimum/habana/transformers/models/llama/configuration_llama.py
@@ -26,6 +26,7 @@ def __init__(
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
+ fused_qkv=False,
**kwargs,
):
super().__init__(
@@ -53,3 +54,4 @@ def __init__(
)
self.mlp_bias = mlp_bias
+ self.fused_qkv = fused_qkv
diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py
index f98fa90e9c..1cbd714df5 100755
--- a/optimum/habana/transformers/models/llama/modeling_llama.py
+++ b/optimum/habana/transformers/models/llama/modeling_llama.py
@@ -290,6 +290,19 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+ if config.fused_qkv:
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.hidden_size // self.num_heads
+ self.dim1 = self.num_heads * self.head_dim
+ self.dim2 = config.num_key_value_heads * self.head_dim
+ self.qkv_proj = torch.nn.Linear(
+ self.hidden_size,
+ self.dim1 + 2 * self.dim2,
+ bias=config.attention_bias,
+ )
+ self.q_proj = None
+ self.k_proj = None
+ self.v_proj = None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
@@ -375,10 +388,15 @@ def pre_attn_forward(
value_states = torch.cat(value_states, dim=-1)
else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
+ if self.config.fused_qkv:
+ qkv_states = self.qkv_proj(hidden_states)
+ query_states, key_states, value_states = torch.split(
+ qkv_states, [self.dim1, self.dim2, self.dim2], dim=-1
+ )
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
@@ -638,20 +656,20 @@ def pre_attn(
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
- hidden_states,
- attention_mask,
- position_ids,
- past_key_value,
- output_attentions,
- use_cache,
- cache_position,
- token_idx,
- attn_softmax_bf16,
- reuse_cache,
- use_flash_attention,
- flash_attention_recompute,
- flash_attention_causal_mask,
- flash_attention_fast_softmax,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py
index f2746c2a3a..fa3a321e77 100644
--- a/optimum/habana/transformers/models/llava/modeling_llava.py
+++ b/optimum/habana/transformers/models/llava/modeling_llava.py
@@ -123,6 +123,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
image_offset: Optional[int] = None,
tokens_pos: Optional[torch.LongTensor] = None,
+ use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py
@@ -152,7 +153,9 @@ def forward(
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+ image_outputs = self.vision_tower(
+ pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
+ )
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
@@ -180,6 +183,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx + image_offset,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=use_flash_attention,
)
if input_ids.shape[1] != 1 and pixel_values is not None:
@@ -290,7 +295,7 @@ def prepare_inputs_for_generation(
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
-
+ use_flash_attention = kwargs.get("use_flash_attention", False)
model_inputs.update(
{
"position_ids": position_ids,
@@ -301,6 +306,7 @@ def prepare_inputs_for_generation(
"token_idx": token_idx,
"image_offset": image_offset,
"tokens_pos": tokens_pos,
+ "use_flash_attention": use_flash_attention,
}
)
diff --git a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py
index 7fd76c5640..fdf9276123 100644
--- a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py
+++ b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py
@@ -54,6 +54,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L433
@@ -81,6 +82,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx + self.image_offset,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=use_flash_attention,
)
if inputs_embeds.shape[1] != 1 and pixel_values is not None:
@@ -244,6 +247,7 @@ def prepare_inputs_for_generation(
**kwargs,
)
else:
+ use_flash_attention = kwargs.get("use_flash_attention", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
@@ -263,7 +267,9 @@ def prepare_inputs_for_generation(
# 2. Merge text and images
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
- image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True)
+ image_features = self.vision_tower(
+ reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
+ )
selected_image_feature = image_features.hidden_states[vision_feature_layer]
@@ -383,6 +389,7 @@ def prepare_inputs_for_generation(
"token_idx": token_idx,
"image_sizes": image_sizes,
"labels": labels,
+ "use_flash_attention": use_flash_attention,
}
)
diff --git a/optimum/habana/transformers/models/mamba/__init__.py b/optimum/habana/transformers/models/mamba/__init__.py
new file mode 100644
index 0000000000..c22d12877c
--- /dev/null
+++ b/optimum/habana/transformers/models/mamba/__init__.py
@@ -0,0 +1,4 @@
+from .modeling_mamba import (
+ gaudi_MambaForCausalLM_prepare_inputs_for_generation,
+ gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+)
diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py
new file mode 100644
index 0000000000..7b917e163e
--- /dev/null
+++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py
@@ -0,0 +1,46 @@
+from typing import Any, Dict, Optional
+
+import torch
+from transformers.models.mamba.modeling_mamba import (
+ MambaCache,
+)
+from transformers.utils import (
+ ModelOutput,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
+) -> Dict[str, Any]:
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
+ token_idx = model_kwargs.get("token_idx", None)
+ if token_idx is not None:
+ token_idx.add_(1)
+ if "token_idx_cpu" in model_kwargs:
+ model_kwargs["token_idx_cpu"] += 1
+ return model_kwargs
+
+
+def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
+ self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
+):
+ token_idx = kwargs.get("token_idx", None)
+ token_idx_cpu = kwargs.get("token_idx_cpu", None)
+ if cache_params is not None:
+ if token_idx is None:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+ else:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ else:
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, torch.arange(token_idx_cpu, device=input_ids.device))
+ if inputs_embeds is not None and cache_params is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+ model_inputs["cache_params"] = cache_params
+ return model_inputs
diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py
index 3a15c6df11..0dd6ffc47e 100644
--- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py
+++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py
@@ -337,7 +337,7 @@ def forward(
)
past_key_value = (past_key, past_value)
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
- value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
+ value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if token_idx is None:
past_key_value = (key_states, value_states)
diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
index f192cf4898..27f3319579 100644
--- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
+++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
@@ -17,12 +17,12 @@
###############################################################################
import math
+import os
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
-import torch.utils.checkpoint
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
@@ -124,6 +124,16 @@ def gaudi_qwen2_repeat_kv(
return query_states, key_states, value_states, attention_mask
+# FusedScaledDotProductAttention
+class ModuleFusedSDPA(torch.nn.Module):
+ def __init__(self, fusedSDPA):
+ super().__init__()
+ self._hpu_kernel_fsdpa = fusedSDPA
+
+ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
+ return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)
+
+
class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -181,6 +191,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.block_size = 4096
@@ -255,6 +266,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -266,6 +278,7 @@ def pre_attn_forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
+ - add new arg flash_attention_fast_softmax
"""
if "padding_mask" in kwargs:
warnings.warn(
@@ -310,7 +323,7 @@ def pre_attn_forward(
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
)
- past_key_value = (past_key, past_value)
+ past_key_value = [past_key, past_value]
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if token_idx is None:
@@ -328,18 +341,23 @@ def pre_attn_forward(
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht
+ softmax_mode = "fast" if flash_attention_fast_softmax else "None"
+
if q_len == 1:
# next token
- with ht.sdp_kernel(enable_recompute=False):
- attn_output = FusedSDPA.apply(
- query_states, key_states, value_states, attention_mask, 0.0, False, None
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
- attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
+ )
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if q_len > 8192:
@@ -348,8 +366,8 @@ def pre_attn_forward(
)
htcore.mark_step()
else:
- attn_output = FusedSDPA.apply(
- query_states, key_states, value_states, attention_mask, 0.0, False, None
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
@@ -391,6 +409,11 @@ def pre_attn_forward(
if not output_attentions:
attn_weights = None
+ if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
+ # Return only past key value shapes and not the tensors during decode phase (q len is 1)
+ # to avoid making past key values as persistent output tensors of HPU graphs.
+ past_key_value = (past_key_value[0].shape, past_key_value[1].shape)
+
return attn_output, attn_weights, past_key_value
def attention_all_reduce(self, attn_output):
@@ -438,7 +461,9 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@@ -456,6 +481,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
)
self.self_attn.attention_all_reduce(hidden_states)
@@ -488,6 +514,7 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
@@ -505,6 +532,7 @@ def pre_attn(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
+ flash_attention_fast_softmax,
cache_idx=cache_idx,
)
return hidden_states, attn_weights, present_key_value
@@ -566,6 +594,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
@@ -666,6 +695,8 @@ def forward(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
+ flash_attention_fast_softmax,
+ None,
)
else:
layer_outputs = decoder_layer(
@@ -682,6 +713,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
)
@@ -744,6 +776,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
@@ -776,6 +809,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
@@ -821,6 +855,7 @@ def prepare_inputs_for_generation(
past_length = 0
reuse_cache = kwargs.get("reuse_cache")
+ bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
@@ -852,8 +887,9 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
- elif reuse_cache and token_idx is not None:
- # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
+ elif (reuse_cache or bucket_internal) and token_idx is not None:
+ # KV cache is pre allocated with reuse cache or will be padded with bucket internal
+ # hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
@@ -890,6 +926,7 @@ def prepare_inputs_for_generation(
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
+ "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
@@ -900,6 +937,15 @@ def prepare_inputs_for_generation(
def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
# TODO: remove `.clone()` when it is fixed in SynapseAI
+ if k.dtype == torch.bfloat16:
+ return FusedRoPE.apply(
+ q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
+ ), FusedRoPE.apply(
+ k,
+ cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
+ position_ids,
+ )
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
index c402833319..ce5fa35ec9 100644
--- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
@@ -22,6 +22,13 @@
)
+try:
+ from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
+except ImportError:
+ print("Not using HPU fused kernel for apply_rotary_pos_emb")
+ FusedRoPE = None
+
+
logger = logging.get_logger(__name__)
@@ -47,6 +54,7 @@ def gaudi_starcoder2_attention_forward(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
+ norm_factor = 1.0 / math.sqrt(self.head_dim)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -70,7 +78,7 @@ def gaudi_starcoder2_attention_forward(
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training)
if past_key_value is not None:
if token_idx is not None:
@@ -90,7 +98,7 @@ def gaudi_starcoder2_attention_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * norm_factor
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
@@ -470,3 +478,25 @@ def prepare_inputs_for_generation(
}
)
return model_inputs
+
+
+def apply_customized_rope(q, k, cos, sin, position_ids, is_training):
+ if q.device.type == "hpu" and FusedRoPE:
+ if not is_training and (q.dtype == torch.bfloat16 or k.dtype == torch.bfloat16):
+ return FusedRoPE.apply(
+ q,
+ cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ position_ids,
+ ), FusedRoPE.apply(
+ k,
+ cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
+ position_ids,
+ )
+ else:
+ return FusedRoPE.apply(
+ q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids
+ ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
+ else:
+ return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
diff --git a/optimum/habana/transformers/models/vits/__init__.py b/optimum/habana/transformers/models/vits/__init__.py
index b0cf4ecfeb..e91196fd99 100644
--- a/optimum/habana/transformers/models/vits/__init__.py
+++ b/optimum/habana/transformers/models/vits/__init__.py
@@ -1,4 +1 @@
-from .modeling_vits import (
- gaudi_unconstrained_rational_quadratic_spline,
- gaudi_VitsResidualCouplingLayer_forward,
-)
+from .modeling_vits import gaudi_unconstrained_rational_quadratic_spline
diff --git a/optimum/habana/transformers/models/vits/modeling_vits.py b/optimum/habana/transformers/models/vits/modeling_vits.py
index 174e8a3ac9..d957aaee6e 100644
--- a/optimum/habana/transformers/models/vits/modeling_vits.py
+++ b/optimum/habana/transformers/models/vits/modeling_vits.py
@@ -52,26 +52,3 @@ def gaudi_unconstrained_rational_quadratic_spline(
outputs = outputs_i * inside_interval_mask + outputs * outside_interval_mask
log_abs_det = log_abs_det_i * inside_interval_mask + log_abs_det * outside_interval_mask
return outputs, log_abs_det
-
-
-def gaudi_VitsResidualCouplingLayer_forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
- """
- Copied from VitsResidualCouplingLayer:forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/vits/modeling_vits.py
- The only differences are:
- - WA to fix torch.flip issue after conv1d
- """
- first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
- hidden_states = self.conv_pre(first_half) * padding_mask
- hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
- mean = self.conv_post(hidden_states) * padding_mask
- log_stddev = torch.zeros_like(mean)
-
- if not reverse:
- second_half = mean.cpu() + second_half * torch.exp(log_stddev) * padding_mask
- outputs = torch.cat([first_half, second_half], dim=1)
- log_determinant = torch.sum(log_stddev, [1, 2])
- return outputs, log_determinant
- else:
- second_half = (second_half - mean.cpu()) * torch.exp(-log_stddev) * padding_mask
- outputs = torch.cat([first_half, second_half], dim=1)
- return outputs, None
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py
index 5fa006873a..5966790d48 100644
--- a/optimum/habana/transformers/trainer.py
+++ b/optimum/habana/transformers/trainer.py
@@ -96,7 +96,7 @@
from optimum.utils import logging
from ..accelerate import GaudiAccelerator
-from ..accelerate.utils import GaudiDistributedType
+from ..accelerate.utils import FP8ContextWrapper, GaudiDistributedType
from ..utils import (
HabanaProfile,
get_hpu_memory_stats,
@@ -692,6 +692,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
transformers.modeling_utils.checkpoint = lazy_mode_checkpointing
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
+
+ # Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context.
+ if self.accelerator.state.is_fp8_enabled:
+ FP8ContextWrapper.gradient_checkpointing_wrap(self.model)
else:
# Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable`
if hasattr(self.model, "gradient_checkpointing_disable"):
@@ -1518,6 +1522,11 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
else:
ctx_manager = contextlib.nullcontext()
+ # Merge autocast context and `fp8_autocast` context if FP8 is enabled.
+ # Currently FP8 is enabled only for training.
+ if self.accelerator.state.is_fp8_enabled and self.model.training:
+ ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler)
+
return ctx_manager
def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
@@ -1551,6 +1560,9 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
self.htcore.mark_step()
if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
+ assert not (
+ self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing
+ ), "FP8 precision with gradient_checkpointing is currently not supported with PeftType.ADALORA"
if self.is_deepspeed_enabled and not is_deepspeed_zero3_enabled():
self.accelerator.deepspeed_engine_wrapped.engine.backward(loss)
self.model.base_model.update_and_allocate(self.state.global_step)
@@ -1559,7 +1571,15 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
self.accelerator.backward(loss)
self.model.base_model.update_and_allocate(self.state.global_step)
else:
- self.accelerator.backward(loss)
+ if self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing:
+ # The precision used in backward pass should be same as the one used in forward pass.
+ # However when training with gradient_checkpointing and FP8 precision, recompute forward
+ # in backward does not automatically run with FP8 precision. In order to handle this,
+ # the backward is run in `fp8_autocast` context
+ with FP8ContextWrapper.create_fp8_context(self.accelerator.fp8_recipe_handler):
+ self.accelerator.backward(loss)
+ else:
+ self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py
index eff3c1ede7..c280e46888 100644
--- a/optimum/habana/transformers/training_args.py
+++ b/optimum/habana/transformers/training_args.py
@@ -291,14 +291,6 @@ class GaudiTrainingArguments(TrainingArguments):
metadata={"help": "Whether to use fp8 for training."},
)
- fp8_recipe_format: Optional[str] = field(
- default="E5M2",
- metadata={
- "help": "Which fp8 format to use for fp8 training.",
- "choices": ["E5M2", "E4M3", "HYBRID"],
- },
- )
-
def __post_init__(self):
if self.use_hpu_graphs:
warnings.warn(
diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py
index c6728f1ce2..3d35c64202 100644
--- a/optimum/habana/trl/trainer/sft_trainer.py
+++ b/optimum/habana/trl/trainer/sft_trainer.py
@@ -30,6 +30,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from trl import SFTTrainer
+from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.import_utils import is_peft_available
from trl.trainer.utils import (
DataCollatorForCompletionOnlyLM,
@@ -70,6 +71,7 @@ def __init__(
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
+ eval_packing: Optional[bool] = None,
):
"""
Copied from SFTTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/sft_trainer.py#L120
@@ -171,6 +173,11 @@ def make_inputs_require_grad(module, input, output):
elif not self._trainer_supports_neftune:
self.neftune_noise_alpha = neftune_noise_alpha
+ if formatting_func is None and dataset_text_field is None:
+ # check if dataset has ChatML format or instruction format and is supported
+ # if not stays #None
+ formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
+
if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError(
@@ -194,6 +201,7 @@ def make_inputs_require_grad(module, input, output):
chars_per_token,
**dataset_kwargs,
)
+
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
diff --git a/optimum/habana/version.py b/optimum/habana/version.py
index a6896016c7..0add7978d8 100644
--- a/optimum/habana/version.py
+++ b/optimum/habana/version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "1.12.0.dev0"
+__version__ = "1.13.0.dev0"
diff --git a/setup.py b/setup.py
index 9c08257902..d0899959a7 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,6 @@
"accelerate < 0.28.0",
"diffusers >= 0.26.0, < 0.27.0",
"huggingface_hub < 0.23.0",
- "datasets < 2.20.0",
]
TESTS_REQUIRE = [
@@ -47,6 +46,7 @@
"datasets",
"safetensors",
"pytest < 8.0.0",
+ "torchsde",
]
QUALITY_REQUIRES = [
diff --git a/tests/baselines/Qwen2_7B.json b/tests/baselines/Qwen2_7B.json
new file mode 100644
index 0000000000..453d60848a
--- /dev/null
+++ b/tests/baselines/Qwen2_7B.json
@@ -0,0 +1,74 @@
+{
+ "gaudi2": {
+ "trl-sft-chat-peft": {
+ "num_train_epochs": 1,
+ "eval_batch_size": 32,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 3e-4,
+ "train_batch_size": 32,
+ "train_runtime": 410,
+ "train_samples_per_second": 120,
+ "extra_arguments": [
+ "--bf16 True",
+ "--subset ''",
+ "--streaming False",
+ "--packing True",
+ "--gradient_accumulation_steps 8",
+ "--gradient_checkpointing True",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--throughput_warmup_steps 5",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type cosine",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--adam_epsilon 3e-4",
+ "--use_peft True",
+ "--lora_r 4",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--lora_target_modules q_proj v_proj k_proj o_proj",
+ "--max_seq_length 512",
+ "--weight_decay 0.05",
+ "--report_to none",
+ "--max_steps 20"
+ ]
+ }
+ }
+ },
+ "trl-sft-chat": {
+ "num_train_epochs": 1,
+ "eval_batch_size": 2,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 3e-4,
+ "train_batch_size": 2,
+ "train_runtime": 360,
+ "train_samples_per_second": 8.5,
+ "extra_arguments": [
+ "--bf16 True",
+ "--subset ''",
+ "--streaming False",
+ "--packing True",
+ "--gradient_accumulation_steps 8",
+ "--gradient_checkpointing True",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--throughput_warmup_steps 5",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type cosine",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--adam_epsilon 3e-4",
+ "--use_peft False",
+ "--max_seq_length 4096",
+ "--report_to none",
+ "--use_flash_attention True",
+ "--max_steps 20"
+ ]
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/baselines/falcon_40b.json b/tests/baselines/falcon_40b.json
index cb08dc4ed4..d3563466c5 100644
--- a/tests/baselines/falcon_40b.json
+++ b/tests/baselines/falcon_40b.json
@@ -34,6 +34,40 @@
]
}
}
- }
+ },
+ "mamamiya405/finred": {
+ "num_train_epochs": 3,
+ "eval_batch_size": 1,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 4e-4,
+ "train_batch_size": 1,
+ "perplexity": 4.0893,
+ "train_runtime": 1170,
+ "train_samples_per_second": 28.162,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 16",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 64",
+ "--lora_alpha 16",
+ "--lora_dropout 0.1",
+ "--lora_target_modules query_key_value dense dense_h_to_4h dense_4h_to_h",
+ "--max_seq_length 256",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--ddp_bucket_cap_mb 50",
+ "--pipelining_fwd_bwd",
+ "--validation_split_percentage 10"
+ ]
+ }
+ }
+ }
}
-}
\ No newline at end of file
+}
diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json
index f1a556a35b..ff9f3f1485 100644
--- a/tests/baselines/llama_7b.json
+++ b/tests/baselines/llama_7b.json
@@ -1,6 +1,42 @@
{
"gaudi": {
- "tatsu-lab/alpaca": {
+ "databricks/databricks-dolly-15k": {
+ "num_train_epochs": 1,
+ "eval_batch_size": 2,
+ "distribution": {
+ "single_card": {
+ "learning_rate": 2e-4,
+ "train_batch_size": 2,
+ "perplexity": 3.9168,
+ "train_runtime": 132.665,
+ "train_samples_per_second": 2.295,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 1",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.1",
+ "--lora_target_modules q_proj v_proj",
+ "--dataset_concatenation",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--validation_split_percentage 20",
+ "--attn_softmax_bf16",
+ "--max_steps 100",
+ "--input_column_name context",
+ "--output_column_name response"
+ ]
+ }
+ }
+ },
+ "tatsu-lab/alpaca": {
"num_train_epochs": 1,
"eval_batch_size": 2,
"distribution": {
@@ -96,6 +132,40 @@
}
}
},
+ "mamamiya405/finred": {
+ "num_train_epochs": 3,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 3e-4,
+ "train_batch_size": 8,
+ "perplexity": 2.3665,
+ "train_runtime": 294.5707,
+ "train_samples_per_second": 148.093,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 2",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--lora_target_modules q_proj v_proj",
+ "--max_seq_length 512",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--ddp_bucket_cap_mb 50",
+ "--validation_split_percentage 10",
+ "--attn_softmax_bf16"
+ ]
+ }
+ }
+ },
"tatsu-lab/alpaca_fsdpcompile": {
"num_train_epochs": 1,
"eval_batch_size": 1,
@@ -135,6 +205,44 @@
}
}
},
+ "llama-adapter": {
+ "num_train_epochs": 3,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 3e-4,
+ "train_batch_size": 8,
+ "perplexity": 5.575,
+ "train_runtime": 131.7,
+ "train_samples_per_second": 294,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 2",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--lora_target_modules q_proj v_proj",
+ "--dataset_concatenation",
+ "--max_seq_length 512",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--ddp_bucket_cap_mb 50",
+ "--validation_split_percentage 10",
+ "--attn_softmax_bf16",
+ "--adapter_layers 2",
+ "--adapter_len 4",
+ "--peft_type llama-adapter"
+ ]
+ }
+ }
+ },
"trl-sft": {
"num_train_epochs": 1,
"eval_batch_size": 1,
@@ -157,7 +265,7 @@
"--lora_alpha 16",
"--lora_dropout 0.05",
"--lora_target_modules q_proj v_proj",
- "--seq_length 1024",
+ "--max_seq_length 1024",
"--optim paged_adamw_32bit",
"--weight_decay 0.05",
"--report_to none",
@@ -197,6 +305,60 @@
}
}
},
+ "trl-reward": {
+ "num_train_epochs": 1,
+ "eval_batch_size": 1,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-4,
+ "train_batch_size": 1,
+ "train_runtime": 250,
+ "train_samples_per_second": 1.6,
+ "extra_arguments": [
+ "--logging_steps 1",
+ "--lora_r 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--lora_target_modules q_proj v_proj k_proj out_proj fc_in fc_out wte",
+ "--max_length 1024",
+ "--eval_steps 200",
+ "--lr_scheduler_type cosine",
+ "--weight_decay 0.05",
+ "--gradient_accumulation_steps 4",
+ "--train_subset 500",
+ "--eval_subset 100"
+ ]
+ }
+ }
+ },
+ "trl-ppo": {
+ "num_train_epochs": 1,
+ "eval_batch_size": 1,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-4,
+ "train_batch_size": 8,
+ "train_runtime": 62,
+ "train_samples_per_second": 0.50,
+ "extra_arguments": [
+ "--lora_r 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--reward_model_name HuggingFaceH4/tiny-random-LlamaForSequenceClassification",
+ "--lora_target_modules q_proj v_proj k_proj out_proj fc_in fc_out wte",
+ "--max_train_samples 1000",
+ "--use_habana",
+ "--ppo_epochs 1",
+ "--batched_gen True",
+ "--mini_batch_size 1",
+ "--output_max_length 128",
+ "--input_max_length 128",
+ "--learning_rate 1.4e-5",
+ "--early_stopping"
+ ]
+ }
+ }
+ },
"prompt-tuning": {
"num_train_epochs": 20,
"eval_batch_size": 1,
@@ -274,6 +436,43 @@
]
}
}
+ },
+ "tatsu-lab/alpaca_fp8": {
+ "num_train_epochs": 3,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 3e-4,
+ "train_batch_size": 16,
+ "perplexity": 2.3692,
+ "train_runtime": 411.9935,
+ "train_samples_per_second": 232.439,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 1",
+ "--evaluation_strategy no",
+ "--save_strategy no",
+ "--warmup_ratio 0.03",
+ "--lr_scheduler_type constant",
+ "--logging_steps 40",
+ "--lora_rank 8",
+ "--lora_alpha 16",
+ "--lora_dropout 0.05",
+ "--lora_target_modules q_proj v_proj",
+ "--dataset_concatenation",
+ "--max_seq_length 512",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--ddp_bucket_cap_mb 50",
+ "--validation_split_percentage 10",
+ "--pipelining_fwd_bwd",
+ "--throughput_warmup_steps 18",
+ "--use_lazy_mode",
+ "--max_grad_norm 0.3",
+ "--fp8"
+ ]
+ }
+ }
}
}
}
diff --git a/tests/baselines/wav2vec2_base.json b/tests/baselines/wav2vec2_base.json
index 3927ec4a5b..64c16c70de 100644
--- a/tests/baselines/wav2vec2_base.json
+++ b/tests/baselines/wav2vec2_base.json
@@ -21,7 +21,8 @@
"--seed 0",
"--dataloader_num_workers 1",
"--use_hpu_graphs_for_training",
- "--use_hpu_graphs_for_inference"
+ "--use_hpu_graphs_for_inference",
+ "--trust_remote_code True"
]
}
}
@@ -49,7 +50,8 @@
"--seed 0",
"--dataloader_num_workers 1",
"--use_hpu_graphs_for_training",
- "--use_hpu_graphs_for_inference"
+ "--use_hpu_graphs_for_inference",
+ "--trust_remote_code True"
]
}
}
diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt
index 278d3485ff..d7b474164d 100644
--- a/tests/example_diff/run_audio_classification.txt
+++ b/tests/example_diff/run_audio_classification.txt
@@ -30,7 +30,7 @@
>
47,48c48,50
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
---
> # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
> check_min_version("4.40.0")
@@ -76,13 +76,13 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-302a296,298
+304a298,300
> # Max input length
> max_length = int(round(feature_extractor.sampling_rate * data_args.max_length_seconds))
>
-307a304
+309a306
>
-313c310,316
+315c312,318
< inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
---
> inputs = feature_extractor(
@@ -92,7 +92,7 @@
> padding="max_length",
> truncation=True,
> )
-322c325,331
+324c327,333
< inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
---
> inputs = feature_extractor(
@@ -102,15 +102,15 @@
> padding="max_length",
> truncation=True,
> )
-368,369c377,378
+370,371c379,380
< # freeze the convolutional waveform encoder
< if model_args.freeze_feature_encoder:
---
> # freeze the convolutional waveform encoder if supported by model
> if hasattr(model, "freeze_feature_encoder") and model_args.freeze_feature_encoder:
-389c398
+391c400
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-390a400
+392a402
> gaudi_config=gaudi_config,
diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt
index 2eebcc2d7b..1099d3c94a 100644
--- a/tests/example_diff/run_clip.txt
+++ b/tests/example_diff/run_clip.txt
@@ -25,7 +25,7 @@
>
56,57c63,65
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
---
> # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
> check_min_version("4.40.0")
@@ -55,9 +55,9 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-419d438
+420d439
< image_transformations = torch.jit.script(image_transformations)
-466,467c485,493
+467,468c486,494
< # Transform images on the fly as doing it on the whole dataset takes too much time.
< train_dataset.set_transform(transform_images)
---
@@ -70,7 +70,7 @@
> else:
> # Transform images on the fly as doing it on the whole dataset takes too much time.
> train_dataset.set_transform(transform_images)
-489,490c515,523
+490,491c516,524
< # Transform images on the fly as doing it on the whole dataset takes too much time.
< eval_dataset.set_transform(transform_images)
---
@@ -83,7 +83,7 @@
> else:
> # Transform images on the fly as doing it on the whole dataset takes too much time.
> eval_dataset.set_transform(transform_images)
-513a547,555
+514a548,556
> if data_args.mediapipe_dataloader:
> test_dataset.image_mean = image_processor.image_mean
> test_dataset.image_std = image_processor.image_std
@@ -93,10 +93,10 @@
> else:
> # Transform images on the fly as doing it on the whole dataset takes too much time.
> test_dataset.set_transform(transform_images)
-516c558,559
+517c559,560
< trainer = Trainer(
---
> trainer_cls = HabanaDataloaderTrainer if data_args.mediapipe_dataloader else GaudiTrainer
> trainer = trainer_cls(
-517a561
+518a562
> gaudi_config=gaudi_config,
diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt
index 7db8099ecf..c91df2d5cd 100644
--- a/tests/example_diff/run_clm.txt
+++ b/tests/example_diff/run_clm.txt
@@ -25,7 +25,7 @@
> from optimum.habana.utils import set_seed
57,58d52
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
60c54,60
< require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
---
@@ -92,25 +92,25 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-387a417
+390a420
> "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
-483a514
+486a517
>
-547a579,582
+550a582,585
>
> def tensor_mapper(x):
> return {i: torch.tensor(x[i], dtype=torch.int32) for i in x}
>
-550a586,587
+553a589,590
> if training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint != "":
> train_dataset = train_dataset.map(tensor_mapper)
-581c618
+584c621
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-582a620
+585a623
> gaudi_config=gaudi_config,
-589,592c627,628
+592,595c630,631
< compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
< preprocess_logits_for_metrics=preprocess_logits_for_metrics
< if training_args.do_eval and not is_torch_xla_available()
@@ -118,12 +118,12 @@
---
> compute_metrics=compute_metrics if training_args.do_eval else None,
> preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
-603c639,640
+606c642,643
< trainer.save_model() # Saves the tokenizer too for easy upload
---
> if data_args.save_last_ckpt:
> trainer.save_model() # Saves the tokenizer too for easy upload
-607,610c644,650
+610,613c647,653
< max_train_samples = (
< data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
< )
@@ -136,9 +136,9 @@
> data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
> )
> metrics["train_samples"] = min(max_train_samples, len(train_dataset))
-619d658
+622d661
<
-622,623c661,666
+625,626c664,669
< max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
< metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
---
@@ -148,7 +148,7 @@
> )
> metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
>
-646,650d688
+649,653d691
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_generation.txt b/tests/example_diff/run_generation.txt
index 5da903f6e8..e1745cf95b 100644
--- a/tests/example_diff/run_generation.txt
+++ b/tests/example_diff/run_generation.txt
@@ -551,7 +551,7 @@
> parser.add_argument(
> "--trust_remote_code",
> action="store_true",
-> help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
+> help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
333d289
< parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
336,339c292,293
diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt
index 78cafcf01c..46005ba396 100644
--- a/tests/example_diff/run_glue.txt
+++ b/tests/example_diff/run_glue.txt
@@ -21,7 +21,7 @@
> return ()
50,51c56,61
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
---
>
> logger = logging.getLogger(__name__)
@@ -63,21 +63,21 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-375a401
+376a402
> problem_type=data_args.problem_type,
-416a443,447
+417a444,448
> if model_args.add_pad_token:
> if not model.config.pad_token_id and not tokenizer.pad_token:
> tokenizer.pad_token = tokenizer.eos_token
> model.config.pad_token_id = tokenizer.eos_token_id
>
-527c558
+528c559
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-528a560
+529a561
> gaudi_config=gaudi_config,
-628,632d659
+629,633d660
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt
index 0dbbe3f6c2..49ab2bb6a1 100644
--- a/tests/example_diff/run_image_classification.txt
+++ b/tests/example_diff/run_image_classification.txt
@@ -25,7 +25,7 @@
< """ Fine-tuning a 🤗 Transformers model for image classification"""
58,59c65,67
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
---
> # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
> check_min_version("4.40.0")
@@ -51,9 +51,9 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-392c409
+393c410
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-393a411
+394a412
> gaudi_config=gaudi_config,
diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt
index 372a913834..3e4f6c5863 100644
--- a/tests/example_diff/run_mlm.txt
+++ b/tests/example_diff/run_mlm.txt
@@ -20,7 +20,7 @@
> from optimum.habana.utils import set_seed
56,57d51
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
59c53,59
< require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
---
@@ -75,13 +75,13 @@
> + f"mixed-precision training: {mixed_precision}"
289d305
< # Set the verbosity to info of the Transformers logger (on main process only):
-617c633
+620c636
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-618a635
+621a638
> gaudi_config=gaudi_config,
-624,627c641,642
+627,630c644,645
< compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
< preprocess_logits_for_metrics=preprocess_logits_for_metrics
< if training_args.do_eval and not is_torch_xla_available()
@@ -89,7 +89,7 @@
---
> compute_metrics=compute_metrics if training_args.do_eval else None,
> preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
-641,644c656,662
+644,647c659,665
< max_train_samples = (
< data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
< )
@@ -102,9 +102,9 @@
> data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
> )
> metrics["train_samples"] = min(max_train_samples, len(train_dataset))
-653d670
+656d673
<
-656,657c673,678
+659,660c676,681
< max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
< metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
---
@@ -114,7 +114,7 @@
> )
> metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
>
-680,684d700
+683,687d703
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt
index 118add46a1..961785aaac 100644
--- a/tests/example_diff/run_qa.txt
+++ b/tests/example_diff/run_qa.txt
@@ -19,7 +19,7 @@
> from optimum.habana.utils import set_seed
52,53d50
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
55c52,58
< require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
---
@@ -62,14 +62,14 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-346a365,368
+347a366,369
> if config.model_type == "llama":
> if tokenizer.pad_token is None:
> tokenizer.add_special_tokens({"pad_token": "[PAD]"})
> tokenizer.cls_token = tokenizer.bos_token
-637a660
+638a661
> gaudi_config=gaudi_config,
-706,710d728
+707,711d729
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_seq2seq_qa.txt b/tests/example_diff/run_seq2seq_qa.txt
index 817c72b5a9..322661ff62 100644
--- a/tests/example_diff/run_seq2seq_qa.txt
+++ b/tests/example_diff/run_seq2seq_qa.txt
@@ -11,7 +11,7 @@
> from optimum.habana.utils import set_seed
48,49d46
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
51c48,54
< require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
---
@@ -54,9 +54,9 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-660a679
+661a680
> gaudi_config=gaudi_config,
-734,738d752
+735,739d753
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt
index 1fab0abcf2..a99ee732b3 100644
--- a/tests/example_diff/run_speech_recognition_ctc.txt
+++ b/tests/example_diff/run_speech_recognition_ctc.txt
@@ -13,7 +13,7 @@
> from optimum.habana.utils import set_seed
52,53d49
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
55c51,56
< require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
---
@@ -56,13 +56,14 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-451,457c465,470
+451,458c465,471
< if training_args.do_train:
< raw_datasets["train"] = load_dataset(
< data_args.dataset_name,
< data_args.dataset_config_name,
< split=data_args.train_split_name,
< token=data_args.token,
+< trust_remote_code=data_args.trust_remote_code,
< )
---
> raw_datasets["train"] = load_dataset(
@@ -70,8 +71,9 @@
> data_args.dataset_config_name,
> split=data_args.train_split_name,
> token=data_args.token,
+> trust_remote_code=data_args.trust_remote_code,
> )
-459,464c472,477
+460,465c473,478
< if data_args.audio_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
@@ -85,7 +87,7 @@
> " Make sure to set `--audio_column_name` to the correct audio column - one of"
> f" {', '.join(raw_datasets['train'].column_names)}."
> )
-466,471c479,484
+467,472c480,485
< if data_args.text_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
@@ -99,32 +101,32 @@
> "Make sure to set `--text_column_name` to the correct text column - one of "
> f"{', '.join(raw_datasets['train'].column_names)}."
> )
-473,474c486,487
+474,475c487,488
< if data_args.max_train_samples is not None:
< raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
---
> if data_args.max_train_samples is not None:
> raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
-492c505
+494c507
< f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
---
> f'[{"".join(data_args.chars_to_ignore).replace(" ", "")}]' if data_args.chars_to_ignore is not None else None
-631a645,649
+633a647,651
> raise RuntimeError(
> f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one"
> f" ({feature_extractor.sampling_rate}).Data resampling should be done. The Datasets library does not"
> " support it on HPUs yet."
> )
-741c759,762
+743c761,764
< processor=processor, feature_extractor_input_name=feature_extractor_input_name
---
> processor=processor,
> feature_extractor_input_name=feature_extractor_input_name,
> pad_to_multiple_of=int(max_input_length),
> pad_to_multiple_of_labels=500,
-745c766
+747c768
< trainer = Trainer(
---
> trainer = GaudiTrainer(
-746a768
+748a770
> gaudi_config=gaudi_config,
diff --git a/tests/example_diff/run_speech_recognition_seq2seq.txt b/tests/example_diff/run_speech_recognition_seq2seq.txt
index 45b00bef9b..196d356171 100644
--- a/tests/example_diff/run_speech_recognition_seq2seq.txt
+++ b/tests/example_diff/run_speech_recognition_seq2seq.txt
@@ -20,7 +20,7 @@
> return ()
>
51c58,59
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
---
> check_min_version("4.40.0")
> check_optimum_habana_min_version("1.11.0")
@@ -59,18 +59,18 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-442d466
+444d468
< model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
-456a481,484
+458a483,486
> logger.warning(
> f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one"
> f" ({feature_extractor.sampling_rate}).Data resampling should be done."
> )
-561a590
+563a592
> label_features_max_length=data_args.label_features_max_length,
-565c594
+567c596
< trainer = Seq2SeqTrainer(
---
> trainer = GaudiSeq2SeqTrainer(
-566a596
+568a598
> gaudi_config=gaudi_config,
diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt
index 9f01193b14..81868ab221 100644
--- a/tests/example_diff/run_summarization.txt
+++ b/tests/example_diff/run_summarization.txt
@@ -23,7 +23,7 @@
> from optimum.habana.utils import set_seed
54,55d55
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
57c57,63
< require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
---
@@ -78,16 +78,16 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-431a463
+432a464
> use_cache=False if training_args.gradient_checkpointing else model_args.use_cache,
-450a483,488
+451a484,489
> is_bart = model.config.model_type == "bart"
> if is_bart and training_args.do_train:
> raise ValueError(
> "Training is not yet supported for BART. Eval or predict can be enabled with `--do_eval` and `--do_predict`."
> )
>
-453c491,498
+454c492,499
< embedding_size = model.get_input_embeddings().weight.shape[0]
---
> embeddings = model.get_input_embeddings()
@@ -98,16 +98,16 @@
> embedding_size = embeddings.weight.shape[0]
> else:
> embedding_size = embeddings.weight.shape[0]
-486a532
+487a533
> suffix = data_args.source_suffix if data_args.source_suffix is not None else ""
-557a604,605
+558a605,606
> else:
> raise ValueError("Found case where either text or summary is missing.")
-559c607
+560c608
< inputs = [prefix + inp for inp in inputs]
---
> inputs = [prefix + inp + suffix for inp in inputs]
-574a623,662
+575a624,663
> def preprocess_bucketing_function(examples):
> # remove pairs where at least one record is None
>
@@ -148,22 +148,22 @@
> model_inputs["labels"] = labels["input_ids"]
> return model_inputs
>
-589a678,683
+590a679,684
> def wrapper_preprocess_function(examples):
> if model.config.is_encoder_decoder:
> return preprocess_bucketing_function(examples)
> else:
> return preprocess_function(examples)
>
-598c692
+599c693
< preprocess_function,
---
> wrapper_preprocess_function,
-614c708
+615c709
< preprocess_function,
---
> wrapper_preprocess_function,
-624,629c718,726
+625,630c719,727
< data_collator = DataCollatorForSeq2Seq(
< tokenizer,
< model=model,
@@ -180,7 +180,7 @@
> label_pad_token_id=label_pad_token_id,
> pad_to_multiple_of=8 if training_args.fp16 else None,
> )
-664,671c761,769
+665,672c762,773
< training_args.generation_max_length = (
< training_args.generation_max_length
< if training_args.generation_max_length is not None
@@ -196,16 +196,19 @@
> else:
> training_args.generation_config.max_length = data_args.val_max_target_length
> if data_args.num_beams is not None:
+> if data_args.num_beams == 1:
+> training_args.generation_config.length_penalty = None
+> training_args.generation_config.early_stopping = False
> training_args.generation_config.num_beams = data_args.num_beams
> elif training_args.generation_num_beams is not None:
> training_args.generation_config.num_beams = training_args.generation_num_beams
-674c772
+675c776
< trainer = Seq2SeqTrainer(
---
> trainer = GaudiSeq2SeqTrainer(
-675a774
+676a778
> gaudi_config=gaudi_config,
-764,768d862
+765,769d866
<
<
< def _mp_fn(index):
diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt
index 1aa504c06f..e7038d847c 100644
--- a/tests/example_diff/run_translation.txt
+++ b/tests/example_diff/run_translation.txt
@@ -15,7 +15,7 @@
> from optimum.habana.utils import set_seed
54,55d52
< # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-< check_min_version("4.42.0.dev0")
+< check_min_version("4.44.0.dev0")
57c54,60
< require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
---
@@ -79,19 +79,19 @@
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
-384a419
+385a420
> use_cache=False if training_args.gradient_checkpointing else model_args.use_cache,
-456c491
+457c492
< # Check the whether the source target length fits in the model, if it has absolute positional embeddings
---
> # Check whether the source target length fits in the model, if it has absolute positional embeddings
-594c629
+595c630
< trainer = Seq2SeqTrainer(
---
> trainer = GaudiSeq2SeqTrainer(
-595a631
+596a632
> gaudi_config=gaudi_config,
-688,692d723
+689,693d724
<
<
< def _mp_fn(index):
diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py
index 89a101fdf7..0c0bae76d9 100755
--- a/tests/test_diffusers.py
+++ b/tests/test_diffusers.py
@@ -14,17 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
+import copy
+import gc
+import inspect
import json
import os
import random
import re
import subprocess
import tempfile
-from io import BytesIO
+from io import BytesIO, StringIO
from pathlib import Path
-from typing import Union
-from unittest import TestCase, skipUnless
+from typing import Callable, Union
+from unittest import TestCase, skipIf, skipUnless
+import diffusers
import numpy as np
import requests
import safetensors
@@ -32,14 +37,29 @@
from diffusers import (
AutoencoderKL,
AutoencoderKLTemporalDecoder,
+ AutoencoderTiny,
ControlNetModel,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ EulerDiscreteScheduler,
+ LCMScheduler,
+ PNDMScheduler,
UNet2DConditionModel,
UNetSpatioTemporalConditionModel,
UniPCMultistepScheduler,
)
+from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
-from diffusers.utils import load_image, numpy_to_pil
-from diffusers.utils.testing_utils import floats_tensor
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import logging, numpy_to_pil
+from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ load_numpy,
+ require_torch,
+)
from diffusers.utils.torch_utils import randn_tensor
from huggingface_hub import snapshot_download
from parameterized import parameterized
@@ -62,9 +82,14 @@
GaudiEulerAncestralDiscreteScheduler,
GaudiEulerDiscreteScheduler,
GaudiStableDiffusionControlNetPipeline,
+ GaudiStableDiffusionImageVariationPipeline,
+ GaudiStableDiffusionInpaintPipeline,
+ GaudiStableDiffusionInstructPix2PixPipeline,
GaudiStableDiffusionLDM3DPipeline,
GaudiStableDiffusionPipeline,
GaudiStableDiffusionUpscalePipeline,
+ GaudiStableDiffusionXLImg2ImgPipeline,
+ GaudiStableDiffusionXLInpaintPipeline,
GaudiStableDiffusionXLPipeline,
GaudiStableVideoDiffusionPipeline,
)
@@ -83,6 +108,9 @@
TEXTUAL_INVERSION_RUNTIME = 114.1344320399221
CONTROLNET_THROUGHPUT = 92.886919836857
CONTROLNET_RUNTIME = 537.4276602957398
+ INPAINT_THROUGHPUT_BASELINE_BF16 = 4.584
+ INPAINT_XL_THROUGHPUT_BASELINE_BF16 = 1.151
+ DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT = 0.946
else:
THROUGHPUT_BASELINE_BF16 = 0.309
THROUGHPUT_BASELINE_AUTOCAST = 0.114
@@ -90,7 +118,9 @@
TEXTUAL_INVERSION_RUNTIME = 196.43840550999994
CONTROLNET_THROUGHPUT = 44.7278034963213
CONTROLNET_RUNTIME = 1116.084316640001
-
+ INPAINT_THROUGHPUT_BASELINE_BF16 = 1.42
+ INPAINT_XL_THROUGHPUT_BASELINE_BF16 = 0.271
+ DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT = 0.302
_run_custom_bf16_ops_test_ = parse_flag_from_env("CUSTOM_BF16_OPS", default=False)
@@ -2211,3 +2241,2836 @@ def test_stable_video_diffusion_no_throughput_regression_bf16(self):
self.assertEqual(len(outputs.frames[0]), 25)
if IS_GAUDI2:
self.assertGreaterEqual(outputs.throughput, 0.95 * 0.012)
+
+
+class GaudiStableDiffusionInstructPix2PixPipelineTests(TestCase):
+ """
+ Tests the class StableDiffusionInstructPix2PixPipeline for Gaudi.
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+ """
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=8,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ scheduler = GaudiDDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ image = Image.fromarray(np.uint8(image)).convert("RGB")
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "image_guidance_scale": 1,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_stable_diffusion_pix2pix_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionInstructPix2PixPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+ self.assertEqual(image.shape, (1, 32, 32, 3))
+ expected_slice = np.array([0.7526, 0.3750, 0.4547, 0.6117, 0.5866, 0.5016, 0.4327, 0.5642, 0.4815])
+
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
+
+ def test_stable_diffusion_pix2pix_negative_prompt(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionInstructPix2PixPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ negative_prompt = "french fries"
+ output = sd_pipe(**inputs, negative_prompt=negative_prompt)
+ image = output.images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 32, 32, 3))
+ expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831])
+
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
+
+ def test_stable_diffusion_pix2pix_multiple_init_images(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionInstructPix2PixPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["prompt"] = [inputs["prompt"]] * 2
+
+ image = np.array(inputs["image"]).astype(np.float32) / 255.0
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
+ image = image / 2 + 0.5
+ image = image.permute(0, 3, 1, 2)
+ inputs["image"] = image.repeat(2, 1, 1, 1)
+
+ image = sd_pipe(**inputs).images
+ image_slice = image[-1, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (2, 32, 32, 3))
+ expected_slice = np.array([0.5812, 0.5748, 0.5222, 0.5908, 0.5695, 0.7174, 0.6804, 0.5523, 0.5579])
+
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
+
+ def test_stable_diffusion_pix2pix_euler(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ components["scheduler"] = GaudiEulerAncestralDiscreteScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
+ )
+ sd_pipe = GaudiStableDiffusionInstructPix2PixPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ slice = [round(x, 4) for x in image_slice.flatten().tolist()]
+ print(",".join([str(x) for x in slice]))
+
+ self.assertEqual(image.shape, (1, 32, 32, 3))
+ expected_slice = np.array([0.7417, 0.3842, 0.4732, 0.5776, 0.5891, 0.5139, 0.4052, 0.5673, 0.4986])
+
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-3)
+
+
+class GaudiStableDiffusionImageVariationPipelineTests(TestCase):
+ """
+ Tests the class StableDiffusionImageVariationPipeline for Gaudi.
+ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
+ """
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ scheduler = GaudiDDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=32,
+ projection_dim=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ image_size=32,
+ patch_size=4,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+ feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "image_encoder": image_encoder,
+ "feature_extractor": feature_extractor,
+ "safety_checker": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed))
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "numpy",
+ }
+ return inputs
+
+ def test_stable_diffusion_img_variation_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionImageVariationPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 64, 64, 3))
+ expected_slice = np.array([0.5239, 0.5723, 0.4796, 0.5049, 0.5550, 0.4685, 0.5329, 0.4891, 0.4921])
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
+
+ def test_stable_diffusion_img_variation_multiple_images(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionImageVariationPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["image"] = 2 * [inputs["image"]]
+ output = sd_pipe(**inputs)
+
+ image = output.images
+
+ image_slice = image[-1, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (2, 64, 64, 3))
+ expected_slice = np.array([0.6892, 0.5637, 0.5836, 0.5771, 0.6254, 0.6409, 0.5580, 0.5569, 0.5289])
+
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1)
+
+
+class GaudiStableDiffusionXLImg2ImgPipelineTests(TestCase):
+ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ time_cond_proj_dim=time_cond_proj_dim,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=72, # 5 * 8 + 32
+ cross_attention_dim=64 if not skip_first_text_encoder else 32,
+ )
+ scheduler = GaudiEulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ timestep_spacing="leading",
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=32,
+ image_size=224,
+ projection_dim=32,
+ intermediate_size=37,
+ num_attention_heads=4,
+ num_channels=3,
+ num_hidden_layers=5,
+ patch_size=14,
+ )
+
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ feature_extractor = CLIPImageProcessor(
+ crop_size=224,
+ do_center_crop=True,
+ do_normalize=True,
+ do_resize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ resample=3,
+ size=224,
+ )
+
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder if not skip_first_text_encoder else None,
+ "tokenizer": tokenizer if not skip_first_text_encoder else None,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "requires_aesthetics_score": True,
+ "image_encoder": image_encoder,
+ "feature_extractor": feature_extractor,
+ }
+ return components
+
+ def get_dummy_tiny_autoencoder(self):
+ return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components.pop("requires_aesthetics_score")
+ pipe = GaudiStableDiffusionXLImg2ImgPipeline(
+ **init_components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image / 2 + 0.5
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "strength": 0.8,
+ }
+ return inputs
+
+ def test_stable_diffusion_xl_img2img_euler(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionXLImg2ImgPipeline(
+ **components,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_torch_autocast=False),
+ )
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 32, 32, 3))
+
+ expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
+ self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)
+
+
+class GaudiDeterministicImageGenerationTester(TestCase):
+ """
+ Test deterministic generation using text_to_image_generation.py.
+ """
+
+ @slow
+ def test_deterministic_image_generation(self):
+ path_to_script = (
+ Path(os.path.dirname(__file__)).parent / "examples" / "stable-diffusion" / "text_to_image_generation.py"
+ )
+ install_requirements(path_to_script.parent / "requirements.txt")
+
+ with tempfile.TemporaryDirectory():
+ test_args = f"""
+ python3
+ {path_to_script}
+ --model_name_or_path runwayml/stable-diffusion-v1-5
+ --num_images_per_prompt 20
+ --batch_size 4
+ --image_save_dir /tmp/stable_diffusion_images
+ --use_habana
+ --use_hpu_graphs
+ --gaudi_config Habana/stable-diffusion
+ --bf16
+ --use_cpu_rng
+ """.split()
+ test_args.append("--prompts")
+ test_args.append("An image of a squirrel in Picasso style")
+ p = subprocess.Popen(test_args)
+ return_code = p.wait()
+
+ # Ensure the run finished without any issue
+ self.assertEqual(return_code, 0)
+
+ @slow
+ def test_deterministic_image_generation_no_throughput_regression_bf16(self):
+ kwargs = {"timestep_spacing": "linspace"}
+ scheduler = GaudiDDIMScheduler.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", **kwargs, subfolder="scheduler"
+ )
+
+ kwargs = {
+ "scheduler": scheduler,
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion",
+ }
+
+ pipeline = GaudiStableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ **kwargs,
+ )
+
+ num_images_per_prompt = 20
+ res = {}
+ generator = [set_seed(27) for i in range(num_images_per_prompt)]
+ outputs = pipeline(
+ prompt="An image of a squirrel in Picasso style",
+ num_images_per_prompt=num_images_per_prompt,
+ batch_size=4,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ negative_prompt=None,
+ eta=0.0,
+ output_type="pil",
+ generator=generator,
+ **res,
+ )
+
+ self.assertGreaterEqual(outputs.throughput, 0.95 * DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT)
+
+
+"""
+Copied from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/test_pipelines_common.py
+- Remove PipelinePushToHubTester testcase.
+- Remove test_multi_vae testcase.
+- Remove test_save_load_local.
+- Remove test_save_load_optional_components.
+- Modified the get_dummy_components to add the Gaudi pipeline parameters: use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval
+"""
+
+
+torch_device = "hpu"
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+def check_same_shape(tensor_list):
+ shapes = [tensor.shape for tensor in tensor_list]
+ return all(shape == shapes[0] for shape in shapes[1:])
+
+
+class PipelineLatentTesterMixin:
+ """
+ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
+ It provides a set of common tests for PyTorch pipeline that has vae, e.g.
+ equivalence of different input and output types, etc.
+ """
+
+ @property
+ def image_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `image_params` in the child test class. "
+ "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
+ )
+
+ @property
+ def image_latents_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `image_latents_params` in the child test class. "
+ "`image_latents_params` are tested for if passing latents directly are producing same results"
+ )
+
+ def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
+ inputs = self.get_dummy_inputs(device, seed)
+
+ def convert_to_pt(image):
+ if isinstance(image, torch.Tensor):
+ input_image = image
+ elif isinstance(image, np.ndarray):
+ input_image = VaeImageProcessor.numpy_to_pt(image)
+ elif isinstance(image, Image.Image):
+ input_image = VaeImageProcessor.pil_to_numpy(image)
+ input_image = VaeImageProcessor.numpy_to_pt(input_image)
+ else:
+ raise ValueError(f"unsupported input_image_type {type(image)}")
+ return input_image
+
+ def convert_pt_to_type(image, input_image_type):
+ if input_image_type == "pt":
+ input_image = image
+ elif input_image_type == "np":
+ input_image = VaeImageProcessor.pt_to_numpy(image)
+ elif input_image_type == "pil":
+ input_image = VaeImageProcessor.pt_to_numpy(image)
+ input_image = VaeImageProcessor.numpy_to_pil(input_image)
+ else:
+ raise ValueError(f"unsupported input_image_type {input_image_type}.")
+ return input_image
+
+ for image_param in self.image_params:
+ if image_param in inputs.keys():
+ inputs[image_param] = convert_pt_to_type(
+ convert_to_pt(inputs[image_param]).to(device), input_image_type
+ )
+
+ inputs["output_type"] = output_type
+
+ return inputs
+
+ def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
+ self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff)
+
+ def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ output_pt = pipe(
+ **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt")
+ )[0]
+ output_np = pipe(
+ **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np")
+ )[0]
+ output_pil = pipe(
+ **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil")
+ )[0]
+
+ max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "`output_type=='pt'` generate different results from `output_type=='np'`"
+ )
+
+ max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max()
+ self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
+
+ def test_pt_np_pil_inputs_equivalent(self):
+ if len(self.image_params) == 0:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
+ out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+ out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0]
+
+ max_diff = np.abs(out_input_pt - out_input_np).max()
+ self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
+ max_diff = np.abs(out_input_pil - out_input_np).max()
+ self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
+
+ def test_latents_input(self):
+ if len(self.image_latents_params) == 0:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
+ pipe.set_progress_bar_config(disable=None)
+
+ out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
+
+ vae = components["vae"]
+ inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
+ generator = inputs["generator"]
+ for image_param in self.image_latents_params:
+ if image_param in inputs.keys():
+ inputs[image_param] = (
+ vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor
+ )
+ out_latents_inputs = pipe(**inputs)[0]
+
+ max_diff = np.abs(out - out_latents_inputs).max()
+ self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
+
+
+@require_torch
+class PipelineKarrasSchedulerTesterMixin:
+ """
+ This mixin is designed to be used with unittest.TestCase classes.
+ It provides a set of common tests for each PyTorch pipeline that makes use of KarrasDiffusionSchedulers
+ equivalence of dict and tuple outputs, etc.
+ """
+
+ def test_karras_schedulers_shape(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ # make sure that PNDM does not need warm-up
+ pipe.scheduler.register_to_config(skip_prk_steps=True)
+
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["num_inference_steps"] = 2
+
+ if "strength" in inputs:
+ inputs["num_inference_steps"] = 4
+ inputs["strength"] = 0.5
+
+ outputs = []
+ for scheduler_enum in KarrasDiffusionSchedulers:
+ if "KDPM2" in scheduler_enum.name:
+ inputs["num_inference_steps"] = 5
+
+ scheduler_cls = getattr(diffusers, scheduler_enum.name)
+ pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
+ output = pipe(**inputs)[0]
+ outputs.append(output)
+
+ if "KDPM2" in scheduler_enum.name:
+ inputs["num_inference_steps"] = 2
+
+ assert check_same_shape(outputs)
+
+
+@require_torch
+class PipelineTesterMixin:
+ """
+ This mixin is designed to be used with unittest.TestCase classes.
+ It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
+ equivalence of dict and tuple outputs, etc.
+ """
+
+ # Canonical parameters that are passed to `__call__` regardless
+ # of the type of pipeline. They are always optional and have common
+ # sense default values.
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "num_images_per_prompt",
+ "generator",
+ "latents",
+ "output_type",
+ "return_dict",
+ ]
+ )
+
+ # set these parameters to False in the child class if the pipeline does not support the corresponding functionality
+ test_attention_slicing = True
+
+ test_xformers_attention = True
+
+ def get_generator(self, seed):
+ device = "cpu"
+ generator = torch.Generator(device).manual_seed(seed)
+ return generator
+
+ @property
+ def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
+ raise NotImplementedError(
+ "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ def get_dummy_components(self):
+ raise NotImplementedError(
+ "You need to implement `get_dummy_components(self)` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ def get_dummy_inputs(self, device, seed=0):
+ raise NotImplementedError(
+ "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `params` in the child test class. "
+ "`params` are checked for if all values are present in `__call__`'s signature."
+ " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
+ " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
+ "image pipelines, including prompts and prompt embedding overrides."
+ "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
+ "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
+ "with non-configurable height and width arguments should set the attribute as "
+ "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def batch_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `batch_params` in the child test class. "
+ "`batch_params` are the parameters required to be batched when passed to the pipeline's "
+ "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
+ "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
+ "set of batch arguments has minor changes from one of the common sets of batch arguments, "
+ "do not make modifications to the existing common sets of batch arguments. I.e. a text to "
+ "image pipeline `negative_prompt` is not batched should set the attribute as "
+ "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def callback_cfg_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `callback_cfg_params` in the child test class that requires to run test_callback_cfg. "
+ "`callback_cfg_params` are the parameters that needs to be passed to the pipeline's callback "
+ "function when dynamically adjusting `guidance_scale`. They are variables that require special"
+ "treatment when `do_classifier_free_guidance` is `True`. `pipeline_params.py` provides some common"
+ " sets of parameters such as `TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS`. If your pipeline's "
+ "set of cfg arguments has minor changes from one of the common sets of cfg arguments, "
+ "do not make modifications to the existing common sets of cfg arguments. I.e. for inpaint pipeine, you "
+ " need to adjust batch size of `mask` and `masked_image_latents` so should set the attribute as"
+ "`callback_cfg_params = TEXT_TO_IMAGE_CFG_PARAMS.union({'mask', 'masked_image_latents'})`"
+ )
+
+ def tearDown(self):
+ # clean up the VRAM after each test in case of CUDA runtime errors
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_pipeline_call_signature(self):
+ self.assertTrue(
+ hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
+ )
+
+ parameters = inspect.signature(self.pipeline_class.__call__).parameters
+
+ optional_parameters = set()
+
+ for k, v in parameters.items():
+ if v.default != inspect._empty:
+ optional_parameters.add(k)
+
+ parameters = set(parameters.keys())
+ parameters.remove("self")
+ parameters.discard("kwargs") # kwargs can be added if arguments of pipeline call function are deprecated
+
+ remaining_required_parameters = set()
+
+ for param in self.params:
+ if param not in parameters:
+ remaining_required_parameters.add(param)
+
+ self.assertTrue(
+ len(remaining_required_parameters) == 0,
+ f"Required parameters not present: {remaining_required_parameters}",
+ )
+
+ remaining_required_optional_parameters = set()
+
+ for param in self.required_optional_params:
+ if param not in optional_parameters:
+ remaining_required_optional_parameters.add(param)
+
+ self.assertTrue(
+ len(remaining_required_optional_parameters) == 0,
+ f"Required optional parameters not present: {remaining_required_optional_parameters}",
+ )
+
+ def test_inference_batch_consistent(self, batch_sizes=[2]):
+ self._test_inference_batch_consistent(batch_sizes=batch_sizes)
+
+ def _test_inference_batch_consistent(
+ self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # prepare batched inputs
+ batched_inputs = []
+ for batch_size in batch_sizes:
+ batched_input = {}
+ batched_input.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ # make unequal batch sizes
+ batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+
+ # make last batch super long
+ batched_input[name][-1] = 100 * "very long"
+
+ else:
+ batched_input[name] = batch_size * [value]
+
+ if batch_generator and "generator" in inputs:
+ batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_input["batch_size"] = batch_size
+
+ batched_inputs.append(batched_input)
+ logger.setLevel(level=diffusers.logging.WARNING)
+ for batch_size, batched_input in zip(batch_sizes, batched_inputs):
+ output = pipe(**batched_input)
+ assert len(output[0]) == batch_size
+
+ def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4):
+ self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
+
+ def _test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ output = pipe(**self.get_dummy_inputs(generator_device))[0]
+ output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+
+ # init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
+
+ pipe = self.pipeline_class(**init_components)
+ init_components.pop("use_habana")
+ init_components.pop("use_hpu_graphs")
+ init_components.pop("bf16_full_eval")
+ init_components.pop("gaudi_config")
+
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ @skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ components = self.get_dummy_components()
+ pipe_fp16 = self.pipeline_class(**components)
+ for component in pipe_fp16.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe_fp16.to(torch_device, torch.float16)
+ pipe_fp16.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in inputs:
+ inputs["generator"] = self.get_generator(0)
+
+ output = pipe(**inputs)[0]
+
+ fp16_inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in fp16_inputs:
+ fp16_inputs["generator"] = self.get_generator(0)
+
+ output_fp16 = pipe_fp16(**fp16_inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
+ self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
+
+ @skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ @skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(torch_dtype=torch.bfloat16)
+ model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.bfloat16 for dtype in model_dtypes))
+
+ def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
+ self._test_attention_slicing_forward_pass(expected_max_diff=expected_max_diff)
+
+ def _test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
+ self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
+
+ if test_mean_pixel_difference:
+ assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
+
+ @skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
+ )
+ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_offload = pipe(**inputs)[0]
+
+ pipe.enable_sequential_cpu_offload()
+
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_offload = pipe(**inputs)[0]
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
+
+ @skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
+ )
+ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_offload = pipe(**inputs)[0]
+
+ pipe.enable_model_cpu_offload()
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_offload = pipe(**inputs)[0]
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
+ offloaded_modules = [
+ v
+ for k, v in pipe.components.items()
+ if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
+ ]
+ (
+ self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
+ f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
+ )
+
+ @skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass()
+
+ def _test_xformers_attention_forwardGenerator_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-4
+ ):
+ if not self.test_xformers_attention:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_offload = pipe(**inputs)[0]
+ output_without_offload = (
+ output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
+ )
+
+ pipe.enable_xformers_memory_efficient_attention()
+ inputs = self.get_dummy_inputs(torch_device)
+ output_with_offload = pipe(**inputs)[0]
+ output_with_offload = (
+ output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
+ )
+
+ if test_max_difference:
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
+
+ if test_mean_pixel_difference:
+ assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
+
+ def test_progress_bar(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ with StringIO() as stderr, contextlib.redirect_stderr(stderr):
+ _ = pipe(**inputs)
+ stderr = stderr.getvalue()
+ # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
+ # so we just match "5" in "#####| 1/5 [00:01<00:00]"
+ max_steps = re.search("/(.*?) ", stderr).group(1)
+ self.assertTrue(max_steps is not None and len(max_steps) > 0)
+ self.assertTrue(
+ f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
+ )
+
+ pipe.set_progress_bar_config(disable=True)
+ with StringIO() as stderr, contextlib.redirect_stderr(stderr):
+ _ = pipe(**inputs)
+ self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
+
+ def test_num_images_per_prompt(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+
+ if "num_images_per_prompt" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ batch_sizes = [1, 2]
+ num_images_per_prompts = [1, 2]
+
+ for batch_size in batch_sizes:
+ for num_images_per_prompt in num_images_per_prompts:
+ inputs = self.get_dummy_inputs(torch_device)
+
+ for key in inputs.keys():
+ if key in self.batch_params:
+ inputs[key] = batch_size * [inputs[key]]
+
+ images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
+
+ assert images.shape[0] == batch_size * num_images_per_prompt
+
+ def test_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+
+ if "guidance_scale" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ inputs["guidance_scale"] = 1.0
+ out_no_cfg = pipe(**inputs)[0]
+
+ inputs["guidance_scale"] = 7.5
+ out_cfg = pipe(**inputs)[0]
+
+ assert out_cfg.shape == out_no_cfg.shape
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # interate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # interate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ inputs["output_type"] = "latent"
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ inputs["output_type"] = "latent"
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ inputs["output_type"] = "latent"
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() == 0
+
+ def test_callback_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ if "guidance_scale" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_increase_guidance(pipe, i, t, callback_kwargs):
+ pipe._guidance_scale += 1.0
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # use cfg guidance because some pipelines modify the shape of the latents
+ # outside of the denoising loop
+ inputs["guidance_scale"] = 2.0
+ inputs["callback_on_step_end"] = callback_increase_guidance
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ # we increase the guidance scale by 1.0 at every step
+ # check that the guidance scale is increased by the number of scheduler timesteps
+ # accounts for models that modify the number of inference steps based on strength
+ assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
+
+
+# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
+# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
+# test for all such pipelines. This requires us to use a custom `encode_prompt()` function.
+class SDXLOptionalComponentsTesterMixin:
+ def encode_prompt(
+ self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None
+ ):
+ device = text_encoders[0].device
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ batch_size = len(prompt)
+
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ if negative_prompt is None:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ else:
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # for classifier-free guidance
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ # for classifier-free guidance
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def _test_save_load_optional_components(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+
+ tokenizer = components.pop("tokenizer")
+ tokenizer_2 = components.pop("tokenizer_2")
+ text_encoder = components.pop("text_encoder")
+ text_encoder_2 = components.pop("text_encoder_2")
+
+ tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2]
+ text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2]
+ prompt = inputs.pop("prompt")
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(tokenizers, text_encoders, prompt)
+ inputs["prompt_embeds"] = prompt_embeds
+ inputs["negative_prompt_embeds"] = negative_prompt_embeds
+ inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
+ inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ _ = inputs.pop("prompt")
+ inputs["prompt_embeds"] = prompt_embeds
+ inputs["negative_prompt_embeds"] = negative_prompt_embeds
+ inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
+ inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+
+# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
+# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
+# reference image.
+def assert_mean_pixel_difference(image, expected_image, expected_max_diff=10):
+ image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32)
+ expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
+ avg_diff = np.abs(image - expected_image).mean()
+ assert avg_diff < expected_max_diff, f"Error image deviates {avg_diff} pixels on average"
+
+
+"""
+Copied from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/pipeline_params.py
+"""
+
+TEXT_TO_IMAGE_PARAMS = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "cross_attention_kwargs",
+ ]
+)
+
+TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+
+TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
+
+IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
+
+IMAGE_VARIATION_PARAMS = frozenset(
+ [
+ "image",
+ "height",
+ "width",
+ "guidance_scale",
+ ]
+)
+
+IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
+
+TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
+ [
+ "prompt",
+ "image",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+)
+
+TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
+
+TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
+ [
+ # Text guided image variation with an image mask
+ "prompt",
+ "image",
+ "mask_image",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+)
+
+TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
+
+IMAGE_INPAINTING_PARAMS = frozenset(
+ [
+ # image variation with an image mask
+ "image",
+ "mask_image",
+ "height",
+ "width",
+ "guidance_scale",
+ ]
+)
+
+IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
+
+IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
+ [
+ "example_image",
+ "image",
+ "mask_image",
+ "height",
+ "width",
+ "guidance_scale",
+ ]
+)
+
+IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
+
+CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
+
+CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
+
+UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
+
+UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
+
+UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
+
+UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
+
+TEXT_TO_AUDIO_PARAMS = frozenset(
+ [
+ "prompt",
+ "audio_length_in_s",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "cross_attention_kwargs",
+ ]
+)
+
+TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
+
+TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
+
+TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
+
+VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
+
+
+"""
+Copied from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
+- Modified pipeline to Gaudi pipeline.
+- Modified the get_dummy_components to add the Gaudi pipeline parameters: use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval
+- Added testcases:
+ test_stable_diffusion_inpaint_no_safety_checker
+ test_stable_diffusion_inpaint_enable_safety_checker
+ test_stable_diffusion_inpaint_no_throughput_regression
+"""
+
+enable_full_determinism()
+
+
+class StableDiffusionInpaintPipelineFastTests(
+ PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, TestCase
+):
+ pipeline_class = GaudiStableDiffusionInpaintPipeline
+ params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
+ batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ image_params = frozenset(
+ []
+ ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
+ image_latents_params = frozenset([])
+ callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=9,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ )
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ set_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+ set_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=512,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ "image_encoder": None,
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion-2",
+ "bf16_full_eval": True,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
+ # ensure determinism for the device-dependent torch.Generator on HPU
+ # Device type HPU is not supported for torch.Generator() api
+ device = "cpu"
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+ mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": init_image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "numpy",
+ }
+ return inputs
+
+ def test_stable_diffusion_inpaint(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionInpaintPipeline(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+
+class StableDiffusionInpaintPipelineIntegrationTests(TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+
+ def create_inpaint_pipe(
+ self,
+ model_name="stabilityai/stable-diffusion-2-inpainting",
+ scheduler=None,
+ use_hpu_graphs=False,
+ gaudi_config="Habana/stable-diffusion",
+ disable_safety_checker=False,
+ torch_dtype=torch.bfloat16,
+ ):
+ if scheduler is None:
+ scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+
+ kwargs = {
+ "scheduler": scheduler,
+ "use_habana": True,
+ "use_hpu_graphs": use_hpu_graphs,
+ "gaudi_config": gaudi_config,
+ }
+
+ if disable_safety_checker is True:
+ kwargs["safety_checker"] = None
+
+ sdi_pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(model_name, **kwargs).to(torch_dtype)
+
+ sdi_pipe.set_progress_bar_config(disable=None)
+
+ return sdi_pipe
+
+ def test_stable_diffusion_inpaint_pipeline(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/sd2-inpaint/init_image.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
+ )
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
+ "/yellow_cat_sitting_on_a_park_bench.npy"
+ )
+
+ model_id = "stabilityai/stable-diffusion-2-inpainting"
+ init_kwargs = {
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion",
+ "torch_dtype": torch.float,
+ }
+
+ pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, **init_kwargs)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+
+ generator = torch.manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ image=init_image,
+ mask_image=mask_image,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+ # There is no difference in the experimental results observed by the human eye.
+ # np.abs(expected_image - image).max() = 0.31966144
+ assert np.abs(expected_image - image).max() < 0.4
+
+ def test_stable_diffusion_inpaint_pipeline_bf16(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/sd2-inpaint/init_image.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
+ )
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
+ "/yellow_cat_sitting_on_a_park_bench_fp16.npy"
+ )
+
+ model_id = "stabilityai/stable-diffusion-2-inpainting"
+ init_kwargs = {
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion-2",
+ "torch_dtype": torch.bfloat16,
+ }
+
+ pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, **init_kwargs)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+
+ generator = torch.manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ image=init_image,
+ mask_image=mask_image,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+ # The format of expected_image used for testing is only float16. There is no difference in the experimental results observed by the human eye.
+ # np.abs(expected_image - image).max() = 0.9626465
+ assert np.abs(expected_image - image).max() < 0.97
+
+ @slow
+ def test_stable_diffusion_inpaint_no_safety_checker(self):
+ """Test that stable diffusion inpainting works without a saftey checker"""
+ from diffusers.utils import load_image
+
+ # Create test inpaint pipeline
+ gaudi_config = GaudiConfig()
+ scheduler = GaudiDDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ sdi_pipe = self.create_inpaint_pipe(
+ gaudi_config=gaudi_config, scheduler=scheduler, disable_safety_checker=True
+ )
+
+ # Initialize inpaint parameters
+ init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png"
+ )
+
+ self.assertIsInstance(sdi_pipe, GaudiStableDiffusionInpaintPipeline)
+ self.assertIsInstance(sdi_pipe.scheduler, GaudiDDIMScheduler)
+ self.assertIsNone(sdi_pipe.safety_checker)
+
+ image = sdi_pipe("example prompt", image=init_image, mask_image=mask_image, num_inference_steps=2).images[0]
+ self.assertIsNotNone(image)
+
+ # Check that there's no error when saving a pipeline with one of the models being None
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ sdi_pipe.save_pretrained(tmpdirname)
+ sdi_pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(
+ tmpdirname,
+ use_habana=True,
+ gaudi_config=tmpdirname,
+ )
+
+ # Sanity check that the pipeline still works
+ self.assertIsNone(sdi_pipe.safety_checker)
+ image = sdi_pipe("example prompt", image=init_image, mask_image=mask_image, num_inference_steps=2).images[0]
+ self.assertIsNotNone(image)
+
+ @slow
+ def test_stable_diffusion_inpaint_enable_safety_checker(self):
+ """Test that stable diffusion inpainting works with a saftey checker and it is loaded from_pretrained"""
+ from diffusers.utils import load_image
+
+ # Create test inpaint pipeline
+ gaudi_config = GaudiConfig()
+ scheduler = GaudiDDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ sdi_pipe = self.create_inpaint_pipe(
+ gaudi_config=gaudi_config, scheduler=scheduler, disable_safety_checker=False
+ )
+
+ # Initialize inpaint parameters
+ init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png"
+ )
+
+ self.assertIsInstance(sdi_pipe, GaudiStableDiffusionInpaintPipeline)
+ self.assertIsInstance(sdi_pipe.scheduler, GaudiDDIMScheduler)
+ # self.assertIsNotNone(sdi_pipe.safety_checker) <--- The safety checker is not being found.
+
+ image = sdi_pipe("example prompt", image=init_image, mask_image=mask_image, num_inference_steps=2).images[0]
+ self.assertIsNotNone(image)
+
+ # Check that there's no error when saving a pipeline with one of the models being None
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ sdi_pipe.save_pretrained(tmpdirname)
+ sdi_pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(
+ tmpdirname,
+ use_habana=True,
+ gaudi_config=tmpdirname,
+ )
+
+ # Sanity check that the pipeline still works
+ self.assertIsNone(sdi_pipe.safety_checker)
+ image = sdi_pipe("example prompt", image=init_image, mask_image=mask_image, num_inference_steps=2).images[0]
+ self.assertIsNotNone(image)
+
+ @slow
+ def test_stable_diffusion_inpaint_no_throughput_regression(self):
+ """Test that stable diffusion inpainting no throughput regression autocast"""
+ from diffusers.utils import load_image
+
+ # Initialize inpaint parameters
+ init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png"
+ )
+
+ prompts = [
+ "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k",
+ "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k",
+ ]
+ num_images_per_prompt = 10
+ num_inference_steps = 10
+ model_name = "runwayml/stable-diffusion-inpainting"
+
+ init_kwargs = {
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion",
+ "torch_dtype": torch.bfloat16,
+ }
+ sdi_pipe = GaudiStableDiffusionInpaintPipeline.from_pretrained(model_name, **init_kwargs)
+
+ set_seed(0)
+ outputs = sdi_pipe(
+ prompt=prompts,
+ image=init_image,
+ mask_image=mask_image,
+ num_images_per_prompt=num_images_per_prompt,
+ throughput_warmup_steps=3,
+ num_inference_steps=num_inference_steps,
+ batch_size=4,
+ )
+
+ self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
+ self.assertGreaterEqual(outputs.throughput, 0.95 * INPAINT_THROUGHPUT_BASELINE_BF16)
+
+
+"""
+Copied from: https://github.com/huggingface/diffusers/blob/v0.26.3/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+- Modified pipeline to Gaudi pipeline.
+- Modified the get_dummy_components to add the Gaudi pipeline parameters: use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval
+- added test_stable_diffusion_xl_inpaint_no_throughput_regression
+"""
+
+
+class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, TestCase):
+ pipeline_class = GaudiStableDiffusionXLInpaintPipeline
+ params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
+ batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ image_params = frozenset([])
+ # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
+ image_latents_params = frozenset([])
+ callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
+ {
+ "add_text_embeds",
+ "add_time_ids",
+ "mask",
+ "masked_image_latents",
+ }
+ )
+
+ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ set_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ time_cond_proj_dim=time_cond_proj_dim,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=72, # 5 * 8 + 32
+ cross_attention_dim=64 if not skip_first_text_encoder else 32,
+ )
+ scheduler = EulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ timestep_spacing="leading",
+ )
+ torch.manual_seed(0)
+ set_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+ torch.manual_seed(0)
+ set_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ set_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=32,
+ image_size=224,
+ projection_dim=32,
+ intermediate_size=37,
+ num_attention_heads=4,
+ num_channels=3,
+ num_hidden_layers=5,
+ patch_size=14,
+ )
+
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ feature_extractor = CLIPImageProcessor(
+ crop_size=224,
+ do_center_crop=True,
+ do_normalize=True,
+ do_resize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ resample=3,
+ size=224,
+ )
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder if not skip_first_text_encoder else None,
+ "tokenizer": tokenizer if not skip_first_text_encoder else None,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "image_encoder": image_encoder,
+ "feature_extractor": feature_extractor,
+ "requires_aesthetics_score": True,
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion",
+ "bf16_full_eval": True,
+ }
+ return components
+
+ def get_dummy_inputs(self, device="cpu", seed=0):
+ # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+ # create mask
+ image[8:, 8:, :] = 255
+ mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
+
+ # Device type HPU is not supported for torch.Generator() api
+ device = "cpu"
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": init_image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "strength": 1.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def get_dummy_inputs_2images(self, device, seed=0, img_res=64):
+ # Get random floats in [0, 1] as image with spatial size (img_res, img_res)
+ image1 = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed)).to(device)
+ image2 = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed + 22)).to(device)
+ # Convert images to [-1, 1]
+ init_image1 = 2.0 * image1 - 1.0
+ init_image2 = 2.0 * image2 - 1.0
+
+ # empty mask
+ mask_image = torch.zeros((1, 1, img_res, img_res), device=device)
+
+ # Device type HPU is not supported for torch.Generator() api
+ device = "cpu"
+ if str(device).startswith("mps"):
+ generator1 = torch.manual_seed(seed)
+ generator2 = torch.manual_seed(seed)
+ else:
+ generator1 = torch.Generator(device=device).manual_seed(seed)
+ generator2 = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": ["A painting of a squirrel eating a burger"] * 2,
+ "image": [init_image1, init_image2],
+ "mask_image": [mask_image] * 2,
+ "generator": [generator1, generator2],
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "np",
+ "batch_size": 2,
+ }
+ return inputs
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components.pop("requires_aesthetics_score")
+ init_components.pop("use_habana")
+ init_components.pop("use_hpu_graphs")
+ init_components.pop("bf16_full_eval")
+ init_components.pop("gaudi_config")
+ pipe = self.pipeline_class(**init_components)
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def test_stable_diffusion_xl_inpaint_euler(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_xl_inpaint_euler_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = GaudiStableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_xl_inpaint_euler_lcm_custom_timesteps(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = GaudiStableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ del inputs["num_inference_steps"]
+ inputs["timesteps"] = [999, 499]
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_attention_slicing_forward_pass(self):
+ super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+ # TODO(Patrick, Sayak) - skip for now as this requires more refiner tests
+ def test_save_load_optional_components(self):
+ pass
+
+ def test_stable_diffusion_xl_inpaint_negative_prompt_embeds(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ # forward without prompt embeds
+ inputs = self.get_dummy_inputs(device)
+ negative_prompt = 3 * ["this is a negative prompt"]
+ inputs["negative_prompt"] = negative_prompt
+ inputs["prompt"] = 3 * [inputs["prompt"]]
+
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with prompt embeds
+ inputs = self.get_dummy_inputs(device)
+ negative_prompt = 3 * ["this is a negative prompt"]
+ prompt = 3 * [inputs.pop("prompt")]
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
+
+ output = sd_pipe(
+ **inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # make sure that it's equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ def test_stable_diffusion_xl_refiner(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(skip_first_text_encoder=True)
+
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_two_xl_mixture_of_denoiser_fast(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe_1 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_1.unet.set_default_attn_processor()
+ pipe_2 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_2.unet.set_default_attn_processor()
+
+ def assert_run_mixture(
+ num_steps, split, scheduler_cls_orig, num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps
+ ):
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = num_steps
+
+ class scheduler_cls(scheduler_cls_orig):
+ pass
+
+ pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
+ pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
+
+ # Let's retrieve the number of timesteps we want to use
+ pipe_1.scheduler.set_timesteps(num_steps)
+ expected_steps = pipe_1.scheduler.timesteps.tolist()
+
+ split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
+
+ if pipe_1.scheduler.order == 2:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
+ expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
+ expected_steps = expected_steps_1 + expected_steps_2
+ else:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
+
+ # now we monkey patch step `done_steps`
+ # list into the step function for testing
+ done_steps = []
+ old_step = copy.copy(scheduler_cls.step)
+
+ def new_step(self, *args, **kwargs):
+ done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
+ return old_step(self, *args, **kwargs)
+
+ scheduler_cls.step = new_step
+
+ inputs_1 = {**inputs, **{"denoising_end": split, "output_type": "latent"}}
+ latents = pipe_1(**inputs_1).images[0]
+
+ assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ inputs_2 = {**inputs, **{"denoising_start": split, "image": latents}}
+ pipe_2(**inputs_2).images[0]
+
+ assert expected_steps_2 == done_steps[len(expected_steps_1) :]
+ assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ for steps in [7, 20]:
+ assert_run_mixture(steps, 0.33, EulerDiscreteScheduler)
+ # Currently cannot support the default HeunDiscreteScheduler
+ # assert_run_mixture(steps, 0.33, HeunDiscreteScheduler)
+
+ @slow
+ def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
+ components = self.get_dummy_components()
+ pipe_1 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_1.unet.set_default_attn_processor()
+ pipe_2 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_2.unet.set_default_attn_processor()
+
+ def assert_run_mixture(
+ num_steps, split, scheduler_cls_orig, num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps
+ ):
+ inputs = self.get_dummy_inputs()
+ inputs["num_inference_steps"] = num_steps
+
+ class scheduler_cls(scheduler_cls_orig):
+ pass
+
+ pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
+ pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
+
+ # Let's retrieve the number of timesteps we want to use
+ pipe_1.scheduler.set_timesteps(num_steps)
+ expected_steps = pipe_1.scheduler.timesteps.tolist()
+
+ split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
+
+ if pipe_1.scheduler.order == 2:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
+ expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
+ expected_steps = expected_steps_1 + expected_steps_2
+ else:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
+
+ # now we monkey patch step `done_steps`
+ # list into the step function for testing
+ done_steps = []
+ old_step = copy.copy(scheduler_cls.step)
+
+ def new_step(self, *args, **kwargs):
+ done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
+ return old_step(self, *args, **kwargs)
+
+ scheduler_cls.step = new_step
+
+ inputs_1 = {**inputs, **{"denoising_end": split, "output_type": "latent"}}
+ latents = pipe_1(**inputs_1).images[0]
+
+ assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ inputs_2 = {**inputs, **{"denoising_start": split, "image": latents}}
+ pipe_2(**inputs_2).images[0]
+
+ assert expected_steps_2 == done_steps[len(expected_steps_1) :]
+ assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
+
+ for steps in [5, 8, 20]:
+ for split in [0.33, 0.49, 0.71]:
+ for scheduler_cls in [
+ GaudiDDIMScheduler,
+ GaudiEulerDiscreteScheduler,
+ GaudiEulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ UniPCMultistepScheduler,
+ # HeunDiscreteScheduler,
+ ]:
+ assert_run_mixture(steps, split, scheduler_cls)
+
+ @slow
+ def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
+ components = self.get_dummy_components()
+ pipe_1 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_1.unet.set_default_attn_processor()
+ pipe_2 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_2.unet.set_default_attn_processor()
+ pipe_3 = GaudiStableDiffusionXLInpaintPipeline(**components)
+ pipe_3.unet.set_default_attn_processor()
+
+ def assert_run_mixture(
+ num_steps,
+ split_1,
+ split_2,
+ scheduler_cls_orig,
+ num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
+ ):
+ inputs = self.get_dummy_inputs()
+ inputs["num_inference_steps"] = num_steps
+
+ class scheduler_cls(scheduler_cls_orig):
+ pass
+
+ pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
+ pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
+ pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)
+
+ # Let's retrieve the number of timesteps we want to use
+ pipe_1.scheduler.set_timesteps(num_steps)
+ expected_steps = pipe_1.scheduler.timesteps.tolist()
+
+ split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
+ split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
+
+ if pipe_1.scheduler.order == 2:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
+ expected_steps_2 = expected_steps_1[-1:] + list(
+ filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
+ )
+ expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
+ expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
+ else:
+ expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
+ expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
+
+ # now we monkey patch step `done_steps`
+ # list into the step function for testing
+ done_steps = []
+ old_step = copy.copy(scheduler_cls.step)
+
+ def new_step(self, *args, **kwargs):
+ done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
+ return old_step(self, *args, **kwargs)
+
+ scheduler_cls.step = new_step
+
+ inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
+ latents = pipe_1(**inputs_1).images[0]
+
+ assert (
+ expected_steps_1 == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+
+ inputs_2 = {
+ **inputs,
+ **{"denoising_start": split_1, "denoising_end": split_2, "image": latents, "output_type": "latent"},
+ }
+ pipe_2(**inputs_2).images[0]
+
+ assert expected_steps_2 == done_steps[len(expected_steps_1) :]
+
+ inputs_3 = {**inputs, **{"denoising_start": split_2, "image": latents}}
+ pipe_3(**inputs_3).images[0]
+
+ assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
+ assert (
+ expected_steps == done_steps
+ ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+
+ for steps in [7, 11, 20]:
+ for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
+ for scheduler_cls in [
+ GaudiDDIMScheduler,
+ GaudiEulerDiscreteScheduler,
+ GaudiEulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ UniPCMultistepScheduler,
+ # HeunDiscreteScheduler,
+ ]:
+ assert_run_mixture(steps, split_1, split_2, scheduler_cls)
+
+ def test_stable_diffusion_xl_multi_prompts(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ # forward with single prompt
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same prompt duplicated
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ inputs["prompt_2"] = inputs["prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different prompt
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ inputs["prompt_2"] = "different prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ # manually set a negative_prompt
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ inputs["negative_prompt"] = "negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same negative_prompt duplicated
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = inputs["negative_prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different negative_prompt
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 5
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = "different negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ def test_stable_diffusion_xl_img2img_negative_conditions(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1]
+
+ image = sd_pipe(
+ **inputs,
+ negative_original_size=(512, 512),
+ negative_crops_coords_top_left=(
+ 0,
+ 0,
+ ),
+ negative_target_size=(1024, 1024),
+ ).images
+ image_slice_with_neg_conditions = image[0, -3:, -3:, -1]
+
+ assert (
+ np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
+ > 1e-4
+ )
+
+ def test_stable_diffusion_xl_inpaint_mask_latents(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ # normal mask + normal image
+ ## `image`: pil, `mask_image``: pil, `masked_image_latents``: None
+ inputs = self.get_dummy_inputs(device)
+ inputs["strength"] = 0.9
+ out_0 = sd_pipe(**inputs).images
+
+ # image latents + mask latents
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe.image_processor.preprocess(inputs["image"]).to(sd_pipe.device)
+ mask = sd_pipe.mask_processor.preprocess(inputs["mask_image"]).to(sd_pipe.device)
+ masked_image = image * (mask < 0.5)
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_latents = sd_pipe._encode_vae_image(image, generator=generator)
+ torch.randn((1, 4, 32, 32), generator=generator)
+ mask_latents = sd_pipe._encode_vae_image(masked_image, generator=generator)
+ inputs["image"] = image_latents
+ inputs["masked_image_latents"] = mask_latents
+ inputs["mask_image"] = mask
+ inputs["strength"] = 0.9
+ generator = torch.Generator(device=device).manual_seed(0)
+ torch.randn((1, 4, 32, 32), generator=generator)
+ inputs["generator"] = generator
+ out_1 = sd_pipe(**inputs).images
+ assert np.abs(out_0 - out_1).max() < 1e-2
+
+ def test_stable_diffusion_xl_inpaint_2_images(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ # test to confirm if we pass two same image, we will get same output
+ inputs = self.get_dummy_inputs(device)
+ gen1 = torch.Generator(device=device).manual_seed(0)
+ gen2 = torch.Generator(device=device).manual_seed(0)
+ for name in ["prompt", "image", "mask_image"]:
+ inputs[name] = [inputs[name]] * 2
+ inputs["generator"] = [gen1, gen2]
+ images = sd_pipe(**inputs).images
+
+ assert images.shape == (2, 64, 64, 3)
+
+ image_slice1 = images[0, -3:, -3:, -1]
+ image_slice2 = images[1, -3:, -3:, -1]
+ assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() < 1e-4
+
+ # test to confirm that if we pass two different images, we will get different output
+ inputs = self.get_dummy_inputs_2images(device)
+ images = sd_pipe(**inputs).images
+ assert images.shape == (2, 64, 64, 3)
+
+ image_slice1 = images[0, -3:, -3:, -1]
+ image_slice2 = images[1, -3:, -3:, -1]
+ assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
+
+ def test_pipeline_interrupt(self):
+ components = self.get_dummy_components()
+ sd_pipe = GaudiStableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs()
+
+ prompt = "hey"
+ num_inference_steps = 5
+
+ # store intermediate latents from the generation process
+ class PipelineState:
+ def __init__(self):
+ self.state = []
+
+ def apply(self, pipe, i, t, callback_kwargs):
+ self.state.append(callback_kwargs["latents"])
+ return callback_kwargs
+
+ pipe_state = PipelineState()
+ sd_pipe(
+ prompt,
+ image=inputs["image"],
+ mask_image=inputs["mask_image"],
+ strength=0.8,
+ num_inference_steps=num_inference_steps,
+ output_type="np",
+ generator=torch.Generator("cpu").manual_seed(0),
+ callback_on_step_end=pipe_state.apply,
+ ).images
+
+ # interrupt generation at step index
+ interrupt_step_idx = 1
+
+ def callback_on_step_end(pipe, i, t, callback_kwargs):
+ if i == interrupt_step_idx:
+ pipe._interrupt = True
+
+ return callback_kwargs
+
+ output_interrupted = sd_pipe(
+ prompt,
+ image=inputs["image"],
+ mask_image=inputs["mask_image"],
+ strength=0.8,
+ num_inference_steps=num_inference_steps,
+ output_type="latent",
+ generator=torch.Generator("cpu").manual_seed(0),
+ callback_on_step_end=callback_on_step_end,
+ ).images
+
+ # fetch intermediate latents at the interrupted step
+ # from the completed generation process
+ intermediate_latent = pipe_state.state[interrupt_step_idx]
+
+ # compare the intermediate latent to the output of the interrupted process
+ # they should be the same
+ assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
+
+ @slow
+ def test_stable_diffusion_xl_inpaint_no_throughput_regression(self):
+ """Test that stable diffusion inpainting no throughput regression autocast"""
+
+ # Initialize inpaint parameters
+ init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png"
+ )
+
+ prompts = [
+ "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k",
+ "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k",
+ ]
+ model_name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
+ num_images_per_prompt = 10
+ num_inference_steps = 10
+ init_kwargs = {
+ "use_habana": True,
+ "use_hpu_graphs": True,
+ "gaudi_config": "Habana/stable-diffusion",
+ "torch_dtype": torch.bfloat16,
+ }
+ sdi_pipe = GaudiStableDiffusionXLInpaintPipeline.from_pretrained(model_name, **init_kwargs)
+
+ set_seed(0)
+ outputs = sdi_pipe(
+ prompt=prompts,
+ image=init_image,
+ mask_image=mask_image,
+ num_images_per_prompt=num_images_per_prompt,
+ throughput_warmup_steps=3,
+ num_inference_steps=num_inference_steps,
+ batch_size=4,
+ )
+
+ self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
+ self.assertGreaterEqual(outputs.throughput, 0.95 * INPAINT_XL_THROUGHPUT_BASELINE_BF16)
diff --git a/tests/test_examples.py b/tests/test_examples.py
old mode 100755
new mode 100644
index c3c033e440..f3b6ded9d7
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -164,13 +164,23 @@ def is_valid_model_type(model_type: str) -> bool:
"sft": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
- ["llama"],
+ ["llama", "qwen2"],
),
"dpo": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
["llama"],
),
+ "reward_modeling": _get_supported_models_for_script(
+ MODELS_TO_TEST_MAPPING,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ ["llama"],
+ ),
+ "ppo": _get_supported_models_for_script(
+ MODELS_TO_TEST_MAPPING,
+ MODEL_FOR_CAUSAL_LM_MAPPING,
+ ["llama"],
+ ),
"run_prompt_tuning_clm": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
@@ -192,7 +202,9 @@ class ExampleTestMeta(type):
"""
@staticmethod
- def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: str, fsdp: bool):
+ def to_test(
+ model_name: str, multi_card: bool, deepspeed: bool, example_name: str, fsdp: bool, fp8: bool, task_name: str
+ ):
models_with_specific_rules = [
"albert-xxlarge-v1",
"gpt2-xl",
@@ -208,15 +220,23 @@ def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: st
"meta-llama/LlamaGuard-7b",
]
- if fsdp and not IS_GAUDI2:
+ if (fsdp or fp8) and not IS_GAUDI2:
return False
elif (
"sft" in example_name
or "dpo" in example_name
+ or "reward_modeling" in example_name
+ or "ppo" in example_name
or "prompt_tuning" in example_name
or example_name == "run_sequence_classification"
) and not IS_GAUDI2:
return False
+ elif "llama" in model_name and "trl-sft-chat" in task_name:
+ return False
+ elif ("qwen2" in model_name or "Qwen2" in model_name) and task_name == "trl-sft":
+ return False
+ elif "falcon" in model_name and task_name in ("llama-adapter", "databricks/databricks-dolly-15k"):
+ return False
elif model_name not in models_with_specific_rules and not deepspeed:
return True
elif model_name == "gpt2-xl" and deepspeed:
@@ -241,7 +261,7 @@ def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: st
return True
elif "bridgetower" in model_name and IS_GAUDI2:
return True
- elif "falcon" in model_name and IS_GAUDI2 and not fsdp:
+ elif "falcon" in model_name and IS_GAUDI2 and not fsdp and not fp8:
return True
elif "bloom" in model_name and deepspeed and not IS_GAUDI2:
return True
@@ -253,14 +273,22 @@ def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: st
return False
def __new__(
- cls, name, bases, attrs, example_name=None, multi_card=False, deepspeed=False, fsdp=False, torch_compile=False
+ cls,
+ name,
+ bases,
+ attrs,
+ example_name=None,
+ multi_card=False,
+ deepspeed=False,
+ fsdp=False,
+ torch_compile=False,
+ fp8=False,
):
distribution = "single_card"
if multi_card:
distribution = "multi_card"
elif deepspeed:
distribution = "deepspeed"
-
if example_name is not None:
models_to_test = _SCRIPT_TO_MODEL_MAPPING.get(example_name)
if models_to_test is None:
@@ -274,9 +302,9 @@ def __new__(
)
for model_name, gaudi_config_name in models_to_test:
- if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp):
+ if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp, fp8, attrs["TASK_NAME"]):
attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test(
- model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile
+ model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile, fp8
)
attrs["EXAMPLE_NAME"] = example_name
return super().__new__(cls, name, bases, attrs)
@@ -290,6 +318,7 @@ def _create_test(
deepspeed: bool = False,
fsdp: bool = False,
torch_compile: bool = False,
+ fp8: bool = False,
) -> Callable[[], None]:
"""
Create a test function that runs an example for a specific (model_name, gaudi_config_name) pair.
@@ -393,6 +422,9 @@ def test(self):
elif deepspeed and "gpt-neox-20b" in model_name:
env_variables["LD_PRELOAD"] = ""
+ if fp8 and "llama" in model_name:
+ env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt")
+
extra_command_line_arguments = baseline.get("distribution").get(distribution).get("extra_arguments", [])
if os.environ.get("DATA_CACHE", None) is not None and self.EXAMPLE_NAME == "run_clip":
@@ -432,7 +464,6 @@ def test(self):
with open(Path(tmp_dir) / "all_results.json") as fp:
results = json.load(fp)
-
# Ensure performance requirements (accuracy, training time) are met
self.assert_no_regression(results, baseline.get("distribution").get(distribution), model_name)
@@ -502,7 +533,7 @@ def _create_command_line(
"--num_gpus 8",
"--no_local_rank",
]
- if self.EXAMPLE_NAME == "dpo":
+ if self.EXAMPLE_NAME in ["dpo", "reward_modeling"]:
cmd_line += [
f"{script}",
f"--model_name_or_path {model_name}",
@@ -511,6 +542,14 @@ def _create_command_line(
f"--per_device_train_batch_size {train_batch_size}",
f"--per_device_eval_batch_size {eval_batch_size}",
]
+ elif self.EXAMPLE_NAME == "ppo":
+ cmd_line += [
+ f"{script}",
+ f"--model_name_or_path {model_name}",
+ f"--tokenizer_name_or_path {model_name}",
+ f"--output_dir {output_dir}",
+ f"--batch_size {train_batch_size}",
+ ]
else:
cmd_line += [
f"{script}",
@@ -531,10 +570,10 @@ def _create_command_line(
if "compile" in task:
cmd_line += ["--use_lazy_mode False"]
- elif self.EXAMPLE_NAME != "dpo":
+ elif self.EXAMPLE_NAME not in ["dpo", "ppo", "reward_modeling"]:
cmd_line += ["--use_lazy_mode"]
- if "bloom" not in model_name and self.EXAMPLE_NAME != "dpo":
+ if "bloom" not in model_name and self.EXAMPLE_NAME not in ["dpo", "ppo", "reward_modeling"]:
cmd_line.append("--do_eval")
if extra_command_line_arguments is not None:
@@ -568,10 +607,12 @@ def assert_no_regression(self, results: Dict, baseline: Dict, model_name: str):
for metric_name in self.REGRESSION_METRICS.keys():
if metric_name in baseline and metric_name in results:
metrics_to_assess.append(metric_name)
-
# There is no accuracy metric for `run_clip.py`, `run_bridgetower.py` and BLOOM
min_number_metrics = 3
- if self.EXAMPLE_NAME in ["run_clip", "run_bridgetower", "sft", "dpo"] or "bloom" in model_name:
+ if (
+ self.EXAMPLE_NAME in ["run_clip", "run_bridgetower", "sft", "dpo", "ppo", "reward_modeling"]
+ or "bloom" in model_name
+ ):
min_number_metrics = 2
# Check that at least 3 metrics are assessed:
@@ -719,6 +760,12 @@ class CausalLanguageModelingLORAExampleTester(
TASK_NAME = "databricks/databricks-dolly-15k"
+class MultiCardCausalLanguageModelingLORAExampleTester2(
+ ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
+):
+ TASK_NAME = "mamamiya405/finred"
+
+
class MultiCardCausalLanguageModelingLORAExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
):
@@ -753,11 +800,35 @@ class MultiCardSFTExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, ex
DATASET_NAME = "lvwerra/stack-exchange-paired"
+class MultiCardSFTChatExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="sft", multi_card=True):
+ TASK_NAME = "trl-sft-chat"
+ DATASET_NAME = "philschmid/dolly-15k-oai-style"
+
+
+class MultiCardSFTChatPeftExampleTester(
+ ExampleTesterBase, metaclass=ExampleTestMeta, example_name="sft", multi_card=True
+):
+ TASK_NAME = "trl-sft-chat-peft"
+ DATASET_NAME = "philschmid/dolly-15k-oai-style"
+
+
class MultiCardDPOExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="dpo", multi_card=True):
TASK_NAME = "trl-dpo"
DATASET_NAME = "lvwerra/stack-exchange-paired"
+class MultiCardRewardExampleTester(
+ ExampleTesterBase, metaclass=ExampleTestMeta, example_name="reward_modeling", multi_card=True
+):
+ TASK_NAME = "trl-reward"
+ DATASET_NAME = "lvwerra/stack-exchange-paired"
+
+
+class MultiCardPPOExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="ppo", multi_card=True):
+ TASK_NAME = "trl-ppo"
+ DATASET_NAME = "lvwerra/stack-exchange-paired"
+
+
class MultiCardProteinFoldingClassificationTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_sequence_classification", multi_card=True
):
@@ -768,19 +839,33 @@ class MultiCardProteinFoldingClassificationTester(
class MultiCardCausalLanguageModelingPromptTuningExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
):
- TASK_NAME = ["prompt-tuning"]
+ TASK_NAME = "prompt-tuning"
DATASET_NAME = "ought/raft"
class MultiCardCausalLanguageModelingPrefixTuningExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
):
- TASK_NAME = ["prefix-tuning"]
+ TASK_NAME = "prefix-tuning"
DATASET_NAME = "ought/raft"
class MultiCardCausalLanguageModelingPTuningExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
):
- TASK_NAME = ["p-tuning"]
+ TASK_NAME = "p-tuning"
DATASET_NAME = "ought/raft"
+
+
+class MultiCardCausalLanguageModelingLlamaAdapterExampleTester(
+ ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
+):
+ TASK_NAME = "llama-adapter"
+ DATASET_NAME = "tatsu-lab/alpaca"
+
+
+class MultiCardCausalLanguageModelingLoRAFP8ExampleTester(
+ ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True, fp8=True
+):
+ TASK_NAME = "tatsu-lab/alpaca_fp8"
+ DATASET_NAME = "tatsu-lab/alpaca"
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
new file mode 100644
index 0000000000..f934ac15a9
--- /dev/null
+++ b/tests/test_feature_extraction.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2022 the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+from unittest import TestCase
+
+import habana_frameworks.torch as ht
+import pytest
+import torch
+import torch.nn.functional as F
+from transformers import AutoModel, AutoTokenizer
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+adapt_transformers_to_gaudi()
+
+if os.environ.get("GAUDI2_CI", "0") == "1":
+ # Gaudi2 CI baselines
+ LATENCY_GTE_SMALL_BF16_GRAPH_BASELINE = 0.6812
+else:
+ # Gaudi1 CI baselines
+ LATENCY_GTE_SMALL_BF16_GRAPH_BASELINE = 0.7987
+MODEL_NAME = "Supabase/gte-small"
+
+INPUT_TEXTS = [
+ "what is the capital of China?",
+ "how to implement quick sort in Python?",
+ "Beijing",
+ "sorting algorithms",
+]
+
+TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
+
+
+def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
+
+
+def embeddings(outputs, batch_dict):
+ return F.normalize(average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]))
+
+
+def scores(embeddings):
+ return (embeddings[:1] @ embeddings[1:].T) * 100
+
+
+def get_batch_dict():
+ return TOKENIZER(INPUT_TEXTS, max_length=512, padding=True, truncation=True, return_tensors="pt")
+
+
+@pytest.fixture(scope="module")
+def model():
+ return AutoModel.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(autouse=True, scope="class")
+def cpu_results(request, model):
+ batch_dict = get_batch_dict()
+ with torch.no_grad():
+ outputs = model(**batch_dict)
+ embeddings_cpu = embeddings(outputs, batch_dict)
+ request.cls.scores_cpu = scores(embeddings_cpu)
+
+
+@pytest.fixture(autouse=True, scope="class")
+def default_hpu_results(request, model):
+ request.cls.model_hpu = model.to("hpu")
+ request.cls.model_hpu_graph = ht.hpu.wrap_in_hpu_graph(model.to("hpu"))
+ batch_dict = get_batch_dict().to("hpu")
+ with torch.no_grad():
+ outputs = request.cls.model_hpu(**batch_dict)
+ embeddings_hpu_default = embeddings(outputs, batch_dict)
+ request.cls.scores_hpu_default = scores(embeddings_hpu_default)
+
+
+class GaudiFeatureExtractionTester(TestCase):
+ """
+ Tests for Supabase/gte-small feature extraction on Gaudi
+ """
+
+ def test_inference_default(self):
+ """
+ Tests for equivalent CPU and HPU outputs
+ """
+ self.assertTrue(torch.allclose(self.scores_cpu, self.scores_hpu_default, rtol=1e-3))
+
+ def test_inference_bf16(self):
+ """
+ Test for similar bf16 and regular outputs
+ """
+ batch_dict = get_batch_dict()
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16), torch.no_grad():
+ outputs = self.model_hpu(**batch_dict)
+ embeddings_hpu_bf16 = embeddings(outputs, batch_dict)
+ scores_hpu_bf16 = scores(embeddings_hpu_bf16)
+ self.assertTrue(torch.allclose(scores_hpu_bf16, self.scores_hpu_default, rtol=1e-2))
+
+ def test_inference_graph_bf16(self):
+ batch_dict = get_batch_dict().to("hpu")
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16), torch.no_grad():
+ outputs = self.model_hpu_graph(**batch_dict)
+ embeddings_hpu_graph_bf16 = embeddings(outputs, batch_dict)
+ scores_hpu_graph_bf16 = scores(embeddings_hpu_graph_bf16)
+ self.assertTrue(torch.allclose(scores_hpu_graph_bf16, self.scores_hpu_default, rtol=1e-2))
+
+ def test_latency_graph_bf16(self):
+ batch_dict = get_batch_dict().to("hpu")
+ warm_up_iters = 5
+ test_iters = 50
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16), torch.no_grad():
+ for _ in range(warm_up_iters):
+ self.model_hpu_graph(**batch_dict)
+ torch.hpu.synchronize()
+ start_time = time.time()
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16), torch.no_grad():
+ for _ in range(test_iters):
+ outputs = self.model_hpu_graph(**batch_dict)
+ embeddings(outputs, batch_dict)
+ torch.hpu.synchronize()
+ end_time = time.time()
+ time_per_iter = (end_time - start_time) * 1000 / test_iters # time in ms
+ self.assertLess(time_per_iter, 1.05 * LATENCY_GTE_SMALL_BF16_GRAPH_BASELINE)
diff --git a/tests/test_image_classification.py b/tests/test_image_classification.py
new file mode 100644
index 0000000000..6e59b7ac40
--- /dev/null
+++ b/tests/test_image_classification.py
@@ -0,0 +1,120 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from unittest import TestCase
+
+import habana_frameworks.torch as ht
+import numpy as np
+import requests
+import timm
+import torch
+from PIL import Image
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+adapt_transformers_to_gaudi()
+
+# For Gaudi 2
+LATENCY_FastViT_BF16_GRAPH_BASELINE = 2.5270626640319824
+
+
+class GaudiFastViTTester(TestCase):
+ """
+ Tests for FastViT model
+ """
+
+ def prepare_model_and_processor(self):
+ model = timm.create_model("timm/fastvit_t8.apple_in1k", pretrained=True)
+ model.to("hpu")
+ model = model.eval()
+ data_config = timm.data.resolve_model_data_config(model)
+ processor = timm.data.create_transform(**data_config, is_training=False)
+ return model, processor
+
+ def prepare_data(self):
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
+ image = Image.open(requests.get(url, stream=True).raw)
+ return image
+
+ def test_inference_default(self):
+ model, processor = self.prepare_model_and_processor()
+ image = self.prepare_data()
+ inputs = processor(image).unsqueeze(0).to("hpu")
+ outputs = model(inputs)
+ top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
+ top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
+ top1_class_indices = top1_class_indices.to("cpu").numpy()
+ expected_scores = np.array([21.406523]) # from CPU
+ expected_class = np.array([960])
+ self.assertEqual(top1_class_indices, expected_class)
+ self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)
+
+ def test_inference_autocast(self):
+ model, processor = self.prepare_model_and_processor()
+ image = self.prepare_data()
+ inputs = processor(image).unsqueeze(0).to("hpu")
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
+ outputs = model(inputs)
+ top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
+ top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
+ top1_class_indices = top1_class_indices.to("cpu").numpy()
+ expected_scores = np.array([21.406523]) # from CPU
+ expected_class = np.array([960])
+ self.assertEqual(top1_class_indices, expected_class)
+ self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)
+
+ def test_inference_hpu_graphs(self):
+ model, processor = self.prepare_model_and_processor()
+ image = self.prepare_data()
+ inputs = processor(image).unsqueeze(0).to("hpu")
+
+ model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph
+
+ outputs = model(inputs)
+ top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
+ top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
+ top1_class_indices = top1_class_indices.to("cpu").numpy()
+ expected_scores = np.array([21.406523]) # from CPU
+ expected_class = np.array([960])
+ self.assertEqual(top1_class_indices, expected_class)
+ self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)
+
+ def test_no_latency_regression_autocast(self):
+ warmup = 3
+ iterations = 20
+
+ model, processor = self.prepare_model_and_processor()
+ image = self.prepare_data()
+
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
+ for i in range(warmup):
+ inputs = processor(image).unsqueeze(0).to("hpu")
+ _ = model(inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(iterations):
+ inputs = processor(image).unsqueeze(0).to("hpu")
+ model_start_time = time.time()
+ _ = model(inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ latency = total_model_time * 1000 / iterations # in terms of ms
+ self.assertLessEqual(latency, 1.05 * LATENCY_FastViT_BF16_GRAPH_BASELINE)
diff --git a/tests/test_image_segmentation.py b/tests/test_image_segmentation.py
new file mode 100644
index 0000000000..15c2c1b863
--- /dev/null
+++ b/tests/test_image_segmentation.py
@@ -0,0 +1,119 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from unittest import TestCase
+
+import habana_frameworks.torch as ht
+import numpy as np
+import requests
+import torch
+from PIL import Image
+from transformers import AutoModel, AutoProcessor
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+adapt_transformers_to_gaudi()
+
+# For Gaudi 2
+LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 3.7109851837158203
+LATENCY_SAM_BF16_GRAPH_BASELINE = 98.92215728759766
+
+
+class GaudiSAMTester(TestCase):
+ """
+ Tests for Segment Anything Model - SAM
+ """
+
+ def prepare_model_and_processor(self):
+ model = AutoModel.from_pretrained("facebook/sam-vit-huge").to("hpu")
+ processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
+ model = model.eval()
+ return model, processor
+
+ def prepare_data(self):
+ image = Image.open(
+ requests.get(
+ "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png", stream=True
+ ).raw
+ ).convert("RGB")
+ input_points = [[[450, 600]]]
+ return input_points, image
+
+ def test_inference_default(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_inference_bf16(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_inference_hpu_graphs(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+
+ model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph
+
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_no_latency_regression_bf16(self):
+ warmup = 3
+ iterations = 10
+
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
+ for i in range(warmup):
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ _ = model(**inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(iterations):
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ model_start_time = time.time()
+ _ = model(**inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ latency = total_model_time * 1000 / iterations # in terms of ms
+ self.assertLessEqual(latency, 1.05 * LATENCY_SAM_BF16_GRAPH_BASELINE)
diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py
index 85982ce7c7..95a153c1f0 100644
--- a/tests/test_image_to_text_example.py
+++ b/tests/test_image_to_text_example.py
@@ -17,11 +17,15 @@
("llava-hf/llava-1.5-7b-hf", 1, 87.2901500056982),
("llava-hf/llava-1.5-13b-hf", 1, 54.41252589197953),
("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546),
+ ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742),
("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 123.00953973789325),
("llava-hf/llava-1.5-13b-hf", 1, 82.81132373492122),
+ ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 45.011551008367084),
+ ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 45.18544502949674),
+ ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 30.9535718774675),
],
}
else:
diff --git a/tests/test_openclip_vqa.py b/tests/test_openclip_vqa.py
new file mode 100644
index 0000000000..c0c3d38521
--- /dev/null
+++ b/tests/test_openclip_vqa.py
@@ -0,0 +1,81 @@
+import json
+import os
+import re
+import subprocess
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from .test_examples import TIME_PERF_FACTOR
+
+
+if os.environ.get("GAUDI2_CI", "0") == "1":
+ # Gaudi2 CI baselines
+ MODELS_TO_TEST = {
+ "bf16": [
+ ("laion/CLIP-ViT-g-14-laion2B-s12B-b42K", 1472),
+ ("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1816),
+ ],
+ }
+else:
+ # Gaudi1 CI baselines
+ MODELS_TO_TEST = {
+ "bf16": [
+ ("laion/CLIP-ViT-g-14-laion2B-s12B-b42K", 550),
+ ("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1200),
+ ],
+ }
+
+
+def _install_requirements():
+ PATH_TO_EXAMPLE_DIR = Path(__file__).resolve().parent.parent / "examples"
+ cmd_line = (
+ f"pip install -r {PATH_TO_EXAMPLE_DIR / 'visual-question-answering' / 'openclip_requirements.txt'}".split()
+ )
+ p = subprocess.Popen(cmd_line)
+ return_code = p.wait()
+ assert return_code == 0
+
+
+def _test_openclip_vqa(model_name: str, baseline: float):
+ _install_requirements()
+ command = ["python3"]
+ path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
+ env_variables = os.environ.copy()
+
+ command += [
+ f"{path_to_example_dir / 'visual-question-answering' / 'run_openclip_vqa.py'}",
+ f"--model_name_or_path {model_name}",
+ "--bf16",
+ "--use_hpu_graphs",
+ ]
+
+ with TemporaryDirectory() as tmp_dir:
+ command.append(f"--output_dir {tmp_dir}")
+ print(f"\n\nCommand to test: {' '.join(command)}\n")
+
+ pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
+ command = [x for y in command for x in re.split(pattern, y) if x]
+
+ proc = subprocess.run(command, env=env_variables)
+
+ # Ensure the run finished without any issue
+ # Use try-except to avoid logging the token if used
+ try:
+ assert proc.returncode == 0
+ except AssertionError as e:
+ if "'--token', 'hf_" in e.args[0]:
+ e.args = (f"The following command failed:\n{' '.join(command[:-2])}",)
+ raise
+
+ with open(Path(tmp_dir) / "results.json") as fp:
+ results = json.load(fp)
+
+ # Ensure performance requirements (throughput) are met
+ assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline
+
+
+@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["bf16"])
+def test_openclip_vqa_bf16(model_name: str, baseline: float):
+ _test_openclip_vqa(model_name, baseline)
diff --git a/tests/test_peft_inference.py b/tests/test_peft_inference.py
index 205f9d4fd9..05e0058515 100644
--- a/tests/test_peft_inference.py
+++ b/tests/test_peft_inference.py
@@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest import TestCase
-
+import pytest
+import torch
from peft import (
+ AdaptionPromptConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
TaskType,
get_peft_model,
+ tuners,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
@@ -28,11 +30,15 @@
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
-class GaudiPeftTester(TestCase):
- def __init__(self, *args, **kwargs):
- adapt_transformers_to_gaudi()
- super().__init__(*args, **kwargs)
+TEST_CASES = [
+ ("huggyllama/llama-7b", "prompt-tuning"),
+ ("huggyllama/llama-7b", "prefix-tuning"),
+ ("huggyllama/llama-7b", "p-tuning"),
+ ("huggyllama/llama-7b", "llama-adapter"),
+]
+
+class TestGaudiPeftTextGeneration:
def _text_generation(self, model, tokenizer, extra_kwargs=None):
generate_kwargs = {
"lazy_mode": True,
@@ -40,6 +46,8 @@ def _text_generation(self, model, tokenizer, extra_kwargs=None):
"max_new_tokens": 128,
"ignore_eos": True,
}
+ if extra_kwargs:
+ generate_kwargs.update(extra_kwargs)
generator = pipeline(
"text-generation",
model=model,
@@ -50,7 +58,8 @@ def _text_generation(self, model, tokenizer, extra_kwargs=None):
return output[0]["generated_text"]
def _test_text_generation(self, model_name_or_path, peft_method):
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
+ adapt_transformers_to_gaudi()
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if peft_method == "prompt-tuning":
config = PromptTuningConfig(
@@ -67,6 +76,19 @@ def _test_text_generation(self, model_name_or_path, peft_method):
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=8,
)
+ elif peft_method == "llama-adapter":
+ from optimum.habana.peft.layer import (
+ GaudiAdaptedAttention_getattr,
+ GaudiAdaptedAttentionPreAttnForward,
+ )
+
+ tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
+ tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
+ config = AdaptionPromptConfig(
+ adapter_layers=2,
+ adapter_len=4,
+ task_type=TaskType.CAUSAL_LM,
+ )
result = self._text_generation(model, tokenizer)
model = get_peft_model(model, config)
@@ -74,15 +96,15 @@ def _test_text_generation(self, model_name_or_path, peft_method):
model.__class__.prepare_inputs_for_generation = gaudi_prepare_inputs_for_generation
result1 = self._text_generation(model, tokenizer)
- self.assertNotEqual(result, result1)
+ if peft_method != "llama-adapter":
+ assert result != result1
result2 = self._text_generation(model, tokenizer, extra_kwargs={"reuse_cache": True})
- self.assertEqual(result1, result2)
+ assert result1 == result2
result3 = self._text_generation(model, tokenizer, extra_kwargs={"bucket_size": 10})
- self.assertEqual(result1, result3)
+ assert result1 == result3
- def test_text_generation_llama(self):
- self._test_text_generation("huggyllama/llama-7b", "prompt-tuning")
- self._test_text_generation("huggyllama/llama-7b", "p-tuning")
- self._test_text_generation("huggyllama/llama-7b", "prefix-tuning")
+ @pytest.mark.parametrize("model, method", TEST_CASES)
+ def test_text_generation_llama(self, model, method):
+ self._test_text_generation(model, method)
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index f5acbb7c9c..022b6b5e12 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -38,8 +38,9 @@
("codellama/CodeLlama-34b-hf", 1, True, 32.644),
("bigcode/starcoder2-3b", 1, False, 234.2649120507936),
("adept/persimmon-8b-base", 4, False, 366.73968820698406),
- ("Qwen/Qwen1.5-7B", 4, False, 488.82855464593257),
+ ("Qwen/Qwen1.5-7B", 4, False, 518.894516133132),
("google/gemma-7b", 1, False, 109.70751574382221),
+ ("state-spaces/mamba-130m-hf", 1536, False, 8600),
],
"fp8": [
("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68),
@@ -92,6 +93,7 @@
("Qwen/Qwen1.5-7B", 1, False, 39.29068423087616),
("adept/persimmon-8b-base", 1, False, 34.53559807384106),
("bigcode/starcoder2-3b", 1, False, 82.09655684566117),
+ ("state-spaces/mamba-130m-hf", 224, False, 794.542),
],
"fp8": [],
"deepspeed": [
diff --git a/tests/test_video_mae.py b/tests/test_video_mae.py
new file mode 100644
index 0000000000..00dc9c2d26
--- /dev/null
+++ b/tests/test_video_mae.py
@@ -0,0 +1,135 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import time
+from unittest import TestCase
+
+import habana_frameworks.torch as ht
+import numpy as np
+import pytest
+import torch
+from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
+
+
+if os.environ.get("GAUDI2_CI", "0") == "1":
+ # Gaudi2 CI baselines
+ LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848
+else:
+ # Gaudi1 CI baselines
+ LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 61.953186988830566
+MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics"
+
+
+@pytest.fixture(scope="module")
+def frame_buf():
+ return list(np.random.default_rng(123).random((16, 3, 224, 224)))
+
+
+@pytest.fixture(scope="module")
+def processor():
+ return VideoMAEImageProcessor.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(autouse=True, scope="class")
+def inputs(request, frame_buf, processor):
+ request.cls.inputs = processor(frame_buf, return_tensors="pt")
+ request.cls.inputs_hpu = request.cls.inputs.copy().to("hpu")
+
+
+@pytest.fixture(autouse=True, scope="class")
+def outputs_cpu(request):
+ model = VideoMAEForVideoClassification.from_pretrained(MODEL_NAME)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**request.cls.inputs)
+ request.cls.outputs_cpu = output
+
+
+@pytest.fixture(autouse=True, scope="class")
+def model_hpu(request):
+ request.cls.model_hpu = VideoMAEForVideoClassification.from_pretrained(MODEL_NAME).to("hpu")
+ request.cls.model_hpu_graph = ht.hpu.wrap_in_hpu_graph(request.cls.model_hpu)
+
+
+@pytest.fixture(autouse=True, scope="class")
+def outputs_hpu_default(request):
+ with torch.no_grad():
+ output = request.cls.model_hpu(**request.cls.inputs_hpu)
+ request.cls.outputs_hpu_default = output
+
+
+class GaudiVideoMAETester(TestCase):
+ """
+ Tests for VideoMAE on Gaudi
+ """
+
+ def test_inference_default(self):
+ """
+ Tests for equivalent cpu and hpu runs
+ """
+ self.assertTrue(
+ torch.equal(
+ self.outputs_cpu.logits.topk(10).indices,
+ self.outputs_hpu_default.logits.cpu().topk(10).indices,
+ )
+ )
+ self.assertTrue(torch.allclose(self.outputs_cpu.logits, self.outputs_hpu_default.logits, atol=5e-3))
+
+ def test_inference_bf16(self):
+ """
+ Tests for similar bf16 to regular inference
+ """
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
+ outputs = self.model_hpu(**self.inputs_hpu)
+ self.assertTrue(
+ torch.equal(
+ self.outputs_hpu_default.logits.topk(5).indices,
+ outputs.logits.topk(5).indices,
+ )
+ )
+
+ def test_inference_graph_bf16(self):
+ """
+ Test for similar bf16 to regular inference in graph mode
+ """
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
+ outputs = self.model_hpu_graph(**self.inputs_hpu)
+ self.assertTrue(
+ torch.equal(
+ self.outputs_hpu_default.logits.topk(5).indices,
+ outputs.logits.topk(5).indices,
+ )
+ )
+
+ def test_latency_graph_bf16(self):
+ """
+ Tests for performance degredations by up to 5%
+ """
+ warm_up_iters = 5
+ test_iters = 10
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
+ for _ in range(warm_up_iters):
+ self.model_hpu_graph(**self.inputs_hpu)
+ torch.hpu.synchronize()
+ start_time = time.time()
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
+ for _ in range(test_iters):
+ self.model_hpu_graph(**self.inputs_hpu)
+ torch.hpu.synchronize()
+ time_per_iter = (time.time() - start_time) * 1000 / test_iters # Time in ms
+ self.assertLess(time_per_iter, 1.05 * LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE)
diff --git a/tests/utils.py b/tests/utils.py
index cce23476e5..3e9114d14a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -60,6 +60,7 @@
"llama_guard": [("meta-llama/LlamaGuard-7b", "Habana/llama")],
"code_llama": [("codellama/CodeLlama-13b-Instruct-hf", "Habana/llama")],
"protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")],
+ "qwen2": [("Qwen/Qwen2-7B", "Habana/qwen")],
}
MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [