diff --git a/Makefile b/Makefile index ac8ff3f108..e88d6b3562 100644 --- a/Makefile +++ b/Makefile @@ -80,7 +80,7 @@ slow_tests_custom_file_input: test_installs # Run single-card non-regression tests slow_tests_1x: test_installs python -m pytest tests/test_examples.py -v -s -k "single_card" - python -m pip install peft==0.10.0 + python -m pip install peft==0.12.0 python -m pytest tests/test_peft_inference.py python -m pytest tests/test_pipeline.py @@ -96,7 +96,7 @@ slow_tests_deepspeed: test_installs slow_tests_diffusers: test_installs python -m pytest tests/test_diffusers.py -v -s -k "test_no_" python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion" - python -m pip install peft==0.7.0 + python -m pip install peft==0.12.0 python -m pytest tests/test_diffusers.py -v -s -k "test_train_text_to_image_" python -m pytest tests/test_diffusers.py -v -s -k "test_train_controlnet" python -m pytest tests/test_diffusers.py -v -s -k "test_deterministic_image_generation" diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 8caa659ca6..0bcec10e7d 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -464,19 +464,50 @@ def main(): ) if args.unet_adapter_name_or_path is not None: - from peft import PeftModel + from peft import PeftModel, tuners + from peft.utils import PeftType + + from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight + + tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer._FBD_CUDA = False pipeline.unet = PeftModel.from_pretrained(pipeline.unet, args.unet_adapter_name_or_path) - pipeline.unet = pipeline.unet.merge_and_unload() + if pipeline.unet.peft_type in [PeftType.OFT, PeftType.BOFT]: + # WA torch.inverse issue in Synapse AI 1.17 for oft and boft + if args.bf16: + pipeline.unet = pipeline.unet.to(torch.float32) + pipeline.unet = pipeline.unet.merge_and_unload() + if args.bf16: + pipeline.unet = pipeline.unet.to(torch.bfloat16) + else: + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16): + pipeline.unet = pipeline.unet.merge_and_unload() if args.text_encoder_adapter_name_or_path is not None: - from peft import PeftModel + from peft import PeftModel, tuners + from peft.utils import PeftType + + from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight + + tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer._FBD_CUDA = False pipeline.text_encoder = PeftModel.from_pretrained( pipeline.text_encoder, args.text_encoder_adapter_name_or_path ) - pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() - + if pipeline.text_encoder.peft_type in [PeftType.OFT, PeftType.BOFT]: + # WA torch.inverse issue in Synapse AI 1.17 for oft and boft + if args.bf16: + pipeline.text_encoder = pipeline.text_encoder.to(torch.float32) + pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() + if args.bf16: + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) + else: + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16): + pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() else: # SD LDM3D use-case from optimum.habana.diffusers import GaudiStableDiffusionLDM3DPipeline as GaudiStableDiffusionPipeline diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md index 28e2d4e8c0..d9c7cf36e7 100644 --- a/examples/stable-diffusion/training/README.md +++ b/examples/stable-diffusion/training/README.md @@ -319,10 +319,10 @@ Prior-preservation is used to avoid overfitting and language-drift. Refer to the According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. ### PEFT model finetune -We provide example for dreambooth to use lora/lokr/loha/oft to finetune unet or text encoder. +We provide example for dreambooth to use `lora`, `lokr`, `loha`, `oft` and `boft` to finetune unet or text encoder. -**___Note: When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we -use *1e-4* instead of the usual *5e-6*.___** +> [!NOTE] +> When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we use *1e-4* instead of the usual *5e-6*. Launch the multi-card training using: ```bash @@ -355,19 +355,18 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ lora --unet_r 8 --unet_alpha 8 ``` -Similar command could be applied to loha, lokr, oft. +Similar command could be applied to `loha`, `lokr`, `oft` or `boft`. You could check each adapter specific args by "--help", like you could use following command to check oft specific args. ```bash -python3 train_dreambooth.py oft --help +python train_dreambooth.py oft --help ``` -**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu. -there's error like "cpu fallback is not supported during hpu graph capturing"___** - +> [!NOTE] +> `boft` and `oft` do not work with hpu graphs mode since `torch.inverse` `torch.linalg.solve` need to fallback to cpu. Pls. remove `--use_hpu_graphs_for_training` and `--use_hpu_graphs_for_inference` to avoid `cpu fallback is not supported during hpu graph capturing` error. -You could use text_to_image_generation.py to generate picture using the peft adapter like +You could use `text_to_image_generation.py` script to generate picture using the peft adapter: ```bash python ../text_to_image_generation.py \ @@ -384,8 +383,8 @@ python ../text_to_image_generation.py \ ``` ### DreamBooth training example for Stable Diffusion XL -You could use the dog images as example as well. -You can launch training using: +You could use the dog images as example for training: + ```bash export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" export INSTANCE_DIR="dog" @@ -415,7 +414,7 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl. ``` -You could use text_to_image_generation.py to generate picture using the peft adapter like +Then, you could use `text_to_image_generation.py` to generate picture using the peft adapter: ```bash python ../text_to_image_generation.py \ diff --git a/examples/stable-diffusion/training/requirements.txt b/examples/stable-diffusion/training/requirements.txt index 7fb1748675..cb491f0789 100644 --- a/examples/stable-diffusion/training/requirements.txt +++ b/examples/stable-diffusion/training/requirements.txt @@ -1,2 +1,2 @@ imagesize -peft == 0.10.0 +peft == 0.12.0 diff --git a/examples/stable-diffusion/training/train_dreambooth.py b/examples/stable-diffusion/training/train_dreambooth.py index b34f3c12c5..a589efea1f 100644 --- a/examples/stable-diffusion/training/train_dreambooth.py +++ b/examples/stable-diffusion/training/train_dreambooth.py @@ -52,7 +52,7 @@ from diffusers.utils.torch_utils import is_compiled_module from habana_frameworks.torch.hpu import memory_stats from huggingface_hub import HfApi -from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model +from peft import BOFTConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model, tuners from PIL import Image from torch.utils.data import Dataset from torchvision import transforms @@ -108,7 +108,9 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") -def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: +def create_unet_adapter_config( + args: argparse.Namespace, +) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig, BOFTConfig]: if args.adapter == "full": raise ValueError("Cannot create unet adapter config for full parameter") @@ -152,6 +154,21 @@ def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, Lo coft=args.unet_use_coft, eps=args.unet_eps, ) + elif args.adapter == "boft": + config = BOFTConfig( + boft_block_size=args.unet_block_size, + boft_block_num=args.unet_block_num, + boft_n_butterfly_factor=args.unet_n_butterfly_factor, + target_modules=UNET_TARGET_MODULES, + boft_dropout=args.unet_dropout, + bias=args.unet_bias, + ) + from optimum.habana.peft.layer import GaudiBoftConv2dForward, GaudiBoftLinearForward + + tuners.boft.layer.Linear.forward = GaudiBoftLinearForward + tuners.boft.layer.Conv2d.forward = GaudiBoftConv2dForward + tuners.boft.layer._FBD_CUDA = False + else: raise ValueError(f"Unknown adapter type {args.adapter}") @@ -160,7 +177,7 @@ def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, Lo def create_text_encoder_adapter_config( args: argparse.Namespace, -) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: +) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig, BOFTConfig]: if args.adapter == "full": raise ValueError("Cannot create text_encoder adapter config for full parameter") @@ -202,6 +219,20 @@ def create_text_encoder_adapter_config( coft=args.te_use_coft, eps=args.te_eps, ) + elif args.adapter == "boft": + config = BOFTConfig( + boft_block_size=args.te_block_size, + boft_block_num=args.te_block_num, + boft_n_butterfly_factor=args.te_n_butterfly_factor, + target_modules=TEXT_ENCODER_TARGET_MODULES, + boft_dropout=args.te_dropout, + bias=args.te_bias, + ) + from optimum.habana.peft.layer import GaudiBoftConv2dForward, GaudiBoftLinearForward + + tuners.boft.layer.Linear.forward = GaudiBoftLinearForward + tuners.boft.layer.Conv2d.forward = GaudiBoftConv2dForward + tuners.boft.layer._FBD_CUDA = False else: raise ValueError(f"Unknown adapter type {args.adapter}") @@ -632,6 +663,44 @@ def parse_args(input_args=None): help="The control strength of COFT for text_encoder, only used if `train_text_encoder` is True", ) + # boft adapter + boft = subparsers.add_parser("boft", help="Use Boft adapter") + boft.add_argument("--unet_block_size", type=int, default=8, help="Boft block_size for unet") + boft.add_argument("--unet_block_num", type=int, default=0, help="Boft block_num for unet") + boft.add_argument("--unet_n_butterfly_factor", type=int, default=1, help="Boft n_butterfly_factor for unet") + boft.add_argument("--unet_dropout", type=float, default=0.1, help="Boft dropout for unet") + boft.add_argument("--unet_bias", type=str, default="boft_only", help="Boft bias for unet") + boft.add_argument( + "--te_block_size", + type=int, + default=8, + help="Boft block_size for text_encoder,only used if `train_text_encoder` is True", + ) + boft.add_argument( + "--te_block_num", + type=int, + default=0, + help="Boft block_num for text_encoder,only used if `train_text_encoder` is True", + ) + boft.add_argument( + "--te_n_butterfly_factor", + type=int, + default=1, + help="Boft n_butterfly_factor for text_encoder,only used if `train_text_encoder` is True", + ) + boft.add_argument( + "--te_dropout", + type=float, + default=0.1, + help="Boft dropout for text_encoder,only used if `train_text_encoder` is True", + ) + boft.add_argument( + "--te_bias", + type=str, + default="boft_only", + help="Boft bias for text_encoder, only used if `train_text_encoder` is True", + ) + if input_args is not None: args = parser.parse_args(input_args) else: diff --git a/optimum/habana/peft/__init__.py b/optimum/habana/peft/__init__.py index ed33e84393..229721fc5a 100644 --- a/optimum/habana/peft/__init__.py +++ b/optimum/habana/peft/__init__.py @@ -2,6 +2,9 @@ GaudiAdaloraLayerSVDLinearForward, GaudiAdaptedAttention_getattr, GaudiAdaptedAttentionPreAttnForward, + GaudiBoftConv2dForward, + GaudiBoftGetDeltaWeight, + GaudiBoftLinearForward, GaudiPolyLayerLinearForward, ) from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py index fb6074cdbc..87adc439c4 100755 --- a/optimum/habana/peft/layer.py +++ b/optimum/habana/peft/layer.py @@ -217,3 +217,174 @@ def GaudiAdaptedAttention_getattr(self, name: str): # This is necessary as e.g. causal models have various methods that we # don't want to re-implement here. return getattr(self.model, name) + + +def GaudiBoftConv2dForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Copied from Conv2d.forward: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L899 + The only differences are: + - torch.block_diag operate in cpu, or else lazy mode will hang + - delete fbd_cuda_available logic, + """ + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + device=x.device, + dtype=x.dtype, + ) + boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu() + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + boft_P = self.boft_P.to(x) + block_diagonal_butterfly = block_diagonal_butterfly.to(x) + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.base_layer.weight.data.dtype) + + orig_weight = self.base_layer.weight.data + orig_weight = orig_weight.view( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + self.out_features, + ) + rotated_weight = torch.mm(boft_rotation, orig_weight) + + scaled_rotated_weight = rotated_weight * boft_scale + + scaled_rotated_weight = scaled_rotated_weight.view( + self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + ) + result = F.conv2d( + input=x, + weight=scaled_rotated_weight, + bias=self.base_layer.bias, + padding=self.base_layer.padding[0], + stride=self.base_layer.stride[0], + ) + + result = result.to(previous_dtype) + return result + + +def GaudiBoftLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Copied from Linear.forward: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L587 + The only differences are: + - torch.block_diag operate in cpu, or else lazy mode will hang + - delete fbd_cuda_available logic, + """ + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye(self.in_features, device=x.device) + boft_scale = torch.ones((int(self.out_features), 1), device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu() + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + # The BOFT author's cayley_batch, dropout and FastBlockDiag ONLY return fp32 outputs. + boft_P = self.boft_P.to(x) + block_diagonal_butterfly = block_diagonal_butterfly.to(x) + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.get_base_layer().weight.data.dtype) + + orig_weight = self.get_base_layer().weight.data + orig_weight = torch.transpose(orig_weight, 0, 1) + rotated_weight = torch.mm(boft_rotation, orig_weight) + rotated_weight = torch.transpose(rotated_weight, 0, 1) + + scaled_rotated_weight = rotated_weight * boft_scale + + result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias) + + result = result.to(previous_dtype) + return result + + +def GaudiBoftGetDeltaWeight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + """ + Copied from Linear.get_delta_weight: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/boft/layer.py#L555 + The only differences are: + - torch.block_diag operate in cpu, or else lazy mode will hang + - delete fbd_cuda_available logic, + """ + + boft_R = self.boft_R[adapter] + boft_s = self.boft_s[adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0).cpu() + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + boft_P = self.boft_P + block_diagonal_butterfly = block_diagonal_butterfly.to(boft_P) + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + return butterfly_oft_mat, boft_s diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 3015dc21db..bacd685e76 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -246,6 +246,88 @@ class GaudiStableDiffusionPipelineTester(TestCase): Tests the StableDiffusionPipeline for Gaudi. """ + def merge_peft_adapter(self, model, adapter): + from peft import BOFTConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model + + UNET_TARGET_MODULES = [ + "to_q", + "to_k", + "to_v", + "proj", + "proj_in", + "proj_out", + "conv", + "conv1", + "conv2", + "conv_shortcut", + "to_out.0", + "time_emb_proj", + "ff.net.2", + ] + TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] + target_modules = ( + UNET_TARGET_MODULES if isinstance(model, UNet2DConditionModel) else TEXT_ENCODER_TARGET_MODULES + ) + + if adapter == "lora": + config = LoraConfig( + r=2, + lora_alpha=2, + target_modules=target_modules, + lora_dropout=0.0, + bias="none", + init_lora_weights=True, + ) + elif adapter == "loha": + config = LoHaConfig( + r=2, + alpha=2, + target_modules=target_modules, + rank_dropout=0.0, + module_dropout=0.0, + use_effective_conv2d=False, + init_weights=True, + ) + elif adapter == "lokr": + config = LoKrConfig( + r=2, + alpha=2, + target_modules=target_modules, + rank_dropout=0.0, + module_dropout=0.0, + use_effective_conv2d=False, + decompose_both=False, + decompose_factor=-1, + init_weights=True, + ) + elif adapter == "oft": + config = OFTConfig( + r=2, + target_modules=target_modules, + module_dropout=0.0, + init_weights=True, + coft=False, + eps=0.0, + ) + elif adapter == "boft": + from peft import tuners + + from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight + + tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight + tuners.boft.layer._FBD_CUDA = False + config = BOFTConfig( + boft_block_size=1, + boft_block_num=0, + boft_n_butterfly_factor=1, + target_modules=target_modules, + boft_dropout=0.1, + bias="boft_only", + ) + model = get_peft_model(model, config) + return model.merge_and_unload() + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -626,6 +708,45 @@ def test_stable_diffusion_hpu_graphs(self): self.assertEqual(len(images), 10) self.assertEqual(images[-1].shape, (64, 64, 3)) + @parameterized.expand(["lora", "loha", "lokr", "oft", "boft"]) + @slow + def test_no_peft_regression_bf16(self, peft_adapter): + prompts = [ + "An image of a squirrel in Picasso style", + ] + num_images_per_prompt = 1 + batch_size = 1 + model_name = "runwayml/stable-diffusion-v1-5" + scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + model_name, + scheduler=scheduler, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"), + torch_dtype=torch.bfloat16, + ) + if peft_adapter not in ["boft", "oft"]: + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True): + pipeline.unet = self.merge_peft_adapter(pipeline.unet, peft_adapter) + pipeline.text_encoder = self.merge_peft_adapter(pipeline.text_encoder, peft_adapter) + else: + # WA torch.inverse issue in Synapse AI 1.17 for oft and boft + pipeline.unet = pipeline.unet.to(torch.float32) + pipeline.unet = self.merge_peft_adapter(pipeline.unet, peft_adapter) + pipeline.unet = pipeline.unet.to(torch.bfloat16) + pipeline.text_encoder = pipeline.text_encoder.to(torch.float32) + pipeline.text_encoder = self.merge_peft_adapter(pipeline.text_encoder, peft_adapter) + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) + + set_seed(27) + outputs = pipeline( + prompt=prompts, + num_images_per_prompt=num_images_per_prompt, + batch_size=batch_size, + ) + self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts)) + @slow def test_no_throughput_regression_bf16(self): prompts = [ @@ -2480,6 +2601,9 @@ def _test_dreambooth(self, extra_config, train_text_encoder=False): if train_text_encoder: test_args.append("--train_text_encoder") test_args.append(extra_config) + if "boft" in extra_config: + extra_args = "--unet_block_size 1 --te_block_size 1" + test_args.extend(extra_args.split()) p = subprocess.Popen(test_args) return_code = p.wait() @@ -2536,6 +2660,14 @@ def test_dreambooth_oft(self): def test_dreambooth_oft_with_text_encoder(self): self._test_dreambooth("oft", train_text_encoder=True) + @slow + def test_dreambooth_boft(self): + self._test_dreambooth("boft") + + @slow + def test_dreambooth_boft_with_text_encoder(self): + self._test_dreambooth("boft", train_text_encoder=True) + class DreamBoothLoRASDXL(TestCase): def _test_dreambooth_lora_sdxl(self, train_text_encoder=False):