Skip to content

Commit

Permalink
compiled_pipeline general support and split inference methods
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Jul 18, 2024
1 parent 02705a9 commit c1e9195
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 43 deletions.
12 changes: 9 additions & 3 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,8 @@ def export_submodel(
self.map[submodel]["export_args"]["precision"],
self.map[submodel]["export_args"]["batch_size"],
self.map[submodel]["export_args"]["max_length"],
"tokens_to_image",
"produce_img_split",
unet_module_name = self.map["unet"]["module_name"],
)
dims = [
self.map[submodel]["export_args"]["width"],
Expand Down Expand Up @@ -699,8 +700,8 @@ def export_submodel(
return_path=True,
mlir_source="str",
)
self.map[submodel]["vmfb"] = vmfb_path
self.map[submodel]["weights"] = None
self.map[submodel]["vmfb"] = [vmfb_path]
self.map[submodel]["weights"] = []
case _:
export_args = self.map[submodel].get("export_args", {})
if weights_only:
Expand All @@ -725,6 +726,11 @@ def load_map(self):
if not self.map[submodel]["load"]:
self.printer.print(f"Skipping load for {submodel}")
continue
elif self.map[submodel].get("wraps"):
for wrapped in self.map[submodel]["wraps"]:
self.map[submodel]["vmfb"].append(self.map[wrapped]["vmfb"])
self.map[submodel]["weights"].append(self.map[wrapped]["weights"])

self.load_submodel(submodel)

def load_submodel(self, submodel):
Expand Down
145 changes: 110 additions & 35 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,11 @@
"decomp_attn": None,
},
},
"unetloop": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"keywords": ["unetloop"],
"wraps": ["unet", "scheduler"],
"export_args": {
"batch_size": 1,
"height": 1024,
"width": 1024,
"max_length": 64,
},
},
"fullpipeline": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"load": True,
"keywords": ["fullpipeline"],
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
"wraps": ["unet", "scheduler", "vae"],
"export_args": {
"batch_size": 1,
"height": 1024,
Expand Down Expand Up @@ -234,6 +222,7 @@ def __init__(
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
compiled_pipeline: bool = False,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -312,6 +301,7 @@ def __init__(
self.scheduler = None

self.split_scheduler = True
self.compiled_pipeline = compiled_pipeline

self.base_model_name = (
hf_model_name
Expand All @@ -322,11 +312,6 @@ def __init__(
self.is_sdxl = "xl" in self.base_model_name.lower()
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
if self.is_sdxl:
if self.split_scheduler:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
self.tokenizers = [
CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -340,6 +325,20 @@ def __init__(
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
self.scheduler_target = self.map["unet"]["target"]
if not self.compiled_pipeline:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
elif self.compiled_pipeline:
self.map["unet"]["load"] = False
self.map["vae"]["load"] = False
self.load_scheduler(
scheduler_id,
num_inference_steps,
)
self.map["scheduler"]["runner"].unload()
self.map["scheduler"]["load"] = False
elif not self.is_sd3:
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand Down Expand Up @@ -381,10 +380,6 @@ def load_scheduler(
scheduler_id: str,
steps: int = 30,
):
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
if not self.cpu_scheduling:
self.map["scheduler"] = {
"module_name": "compiled_scheduler",
Expand Down Expand Up @@ -430,7 +425,11 @@ def load_scheduler(
except:
print("JIT export of scheduler failed. Loading CPU scheduler.")
self.cpu_scheduling = True
if self.cpu_scheduling:
elif self.cpu_scheduling:
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
scheduler,
Expand Down Expand Up @@ -615,6 +614,72 @@ def _produce_latents_sdxl(
latents = self.scheduler("run_step", [noise_pred, t, latents])
return latents

def produce_images_compiled(
sample,
prompt_embeds,
text_embeds,
guidance_scale,
):
pipe_inputs = [
sample,
prompt_embeds,
text_embeds,
guidance_scale,
]
image = self.compiled_pipeline("produce_img_latents", pipe_inputs)

def prepare_sampling_inputs(
self,
prompt: str,
negative_prompt: str = "",
steps: int = 30,
batch_count: int = 1,
guidance_scale: float = 7.5,
seed: float = -1,
cpu_scheduling: bool = True,
scheduler_id: str = "EulerDiscrete",
return_imgs: bool = False,
):
needs_new_scheduler = (
(steps and steps != self.num_inference_steps)
or (cpu_scheduling != self.cpu_scheduling)
and self.split_scheduler
)
if not self.scheduler and not self.compiled_pipeline:
needs_new_scheduler = True

if guidance_scale == 0:
negative_prompt = prompt
prompt = ""

self.cpu_scheduling = cpu_scheduling
if steps and needs_new_scheduler:
self.num_inference_steps = steps
self.load_scheduler(scheduler_id, steps)

pipe_start = time.time()
numpy_images = []

samples = self.get_rand_latents(seed, batch_count)

# Tokenize prompt and negative prompt.
if self.is_sdxl:
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
prompt, negative_prompt
)
else:
prompt_embeds, negative_embeds = encode_prompt(
self, prompt, negative_prompt
)
produce_latents_input = [
samples[0],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
return produce_latents_input

def generate_images(
self,
prompt: str,
Expand Down Expand Up @@ -660,18 +725,26 @@ def generate_images(
)

for i in range(batch_count):
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
if self.compiled_pipeline:
image = produce_images_compiled(
samples[i],
prompt_embeds,
negative_embeds,
guidance_scale
)
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
numpy_images.append(image)
pipe_end = time.time()

Expand Down Expand Up @@ -757,6 +830,8 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
False,
args.compiled_pipeline,
)
sd_pipe.prepare_all()
sd_pipe.load_map()
Expand Down
Loading

0 comments on commit c1e9195

Please sign in to comment.