Skip to content

Commit

Permalink
torch.compile ae.decode
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Sep 27, 2024
1 parent 99cfbb7 commit e718818
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,66 @@ def base_setup(
shared_models=shared_models,
)

if compile_fp8 or compile_bf16:
self.compile_ae()

if compile_fp8:
self.compile_fp8()

if compile_bf16:
self.compile_bf16()

@torch.inference_mode()
def compile_ae(self):
# helpful: export TORCH_COMPILE_DEBUG=1 TORCH_LOGS=dynamic,dynamo

# the order is important:
# torch.compile has to recompile if it makes invalid assumptions
# about the input sizes. Having higher input sizes first makes
# for fewer recompiles.
VAE_SIZES = [
[1, 16, 192, 168],
[1, 16, 96, 96],
[1, 16, 96, 168],
[1, 16, 128, 128],
[1, 16, 96, 168],
[1, 16, 80, 192],
[1, 16, 104, 152],
[1, 16, 152, 104],
[1, 16, 136, 112],
[1, 16, 112, 136],
[1, 16, 144, 112],
[1, 16, 112, 144],
[1, 16, 168, 96],
[1, 16, 192, 80],
[4, 16, 128, 128],
]
print("compiling AE")
st = time.time()
device = torch.device('cuda')
if self.offload:
self.ae.decoder.to(device)

self.ae.decoder = torch.compile(self.ae.decoder)

# actual compilation happens when you give it inputs
for f in VAE_SIZES:
print("Compiling AE for size", f)
x = torch.rand(f, device=device)
torch._dynamo.mark_dynamic(x, 0, min=1, max=4)
torch._dynamo.mark_dynamic(x, 2, min=80)
torch._dynamo.mark_dynamic(x, 3, min=80)
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16, cache_enabled=False
):
self.ae.decode(x)

if self.offload:
self.ae.decoder.cpu()
torch.cuda.empty_cache()
print("compiled AE in ", time.time() - st)


def compile_fp8(self):
print("compiling fp8 model")
st = time.time()
Expand Down

0 comments on commit e718818

Please sign in to comment.