Skip to content

Commit

Permalink
feat(runner): add support for SD3-medium model
Browse files Browse the repository at this point in the history
This commit introduces support for the Stable Diffusion 3 Medium model
from Hugging Face:
[https://huggingface.co/stabilityai/stable-diffusion-3-medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium).

Please be aware that this model has restrictive licensing at the time of
writing and is not yet advised for public use. Ensure you read and
understand the [licensing
terms](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)
before enabling this model on your orchestrator.
  • Loading branch information
rickstaa committed Jul 16, 2024
1 parent b059e9b commit cd1feb4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
26 changes: 21 additions & 5 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from typing import List, Tuple, Optional
from enum import Enum

import PIL
import torch
Expand All @@ -9,6 +10,7 @@
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
StableDiffusion3Pipeline,
)
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file
Expand All @@ -24,7 +26,17 @@

logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers"

@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))


class TextToImagePipeline(Pipeline):
Expand All @@ -46,7 +58,7 @@ def __init__(self, model_id: str):
for _, _, files in os.walk(folder_path)
for fname in files
)
or SDXL_LIGHTNING_MODEL_ID in model_id
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("TextToImagePipeline loading fp16 variant for %s", model_id)
Expand All @@ -59,7 +71,7 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.bfloat16

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if SDXL_LIGHTNING_MODEL_ID in model_id:
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"

# ByteDance/SDXL-Lightning-2step
Expand All @@ -81,7 +93,7 @@ def __init__(self, model_id: str):
unet.load_state_dict(
load_file(
hf_hub_download(
SDXL_LIGHTNING_MODEL_ID,
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
Expand All @@ -96,6 +108,10 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.SD3_MEDIUM.value in model_id:
self.ldm = StableDiffusion3Pipeline.from_pretrained(model_id, **kwargs).to(
torch_device
)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
Expand Down Expand Up @@ -190,7 +206,7 @@ def __call__(
# SD turbo models were trained without guidance_scale so
# it should be set to 0
kwargs["guidance_scale"] = 0.0
elif SDXL_LIGHTNING_MODEL_ID in self.model_id:
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
Expand Down
3 changes: 2 additions & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ async def text_to_image(
for seed in seeds:
try:
params.seed = seed
imgs, nsfw_check = pipeline(**params.model_dump())
kwargs = {k: v for k,v in params.model_dump().items() if k != "model_id"}
imgs, nsfw_check = pipeline(**kwargs)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ function download_all_models() {
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
Expand Down
4 changes: 3 additions & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
diffusers==0.28.0
diffusers==0.29.2
accelerate==0.30.1
transformers==4.41.1
fastapi==0.111.0
Expand All @@ -14,3 +14,5 @@ deepcache==0.1.1
safetensors==0.4.3
scipy==1.13.0
numpy==1.26.4
sentencepiece== 0.2.0
protobuf==5.27.2

0 comments on commit cd1feb4

Please sign in to comment.