Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some predict call duplication #41

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 92 additions & 89 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Dict, Optional
from typing import Any, Tuple, Optional

import torch

Expand Down Expand Up @@ -292,18 +292,12 @@ def handle_loras(
elif cur_lora:
unload_loras(model)

def preprocess(
self, aspect_ratio: str, seed: Optional[int], megapixels: str
) -> Dict:
def preprocess(self, aspect_ratio: str, megapixels: str = "1") -> Tuple[int, int]:
width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

return {"width": width, "height": height, "seed": seed}
return (width, height)

@torch.inference_mode()
def base_predict(
Expand All @@ -322,6 +316,10 @@ def base_predict(
torch_device = torch.device("cuda")
init_image = None

if not seed:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. fp8 predict does come up with a seed, but there's value in logging the seed s.t. the user can see it for subsequent predictions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah jk it's there isn't it. fair enough, this'll work.

seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

# img2img only works for flux-dev
if image:
print("Image detected - settting to img2img mode")
Expand Down Expand Up @@ -507,6 +505,45 @@ def run_falcon_safety_checker(self, image):

return result == "normal"

def shared_predict(
self,
go_fast: bool,
prompt: str,
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: int = None,
width: int = 1024,
height: int = 1024,
):
if go_fast and not self.disable_fp8:
return self.fp8_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
return self.base_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)


class SchnellPredictor(Predictor):
def setup(self) -> None:
Expand All @@ -530,24 +567,16 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
)
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -596,30 +625,19 @@ def predict(
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
**hws_kwargs,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
**hws_kwargs,
)
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -656,22 +674,16 @@ def predict(
) -> List[Path]:
self.handle_loras(go_fast, lora_weights, lora_scale)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
)
width, height = self.preprocess(aspect_ratio)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -725,28 +737,19 @@ def predict(

self.handle_loras(go_fast, lora_weights, lora_scale)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
)
width, height = self.preprocess(aspect_ratio)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down
Loading