From fc6d018f54f6fc816b2379ae1a62e16f27994d09 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 13:20:51 -0500 Subject: [PATCH] Decompose VAE for cpu --- models/turbine_models/tests/sdxl_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 3868f919..05abc70f 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,7 +93,11 @@ def test00_sdxl_pipe(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": False, + "vae": ( + False + if any(x in arguments["device"] for x in ["hip", "rocm"]) + else True + ), } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -377,7 +381,7 @@ def test04_ExportVaeModelDecode(self): "bs" + str(arguments["batch_size"]), str(arguments["height"]) + "x" + str(arguments["width"]), arguments["precision"], - "vae", + "vae" if arguments["device"] != "cpu" else "vae_decomp_attn", arguments["iree_target_triple"], ] )