Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting Stable Video Diffusion ControNet to HPU #1037

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,39 @@ python image_to_video_generation.py \
```

> For improved performance of the image-to-video pipeline on Gaudi, it is recommended to configure the environment
> by setting PT_HPU_MAX_COMPOUND_OP_SIZE to 1.
> by setting PT_HPU_MAX_COMPOUND_OP_SIZE to 1.

### Image-to-video ControlNet

Here is how to generate video conditioned by depth:

```
python image_to_video_generation.py \
--model_name_or_path "stabilityai/stable-video-diffusion-img2vid" \
--controlnet_model_name_or_path "CiaraRowles/temporal-controlnet-depth-svd-v1" \
--control_image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_0.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_1.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_2.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_3.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_4.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_5.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_6.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_7.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_8.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_9.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_10.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_11.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_12.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_13.png?raw=true" \
--image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/chair.png?raw=true" \
--video_save_dir SVD_controlnet \
--save_frames_as_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16 \
--num_frames 14 \
--motion_bucket_id=14 \
--width=512 \
--height=512
```
127 changes: 91 additions & 36 deletions examples/stable-diffusion/image_to_video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from diffusers.utils import export_to_video, load_image

from optimum.habana.diffusers import GaudiEulerDiscreteScheduler
from optimum.habana.diffusers import GaudiEulerDiscreteScheduler, GaudiStableVideoDiffusionPipeline
from optimum.habana.utils import set_seed


Expand Down Expand Up @@ -49,6 +49,12 @@ def main():
type=str,
help="Path to pre-trained model",
)
parser.add_argument(
"--controlnet_model_name_or_path",
default="CiaraRowles/temporal-controlnet-depth-svd-v1",
type=str,
help="Path to pre-trained controlnet model.",
)

# Pipeline arguments
parser.add_argument(
Expand All @@ -58,6 +64,13 @@ def main():
nargs="*",
help="Path to input image(s) to guide video generation",
)
parser.add_argument(
"--control_image_path",
type=str,
default=None,
nargs="*",
help="Path to controlnet input image(s) to guide video generation.",
)
parser.add_argument(
"--num_videos_per_prompt", type=int, default=1, help="The number of videos to generate per prompt image."
)
Expand Down Expand Up @@ -164,11 +177,9 @@ def main():
),
)
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")

parser.add_argument("--num_frames", type=int, default=25, help="The number of video frames to generate.")
args = parser.parse_args()

from optimum.habana.diffusers import GaudiStableVideoDiffusionPipeline

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand All @@ -177,6 +188,29 @@ def main():
)
logger.setLevel(logging.INFO)

# Load input image(s)
input = []
logger.info("Input image(s):")
if isinstance(args.image_path, str):
args.image_path = [args.image_path]
for image_path in args.image_path:
image = load_image(image_path)
image = image.resize((args.height, args.width))
input.append(image)
logger.info(image_path)

# Load control input image
control_input = []
if args.control_image_path is not None:
logger.info("Input control image(s):")
if isinstance(args.control_image_path, str):
args.control_image_path = [args.control_image_path]
for control_image in args.control_image_path:
image = load_image(control_image)
image = image.resize((args.height, args.width))
control_input.append(image)
logger.info(control_image)

# Initialize the scheduler and the generation pipeline
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
kwargs = {
Expand All @@ -185,44 +219,65 @@ def main():
"use_hpu_graphs": args.use_hpu_graphs,
"gaudi_config": args.gaudi_config_name,
}

set_seed(args.seed)
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16

regisss marked this conversation as resolved.
Show resolved Hide resolved
pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)
if args.control_image_path is not None:
from optimum.habana.diffusers import GaudiStableVideoDiffusionControlNetPipeline
from optimum.habana.diffusers.models import ControlNetSDVModel, UNetSpatioTemporalConditionControlNetModel

# Set seed before running the model
set_seed(args.seed)
controlnet = ControlNetSDVModel.from_pretrained(
args.controlnet_model_name_or_path, subfolder="controlnet", **kwargs
)
unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
args.model_name_or_path, subfolder="unet", **kwargs
)
pipeline = GaudiStableVideoDiffusionControlNetPipeline.from_pretrained(
args.model_name_or_path, controlnet=controlnet, unet=unet, **kwargs
)

# Load input image(s)
input = []
logger.info("Input image(s):")
if isinstance(args.image_path, str):
args.image_path = [args.image_path]
for image_path in args.image_path:
image = load_image(image_path)
image = image.resize((args.height, args.width))
input.append(image)
logger.info(image_path)
# Generate images
outputs = pipeline(
image=input,
controlnet_condition=control_input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
num_frames=args.num_frames,
)
else:
pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)

# Generate images
outputs = pipeline(
image=input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
)
# Generate images
outputs = pipeline(
image=input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
)

# Save the pipeline in the specified directory if not None
if args.pipeline_save_dir is not None:
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image
from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline
from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import (
GaudiStableVideoDiffusionControlNetPipeline,
)
from .pipelines.ddpm.pipeline_ddpm import GaudiDDPMPipeline
from .pipelines.pipeline_utils import GaudiDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .controlnet_sdv import ControlNetSDVModel
from .unet_2d import gaudi_unet_2d_model_forward
from .unet_2d_condition import gaudi_unet_2d_condition_model_forward
from .unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
Loading
Loading