Skip to content

Commit

Permalink
refactor: Moved image generation schema to base image generation model
Browse files Browse the repository at this point in the history
  • Loading branch information
Irozuku committed Nov 21, 2024
1 parent 30c85be commit 54a59e2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 34 deletions.
36 changes: 2 additions & 34 deletions DashAI/back/models/hugging_face/stable_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,17 @@
import torch
from diffusers import DiffusionPipeline

from DashAI.back.core.schema_fields import (
enum_field,
float_field,
int_field,
schema_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.models.image_generation_model import ImageGenerationModel


class ImageGenerationSchema(BaseSchema):
"""Schema for image generation models."""

num_inference_steps: schema_field(
int_field(ge=1),
placeholder=5,
description="The number of denoising steps. More steps usually lead to a higher quality image at the expense of slower inference.",
) # type: ignore
guidance_scale: schema_field(
float_field(ge=0.0),
placeholder=7.5,
description="Higher guidance scale encourages images that are closer to the prompt, usually at the expense of lower image quality.",
) # type: ignore
device: schema_field(
enum_field(enum=["cuda", "cpu"]),
placeholder="cuda",
description="Device to run the model on. CUDA is recommended for faster generation if available.",
) # type: ignore


class StableDiffusionModel(ImageGenerationModel):
"""Class for models associated to StableDiffusionProcess."""

SCHEMA = ImageGenerationSchema

def __init__(self, **kwargs):
"""Initialize the model."""
kwargs = self.validate_and_transform(kwargs)
self.model_name = "stabilityai/stable-diffusion-2-1"
self.num_inference_steps = kwargs.pop("num_inference_steps")
self.guidance_scale = kwargs.pop("guidance_scale")
self.device = kwargs.pop("device")
super().__init__(**kwargs)

self.model_name = "stabilityai/stable-diffusion-2-1"
self.model = DiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=torch.float32 if self.device == "cuda" else torch.float16,
Expand Down
35 changes: 35 additions & 0 deletions DashAI/back/models/image_generation_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
from DashAI.back.core.schema_fields import (
enum_field,
float_field,
int_field,
schema_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.models.base_generative_model import BaseGenerativeModel


class ImageGenerationSchema(BaseSchema):
"""Schema for image generation models."""

num_inference_steps: schema_field(
int_field(ge=1),
placeholder=5,
description="The number of denoising steps. More steps usually lead to a higher quality image at the expense of slower inference.",
) # type: ignore
guidance_scale: schema_field(
float_field(ge=0.0),
placeholder=7.5,
description="Higher guidance scale encourages images that are closer to the prompt, usually at the expense of lower image quality.",
) # type: ignore
device: schema_field(
enum_field(enum=["cuda", "cpu"]),
placeholder="cuda",
description="Device to run the model on. CUDA is recommended for faster generation if available.",
) # type: ignore


class ImageGenerationModel(BaseGenerativeModel):
"""Class for models associated to ImageGenerationTasks."""

COMPATIBLE_COMPONENTS = ["ImageGenerationTask"]
SCHEMA = ImageGenerationSchema

def __init__(self, **kwargs):
"""Initialize the model."""
kwargs = self.validate_and_transform(kwargs)
self.num_inference_steps = kwargs.pop("num_inference_steps")
self.guidance_scale = kwargs.pop("guidance_scale")
self.device = kwargs.pop("device")

0 comments on commit 54a59e2

Please sign in to comment.