From 6e7f202b8df50c6a1d19e00802f9b85dcc188c59 Mon Sep 17 00:00:00 2001 From: Sean Sube <seansube@gmail.com> Date: Sat, 22 Apr 2023 00:11:33 -0500 Subject: [PATCH] fix(api): store both pre-parse and parsed prompts (#320) --- api/onnx_web/diffusers/run.py | 25 ++++++++++++------------- api/onnx_web/params.py | 5 +++++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 1f13b59a1..125878155 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -30,6 +30,14 @@ logger = getLogger(__name__) +def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tuple[str, float]]]: + prompt, loras = get_loras_from_prompt(params.input_prompt) + prompt, inversions = get_inversions_from_prompt(prompt) + params.prompt = prompt + + return loras, inversions + + def run_highres( job: WorkerContext, server: ServerContext, @@ -164,10 +172,7 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: latents = get_latents_from_seed(params.seed, size, batch=params.batch) - - (prompt, loras) = get_loras_from_prompt(params.prompt) - (prompt, inversions) = get_inversions_from_prompt(prompt) - params.prompt = prompt + loras, inversions = parse_prompt(params) pipe_type = "lpw" if params.lpw() else "txt2img" pipe = load_pipeline( @@ -260,9 +265,7 @@ def run_img2img_pipeline( strength: float, source_filter: Optional[str] = None, ) -> None: - (prompt, loras) = get_loras_from_prompt(params.prompt) - (prompt, inversions) = get_inversions_from_prompt(prompt) - params.prompt = prompt + loras, inversions = parse_prompt(params) # filter the source image if source_filter is not None: @@ -376,9 +379,7 @@ def run_inpaint_pipeline( progress = job.get_progress_callback() stage = StageParams(tile_order=tile_order) - (prompt, loras) = get_loras_from_prompt(params.prompt) - (prompt, inversions) = get_inversions_from_prompt(prompt) - params.prompt = prompt + loras, inversions = parse_prompt(params) # calling the upscale_outpaint stage directly needs accumulating progress progress = ChainProgress.from_progress(progress) @@ -444,9 +445,7 @@ def run_upscale_pipeline( progress = job.get_progress_callback() stage = StageParams() - (prompt, loras) = get_loras_from_prompt(params.prompt) - (prompt, inversions) = get_inversions_from_prompt(prompt) - params.prompt = prompt + loras, inversions = parse_prompt(params) image = run_upscale_correction( job, server, stage, params, source, upscale=upscale, callback=progress diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index ddfdf4ac4..ac2c06455 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -173,6 +173,7 @@ class ImageParams: eta: float batch: int control: Optional[NetworkModel] + input_prompt: str def __init__( self, @@ -187,6 +188,7 @@ def __init__( eta: float = 0.0, batch: int = 1, control: Optional[NetworkModel] = None, + input_prompt: Optional[str] = None, ) -> None: self.model = model self.pipeline = pipeline @@ -199,6 +201,7 @@ def __init__( self.eta = eta self.batch = batch self.control = control + self.input_prompt = input_prompt or prompt def lpw(self): return self.pipeline == "lpw" @@ -216,6 +219,7 @@ def tojson(self) -> Dict[str, Optional[Param]]: "eta": self.eta, "batch": self.batch, "control": self.control.name if self.control is not None else "", + "input_prompt": self.input_prompt, } def with_args(self, **kwargs): @@ -231,6 +235,7 @@ def with_args(self, **kwargs): kwargs.get("eta", self.eta), kwargs.get("batch", self.batch), kwargs.get("control", self.control), + kwargs.get("input_prompt", self.input_prompt), )