From 6e7f202b8df50c6a1d19e00802f9b85dcc188c59 Mon Sep 17 00:00:00 2001
From: Sean Sube <seansube@gmail.com>
Date: Sat, 22 Apr 2023 00:11:33 -0500
Subject: [PATCH] fix(api): store both pre-parse and parsed prompts (#320)

---
 api/onnx_web/diffusers/run.py | 25 ++++++++++++-------------
 api/onnx_web/params.py        |  5 +++++
 2 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py
index 1f13b59a1..125878155 100644
--- a/api/onnx_web/diffusers/run.py
+++ b/api/onnx_web/diffusers/run.py
@@ -30,6 +30,14 @@
 logger = getLogger(__name__)
 
 
+def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tuple[str, float]]]:
+    prompt, loras = get_loras_from_prompt(params.input_prompt)
+    prompt, inversions = get_inversions_from_prompt(prompt)
+    params.prompt = prompt
+
+    return loras, inversions
+
+
 def run_highres(
     job: WorkerContext,
     server: ServerContext,
@@ -164,10 +172,7 @@ def run_txt2img_pipeline(
     highres: HighresParams,
 ) -> None:
     latents = get_latents_from_seed(params.seed, size, batch=params.batch)
-
-    (prompt, loras) = get_loras_from_prompt(params.prompt)
-    (prompt, inversions) = get_inversions_from_prompt(prompt)
-    params.prompt = prompt
+    loras, inversions = parse_prompt(params)
 
     pipe_type = "lpw" if params.lpw() else "txt2img"
     pipe = load_pipeline(
@@ -260,9 +265,7 @@ def run_img2img_pipeline(
     strength: float,
     source_filter: Optional[str] = None,
 ) -> None:
-    (prompt, loras) = get_loras_from_prompt(params.prompt)
-    (prompt, inversions) = get_inversions_from_prompt(prompt)
-    params.prompt = prompt
+    loras, inversions = parse_prompt(params)
 
     # filter the source image
     if source_filter is not None:
@@ -376,9 +379,7 @@ def run_inpaint_pipeline(
     progress = job.get_progress_callback()
     stage = StageParams(tile_order=tile_order)
 
-    (prompt, loras) = get_loras_from_prompt(params.prompt)
-    (prompt, inversions) = get_inversions_from_prompt(prompt)
-    params.prompt = prompt
+    loras, inversions = parse_prompt(params)
 
     # calling the upscale_outpaint stage directly needs accumulating progress
     progress = ChainProgress.from_progress(progress)
@@ -444,9 +445,7 @@ def run_upscale_pipeline(
     progress = job.get_progress_callback()
     stage = StageParams()
 
-    (prompt, loras) = get_loras_from_prompt(params.prompt)
-    (prompt, inversions) = get_inversions_from_prompt(prompt)
-    params.prompt = prompt
+    loras, inversions = parse_prompt(params)
 
     image = run_upscale_correction(
         job, server, stage, params, source, upscale=upscale, callback=progress
diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py
index ddfdf4ac4..ac2c06455 100644
--- a/api/onnx_web/params.py
+++ b/api/onnx_web/params.py
@@ -173,6 +173,7 @@ class ImageParams:
     eta: float
     batch: int
     control: Optional[NetworkModel]
+    input_prompt: str
 
     def __init__(
         self,
@@ -187,6 +188,7 @@ def __init__(
         eta: float = 0.0,
         batch: int = 1,
         control: Optional[NetworkModel] = None,
+        input_prompt: Optional[str] = None,
     ) -> None:
         self.model = model
         self.pipeline = pipeline
@@ -199,6 +201,7 @@ def __init__(
         self.eta = eta
         self.batch = batch
         self.control = control
+        self.input_prompt = input_prompt or prompt
 
     def lpw(self):
         return self.pipeline == "lpw"
@@ -216,6 +219,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
             "eta": self.eta,
             "batch": self.batch,
             "control": self.control.name if self.control is not None else "",
+            "input_prompt": self.input_prompt,
         }
 
     def with_args(self, **kwargs):
@@ -231,6 +235,7 @@ def with_args(self, **kwargs):
             kwargs.get("eta", self.eta),
             kwargs.get("batch", self.batch),
             kwargs.get("control", self.control),
+            kwargs.get("input_prompt", self.input_prompt),
         )