Skip to content

Commit

Permalink
Remove some predict call code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP authored and daanelson committed Nov 1, 2024
1 parent ef54514 commit 0ad7832
Showing 1 changed file with 50 additions and 93 deletions.
143 changes: 50 additions & 93 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,13 @@ def run_falcon_safety_checker(self, image):

return result == "normal"

def shared_predict(self, go_fast, *args, **kwargs):
if go_fast and not self.disable_fp8:
return self.fp8_predict(*args, **kwargs)
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
return self.base_predict(*args, **kwargs)


class SchnellPredictor(Predictor):
def setup(self) -> None:
Expand All @@ -529,26 +536,15 @@ def predict(
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
width, height = self.preprocess(aspect_ratio, megapixels)
if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -598,33 +594,18 @@ def predict(
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
width, height = self.preprocess(aspect_ratio, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -662,27 +643,16 @@ def predict(
self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=num_inference_steps,
seed=seed,
width=width,
height=height,
)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=self.num_steps,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down Expand Up @@ -737,31 +707,18 @@ def predict(
self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)
imgs, np_imgs = self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
Expand Down

0 comments on commit 0ad7832

Please sign in to comment.