From 18e8a4100b61adfd9425dd32f780dc5f90017813 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Wed, 27 Dec 2023 15:23:14 -0800 Subject: [PATCH] Vae encode example with test (#294) --- python/shark_turbine/dynamo/passes.py | 56 +---------- python/shark_turbine/dynamo/utils.py | 92 +++++++++++++++++++ .../custom_models/sd_inference/vae.py | 16 +++- .../custom_models/sd_inference/vae_runner.py | 29 ++++-- python/turbine_models/tests/sd_test.py | 53 ++++++++++- 5 files changed, 181 insertions(+), 65 deletions(-) create mode 100644 python/shark_turbine/dynamo/utils.py diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 91ea40211..88c08f6ad 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -1,9 +1,9 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions, register_decomposition +from torch._decomp import get_decompositions +from shark_turbine.dynamo import utils from torch.func import functionalize -from torch import Tensor -from typing import Dict, List, Tuple +from typing import List # default decompositions pulled from SHARK / torch._decomp DEFAULT_DECOMPOSITIONS = [ @@ -53,56 +53,6 @@ ] -@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) -def scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: float = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: - dtype = query.dtype - batchSize, num_head, qSize, headSize = ( - query.shape[0], - query.shape[1], - query.shape[2], - query.shape[3], - ) - - logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float) - cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - max_q, max_k = 0, 0 - philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - debug_attn_mask = torch.empty( - [], - dtype=query.dtype, - device="cpu", - requires_grad=query.requires_grad, - ) - output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( - query, key, value, None, dropout_p, is_causal, None, scale=scale - ) - output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) - return ( - output.transpose(1, 2), - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) - - def apply_decompositions( gm: torch.fx.GraphModule, example_inputs, diff --git a/python/shark_turbine/dynamo/utils.py b/python/shark_turbine/dynamo/utils.py new file mode 100644 index 000000000..6429c2444 --- /dev/null +++ b/python/shark_turbine/dynamo/utils.py @@ -0,0 +1,92 @@ +import torch +from torch._prims_common.wrappers import out_wrapper +from torch._prims_common import ( + DeviceLikeType, + TensorLikeType, +) +import torch._refs as _refs +from torch._decomp import get_decompositions, register_decomposition +from torch import Tensor +from typing import Dict, List, Tuple, Optional + + +@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) +def scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: + dtype = query.dtype + batchSize, num_head, qSize, headSize = ( + query.shape[0], + query.shape[1], + query.shape[2], + query.shape[3], + ) + + logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float) + cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + max_q, max_k = 0, 0 + philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + debug_attn_mask = torch.empty( + [], + dtype=query.dtype, + device="cpu", + requires_grad=query.requires_grad, + ) + output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( + query, key, value, None, dropout_p, is_causal, None, scale=scale + ) + output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) + return ( + output.transpose(1, 2), + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + debug_attn_mask, + ) + + +# manually add decomposition to bypass the error that comes +# from VAE encode(inp).latent_dist.sample() failing to symbolically +# trace from torch fx. +# Expected Torch stable version: > 2.1.0 +# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 +# temporary Torch fix: https://github.com/pytorch/pytorch/issues/107170 +@register_decomposition(torch.ops.aten.randn.generator) +@out_wrapper() +def randn_generator( + *shape, + generator: Optional[torch.Generator] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # We should eventually support the generator overload. + # However, if someone passes in a None generator explicitly, + # we can jut fall back to randn.default + if generator is None: + return _refs.randn( + *shape, + dtype=dtype, + device=device, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + return NotImplemented diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py index 50a788f64..03ef85556 100644 --- a/python/turbine_models/custom_models/sd_inference/vae.py +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -53,6 +53,7 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument("--variant", type=str, default="decode") class VaeModel(torch.nn.Module): @@ -64,11 +65,15 @@ def __init__(self, hf_model_name, hf_auth_token): token=hf_auth_token, ) - def forward(self, inp): + def decode_inp(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.18215 * latents + def export_vae_model( vae_model, @@ -83,6 +88,7 @@ def export_vae_model( device=None, target_triple=None, max_alloc=None, + variant="decode", ): mapper = {} utils.save_external_weights( @@ -90,12 +96,17 @@ def export_vae_model( ) sample = (batch_size, 4, height // 8, width // 8) + if variant == "encode": + sample = (batch_size, 3, height, width) class CompiledVae(CompiledModule): params = export_parameters(vae_model) def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): - return jittable(vae_model.forward)(inp) + if variant == "decode": + return jittable(vae_model.decode_inp)(inp) + elif variant == "encode": + return jittable(vae_model.encode_inp)(inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledVae(context=Context(), import_to=import_to) @@ -127,6 +138,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.variant, ) safe_name = utils.create_safe_name(args.hf_model_name, "-vae") with open(f"{safe_name}.mlir", "w+") as f: diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py index 77b196ac0..77acaedcb 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -45,6 +45,7 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--variant", type=str, default="decode") def run_vae( @@ -57,7 +58,7 @@ def run_vae( return results -def run_torch_vae(hf_model_name, hf_auth_token, example_input): +def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input): from diffusers import AutoencoderKL class VaeModel(torch.nn.Module): @@ -69,26 +70,38 @@ def __init__(self, hf_model_name, hf_auth_token): token=hf_auth_token, ) - def forward(self, inp): + def decode_inp(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.18215 * latents + vae_model = VaeModel( hf_model_name, hf_auth_token, ) - results = vae_model.forward(example_input) + if variant == "decode": + results = vae_model.decode_inp(example_input) + elif variant == "encode": + results = vae_model.encode_inp(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output if __name__ == "__main__": args = parser.parse_args() - example_input = torch.rand( - args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 - ) + if args.variant == "decode": + example_input = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + ) + elif args.variant == "encode": + example_input = torch.rand( + args.batch_size, 3, args.height, args.width, dtype=torch.float32 + ) print("generating turbine output:") turbine_results = run_vae( args.device, @@ -109,12 +122,12 @@ def forward(self, inp): from turbine_models.custom_models.sd_inference import utils torch_output = run_torch_vae( - args.hf_model_name, args.hf_auth_token, example_input + args.hf_model_name, args.hf_auth_token, args.variant, example_input ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_results) print("Largest Error: ", err) - assert err < 9e-5 + assert err < 2e-3 # TODO: Figure out why we occasionally segfault without unlinking output variables turbine_results = None diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index b8dca64f5..125f97d82 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -134,7 +134,7 @@ def testExportUnetModel(self): os.remove("stable_diffusion_v1_4_unet.safetensors") os.remove("stable_diffusion_v1_4_unet.vmfb") - def testExportVaeModel(self): + def testExportVaeModelDecode(self): with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, @@ -148,6 +148,7 @@ def testExportVaeModel(self): "safetensors", "stable_diffusion_v1_4_vae.safetensors", "cpu", + variant="decode", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -168,13 +169,61 @@ def testExportVaeModel(self): arguments["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], arguments["hf_auth_token"], example_input + arguments["hf_model_name"], + arguments["hf_auth_token"], + "decode", + example_input, ) err = utils.largest_error(torch_output, turbine) assert err < 9e-5 os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") + def testExportVaeModelEncode(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + variant="encode", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + example_input = torch.rand( + arguments["batch_size"], + 3, + arguments["height"], + arguments["width"], + dtype=torch.float32, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + arguments["hf_auth_token"], + "encode", + example_input, + ) + err = utils.largest_error(torch_output, turbine) + assert err < 2e-3 + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4_vae.vmfb") + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)