Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Ldm3d #304

Merged
merged 19 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions docs/source/tutorials/stable_diffusion_ldm3d.mdx
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
<!---
Copyright 2022 The Intel authors Team and HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Text-to-(RGB, depth)

LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./stable_diffusion/overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
estelleafl marked this conversation as resolved.
Show resolved Hide resolved

The abstract from the paper is:

*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*


## How to generate rgb and depth images?
estelleafl marked this conversation as resolved.
Show resolved Hide resolved

To generate rgb and depth images with Stable Diffusion LDM3D on Gaudi, you need to instantiate two instances:
estelleafl marked this conversation as resolved.
Show resolved Hide resolved
- A pipeline with [`GaudiStableDiffusionLDM3DPipeline`]. This pipeline supports *text-to-(rgb, depth) generation*.
- A scheduler with [`GaudiDDIMScheduler`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiDDIMScheduler). This scheduler has been optimized for Gaudi.

When initializing the pipeline, you have to specify `use_habana=True` to deploy it on HPUs.
Furthermore, to get the fastest possible generations you should enable **HPU graphs** with `use_hpu_graphs=True`.
Finally, you will need to specify a [Gaudi configuration](https://huggingface.co/docs/optimum/habana/package_reference/gaudi_config) which can be downloaded from the Hugging Face Hub.

```python
from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionLDM3DPipeline
from optimum.habana.utils import set_seed

model_name = "Intel/ldm3d-4c"

scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")

set_seed(42)

pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained(
model_name,
scheduler=scheduler,
use_habana=True,
use_hpu_graphs=True,
gaudi_config="Habana/stable-diffusion",
)
outputs = pipeline(
prompt=["High quality photo of an astronaut riding a horse in space"],
num_images_per_prompt=1,
batch_size=1,
output_type="pil",
num_inference_steps=40,
guidance_scale=5.0,
negative_prompt=None
)


rgb_image, depth_image = outputs.rgb, outputs.depth
rgb_image[0].save("astronaut_ldm3d_rgb.png")
depth_image[0].save("astronaut_ldm3d_depth.png")
```
186 changes: 186 additions & 0 deletions examples/stable-diffusion/text_to_image_generation_ldm3d.py
regisss marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse
import logging
import sys
from pathlib import Path

import torch

from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionLDM3DPipeline
from optimum.habana.utils import set_seed


logger = logging.getLogger(__name__)


def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default="Intel/ldm3d-4c",
type=str,
help="Path to pre-trained model",
)

# Pipeline arguments
parser.add_argument(
"--prompts",
type=str,
nargs="*",
default="An image of a squirrel in Picasso style",
help="The prompt or prompts to guide the image generation.",
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=1, help="The number of images to generate per prompt."
)
parser.add_argument("--batch_size", type=int, default=1, help="The number of images in a batch.")
parser.add_argument("--height", type=int, default=512, help="The height in pixels of the generated images.")
parser.add_argument("--width", type=int, default=512, help="The width in pixels of the generated images.")
parser.add_argument(
"--num_inference_steps",
type=int,
default=50,
help=(
"The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense"
" of slower inference."
),
)
parser.add_argument(
"--guidance_scale",
type=float,
default=5.0,
help=(
"Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598)."
" Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,"
" usually at the expense of lower image quality."
),
)
parser.add_argument(
"--negative_prompts",
type=str,
nargs="*",
default=None,
help="The prompt or prompts not to guide the image generation.",
)
parser.add_argument(
"--eta",
type=float,
default=0.0,
help="Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502.",
)
parser.add_argument(
"--output_type",
type=str,
choices=["pil", "np"],
default="pil",
help="Whether to return PIL images or Numpy arrays.",
)

parser.add_argument(
"--pipeline_save_dir",
type=str,
default=None,
help="The directory where the generation pipeline will be saved.",
)
parser.add_argument(
"--image_save_dir",
type=str,
default="./stable-diffusion-generated-images",
help="The directory where images will be saved.",
)

parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.")

# HPU-specific arguments
parser.add_argument("--use_habana", action="store_true", help="Use HPU.")
parser.add_argument(
"--use_hpu_graphs", action="store_true", help="Use HPU graphs on HPU. This should lead to faster generations."
)
parser.add_argument(
"--gaudi_config_name",
type=str,
default="Habana/stable-diffusion",
help=(
"Name or path of the Gaudi configuration. In particular, it enables to specify how to apply Habana Mixed"
" Precision."
),
)
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")

args = parser.parse_args()

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)

# Initialize the scheduler and the generation pipeline
scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
kwargs = {
"scheduler": scheduler,
"use_habana": args.use_habana,
"use_hpu_graphs": args.use_hpu_graphs,
"gaudi_config": args.gaudi_config_name,
}
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16
pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained(
regisss marked this conversation as resolved.
Show resolved Hide resolved
args.model_name_or_path,
**kwargs,
)

# Set seed before running the model
set_seed(args.seed)

# Generate images
outputs = pipeline(
prompt=args.prompts,
num_images_per_prompt=args.num_images_per_prompt,
batch_size=args.batch_size,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
negative_prompt=args.negative_prompts,
eta=args.eta,
output_type=args.output_type,
)

# Save the pipeline in the specified directory if not None
if args.pipeline_save_dir is not None:
pipeline.save_pretrained(args.pipeline_save_dir)

# Save images in the specified directory if not None and if they are in PIL format
if args.image_save_dir is not None:
if args.output_type == "pil":
image_save_dir = Path(args.image_save_dir)
image_save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving images in {image_save_dir.resolve()}...")
for i, rgb in enumerate(outputs.rgb):
rgb.save(image_save_dir / f"rgb_{i+1}.png")
for i, depth in enumerate(outputs.depth):
depth.save(image_save_dir / f"depth_{i+1}.png")
else:
logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pipelines.pipeline_utils import GaudiDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline
from .schedulers import GaudiDDIMScheduler
Loading
Loading