Skip to content

Commit

Permalink
Porting Stable Video Diffusion ControNet to HPU
Browse files Browse the repository at this point in the history
Signed-off-by: Wenbin Chen <[email protected]>
  • Loading branch information
wenbinc-Bin committed Jun 19, 2024
1 parent 9aa739b commit b6b55af
Show file tree
Hide file tree
Showing 7 changed files with 1,858 additions and 33 deletions.
32 changes: 32 additions & 0 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,35 @@ python image_to_video_generation.py \
--gaudi_config Habana/stable-diffusion \
--bf16
```

### Image-to-video ControlNet

Here is how to generate video conditioned by depth:
python image_to_video_generation.py \
--model_name_or_path "stabilityai/stable-video-diffusion-img2vid" \
--controlnet_model_name_or_path "CiaraRowles/temporal-controlnet-depth-svd-v1" \
--control_image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_0.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_1.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_2.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_3.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_4.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_5.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_6.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_7.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_8.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_9.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_10.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_11.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_12.png?raw=true" \
"https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_13.png?raw=true" \
--image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/chair.png?raw=true" \
--video_save_dir SVD_controlnet \
--save_frames_as_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16 \
--num_frames 14 \
--motion_bucket_id=14 \
--width=512 \
--height=512
135 changes: 102 additions & 33 deletions examples/stable-diffusion/image_to_video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def main():
type=str,
help="Path to pre-trained model",
)
parser.add_argument(
"--controlnet_model_name_or_path",
default="CiaraRowles/temporal-controlnet-depth-svd-v1",
type=str,
help="Path to pre-trained controlnet model",
)

# Pipeline arguments
parser.add_argument(
Expand All @@ -58,6 +64,13 @@ def main():
nargs="*",
help="Path to input image(s) to guide video generation",
)
parser.add_argument(
"--control_image_path",
type=str,
default=None,
nargs="*",
help="Path to controlnet input image(s) to guide video generation",
)
parser.add_argument(
"--num_videos_per_prompt", type=int, default=1, help="The number of videos to generate per prompt image."
)
Expand Down Expand Up @@ -164,7 +177,12 @@ def main():
),
)
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")

parser.add_argument(
"--num_frames",
type=int,
default=25,
help="The number of video frames to generate."
)
args = parser.parse_args()

from optimum.habana.diffusers import GaudiStableVideoDiffusionPipeline
Expand All @@ -177,6 +195,30 @@ def main():
)
logger.setLevel(logging.INFO)

# Load input image(s)
input = []
logger.info("Input image(s):")
if isinstance(args.image_path, str):
args.image_path = [args.image_path]
for image_path in args.image_path:
print(image_path)
image = load_image(image_path)
image = image.resize((args.height, args.width))
input.append(image)
logger.info(image_path)

# Load control input image
control_input = []
if args.control_image_path is not None:
logger.info("Input control image(s):")
if isinstance(args.control_image_path, str):
args.control_image_path = [args.control_image_path]
for control_image in args.control_image_path:
image = load_image(control_image)
image = image.resize((args.height, args.width))
control_input.append(image)
logger.info(control_image)

# Initialize the scheduler and the generation pipeline
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
kwargs = {
Expand All @@ -188,41 +230,68 @@ def main():
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16

pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)
if args.control_image_path is not None:
from optimum.habana.diffusers import GaudiStableVideoDiffusionPipelineControlNet
from optimum.habana.diffusers.models import ControlNetSDVModel
from optimum.habana.diffusers.models import UNetSpatioTemporalConditionControlNetModel
controlnet = controlnet = ControlNetSDVModel.from_pretrained(
args.controlnet_model_name_or_path,
subfolder="controlnet")
unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
args.model_name_or_path,
subfolder="unet")
pipeline = GaudiStableVideoDiffusionPipelineControlNet.from_pretrained(
args.model_name_or_path,
controlnet=controlnet,
unet=unet,
**kwargs)

# Set seed before running the model
set_seed(args.seed)
# Set seed before running the model
set_seed(args.seed)

# Load input image(s)
input = []
logger.info("Input image(s):")
if isinstance(args.image_path, str):
args.image_path = [args.image_path]
for image_path in args.image_path:
image = load_image(image_path)
image = image.resize((args.height, args.width))
input.append(image)
logger.info(image_path)
# Generate images
outputs = pipeline(
image=input,
controlnet_condition=control_input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
num_frames=args.num_frames,
)
else:
pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)

# Generate images
outputs = pipeline(
image=input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
)
# Set seed before running the model
set_seed(args.seed)

# Generate images
outputs = pipeline(
image=input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
noise_aug_strength=args.noise_aug_strength,
decode_chunk_size=args.decode_chunk_size,
output_type=args.output_type,
)

# Save the pipeline in the specified directory if not None
if args.pipeline_save_dir is not None:
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import GaudiStableDiffusionUpscalePipeline
from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline
from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline
from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import GaudiStableVideoDiffusionPipelineControlNet
from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler
2 changes: 2 additions & 0 deletions optimum/habana/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .unet_2d_condition import gaudi_unet_2d_condition_model_forward
from .controlnet_sdv import ControlNetSDVModel
from .unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
Loading

0 comments on commit b6b55af

Please sign in to comment.