Skip to content

Commit

Permalink
Fix ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Sep 27, 2024
1 parent e718818 commit 99cecf1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

import torch
import torch._dynamo as dynamo

torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -183,7 +184,7 @@ def compile_ae(self):
# torch.compile has to recompile if it makes invalid assumptions
# about the input sizes. Having higher input sizes first makes
# for fewer recompiles.
VAE_SIZES = [
vae_sizes = [
[1, 16, 192, 168],
[1, 16, 96, 96],
[1, 16, 96, 168],
Expand All @@ -202,21 +203,21 @@ def compile_ae(self):
]
print("compiling AE")
st = time.time()
device = torch.device('cuda')
device = torch.device("cuda")
if self.offload:
self.ae.decoder.to(device)

self.ae.decoder = torch.compile(self.ae.decoder)

# actual compilation happens when you give it inputs
for f in VAE_SIZES:
for f in vae_sizes:
print("Compiling AE for size", f)
x = torch.rand(f, device=device)
torch._dynamo.mark_dynamic(x, 0, min=1, max=4)
torch._dynamo.mark_dynamic(x, 2, min=80)
torch._dynamo.mark_dynamic(x, 3, min=80)
dynamo.mark_dynamic(x, 0, min=1, max=4)
dynamo.mark_dynamic(x, 2, min=80)
dynamo.mark_dynamic(x, 3, min=80)
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16, cache_enabled=False
device_type=device.type, dtype=torch.bfloat16, cache_enabled=False
):
self.ae.decode(x)

Expand All @@ -225,7 +226,6 @@ def compile_ae(self):
torch.cuda.empty_cache()
print("compiled AE in ", time.time() - st)


def compile_fp8(self):
print("compiling fp8 model")
st = time.time()
Expand Down

0 comments on commit 99cecf1

Please sign in to comment.