Skip to content

Commit

Permalink
Decouple resolution and seed preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP authored and daanelson committed Nov 1, 2024
1 parent aa16cd9 commit ef54514
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Dict, Optional
from typing import Any, Tuple, Optional

import torch

Expand Down Expand Up @@ -292,18 +292,12 @@ def handle_loras(
elif cur_lora:
unload_loras(model)

def preprocess(
self, aspect_ratio: str, seed: Optional[int], megapixels: str
) -> Dict:
def preprocess(self, aspect_ratio: str, megapixels: str = "1") -> Tuple[int, int]:
width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

return {"width": width, "height": height, "seed": seed}
return (width, height)

@torch.inference_mode()
def base_predict(
Expand All @@ -322,6 +316,10 @@ def base_predict(
torch_device = torch.device("cuda")
init_image = None

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

# img2img only works for flux-dev
if image:
print("Image detected - settting to img2img mode")
Expand Down Expand Up @@ -530,14 +528,15 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

width, height = self.preprocess(aspect_ratio, megapixels)
if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
seed=seed,
width=width,
height=height,
)
else:
if self.disable_fp8:
Expand All @@ -546,7 +545,9 @@ def predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
Expand Down Expand Up @@ -596,7 +597,7 @@ def predict(
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)
width, height = self.preprocess(aspect_ratio, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
Expand All @@ -606,7 +607,9 @@ def predict(
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
**hws_kwargs,
seed=seed,
width=width,
height=height,
)
else:
if self.disable_fp8:
Expand All @@ -618,7 +621,9 @@ def predict(
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
**hws_kwargs,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
Expand Down Expand Up @@ -656,13 +661,17 @@ def predict(
) -> List[Path]:
self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)
else:
imgs, np_imgs = self.base_predict(
Expand All @@ -671,6 +680,8 @@ def predict(
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
Expand Down Expand Up @@ -725,27 +736,31 @@ def predict(

self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
Expand Down

0 comments on commit ef54514

Please sign in to comment.