Skip to content

Commit

Permalink
Vae encode example with test (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviator19941 authored Dec 27, 2023
1 parent 406b523 commit 18e8a41
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 65 deletions.
56 changes: 3 additions & 53 deletions python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions, register_decomposition
from torch._decomp import get_decompositions
from shark_turbine.dynamo import utils
from torch.func import functionalize
from torch import Tensor
from typing import Dict, List, Tuple
from typing import List

# default decompositions pulled from SHARK / torch._decomp
DEFAULT_DECOMPOSITIONS = [
Expand Down Expand Up @@ -53,56 +53,6 @@
]


@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)


def apply_decompositions(
gm: torch.fx.GraphModule,
example_inputs,
Expand Down
92 changes: 92 additions & 0 deletions python/shark_turbine/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from torch._prims_common.wrappers import out_wrapper
from torch._prims_common import (
DeviceLikeType,
TensorLikeType,
)
import torch._refs as _refs
from torch._decomp import get_decompositions, register_decomposition
from torch import Tensor
from typing import Dict, List, Tuple, Optional


@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)


# manually add decomposition to bypass the error that comes
# from VAE encode(inp).latent_dist.sample() failing to symbolically
# trace from torch fx.
# Expected Torch stable version: > 2.1.0
# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239
# temporary Torch fix: https://github.com/pytorch/pytorch/issues/107170
@register_decomposition(torch.ops.aten.randn.generator)
@out_wrapper()
def randn_generator(
*shape,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[DeviceLikeType] = None,
layout: Optional[torch.layout] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> TensorLikeType:
# We should eventually support the generator overload.
# However, if someone passes in a None generator explicitly,
# we can jut fall back to randn.default
if generator is None:
return _refs.randn(
*shape,
dtype=dtype,
device=device,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
return NotImplemented
16 changes: 14 additions & 2 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument("--variant", type=str, default="decode")


class VaeModel(torch.nn.Module):
Expand All @@ -64,11 +65,15 @@ def __init__(self, hf_model_name, hf_auth_token):
token=hf_auth_token,
)

def forward(self, inp):
def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents


def export_vae_model(
vae_model,
Expand All @@ -83,19 +88,25 @@ def export_vae_model(
device=None,
target_triple=None,
max_alloc=None,
variant="decode",
):
mapper = {}
utils.save_external_weights(
mapper, vae_model, external_weights, external_weight_path
)

sample = (batch_size, 4, height // 8, width // 8)
if variant == "encode":
sample = (batch_size, 3, height, width)

class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
return jittable(vae_model.forward)(inp)
if variant == "decode":
return jittable(vae_model.decode_inp)(inp)
elif variant == "encode":
return jittable(vae_model.encode_inp)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledVae(context=Context(), import_to=import_to)
Expand Down Expand Up @@ -127,6 +138,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.variant,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
with open(f"{safe_name}.mlir", "w+") as f:
Expand Down
29 changes: 21 additions & 8 deletions python/turbine_models/custom_models/sd_inference/vae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--variant", type=str, default="decode")


def run_vae(
Expand All @@ -57,7 +58,7 @@ def run_vae(
return results


def run_torch_vae(hf_model_name, hf_auth_token, example_input):
def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input):
from diffusers import AutoencoderKL

class VaeModel(torch.nn.Module):
Expand All @@ -69,26 +70,38 @@ def __init__(self, hf_model_name, hf_auth_token):
token=hf_auth_token,
)

def forward(self, inp):
def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents

vae_model = VaeModel(
hf_model_name,
hf_auth_token,
)

results = vae_model.forward(example_input)
if variant == "decode":
results = vae_model.decode_inp(example_input)
elif variant == "encode":
results = vae_model.encode_inp(example_input)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output


if __name__ == "__main__":
args = parser.parse_args()
example_input = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
if args.variant == "decode":
example_input = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
elif args.variant == "encode":
example_input = torch.rand(
args.batch_size, 3, args.height, args.width, dtype=torch.float32
)
print("generating turbine output:")
turbine_results = run_vae(
args.device,
Expand All @@ -109,12 +122,12 @@ def forward(self, inp):
from turbine_models.custom_models.sd_inference import utils

torch_output = run_torch_vae(
args.hf_model_name, args.hf_auth_token, example_input
args.hf_model_name, args.hf_auth_token, args.variant, example_input
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_results)
print("Largest Error: ", err)
assert err < 9e-5
assert err < 2e-3

# TODO: Figure out why we occasionally segfault without unlinking output variables
turbine_results = None
53 changes: 51 additions & 2 deletions python/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def testExportUnetModel(self):
os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4_unet.vmfb")

def testExportVaeModel(self):
def testExportVaeModelDecode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
Expand All @@ -148,6 +148,7 @@ def testExportVaeModel(self):
"safetensors",
"stable_diffusion_v1_4_vae.safetensors",
"cpu",
variant="decode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors"
Expand All @@ -168,13 +169,61 @@ def testExportVaeModel(self):
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"], arguments["hf_auth_token"], example_input
arguments["hf_model_name"],
arguments["hf_auth_token"],
"decode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 9e-5
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")

def testExportVaeModelEncode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
"stable_diffusion_v1_4_vae.safetensors",
"cpu",
variant="encode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors"
arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb"
example_input = torch.rand(
arguments["batch_size"],
3,
arguments["height"],
arguments["width"],
dtype=torch.float32,
)
turbine = vae_runner.run_vae(
arguments["device"],
example_input,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"],
arguments["hf_auth_token"],
"encode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 2e-3
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 18e8a41

Please sign in to comment.