Skip to content

Commit

Permalink
fixed cpu error
Browse files Browse the repository at this point in the history
  • Loading branch information
kamwoh committed Dec 21, 2023
1 parent 6059cab commit d24f26a
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ def prepare_pipeline(model_name):
if 'dpo' in OUTPUT_DIR:
args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"

pipe = load_pipeline(args, torch.float16, 'cuda')
pipe = pipe.to(torch.float16)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

pipe = load_pipeline(args, weight_dtype, device)
pipe = pipe.to(weight_dtype)

pipe.verbose = True
pipe.v = 're'
Expand Down Expand Up @@ -116,7 +119,7 @@ def prepare_pipeline(model_name):
ID2NAME = open('data/dogs/class_names.txt').readlines()
ID2NAME = [line.strip() for line in ID2NAME]

return pipe, MAPPING, ID2NAME
return pipe, MAPPING, ID2NAME, device


def download_file(url, local_path):
Expand Down Expand Up @@ -159,11 +162,11 @@ def process_text(text, MAPPING, ID2NAME):


def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
generator = torch.Generator(device='cuda')
generator = generator.manual_seed(int(seed))

try:
pipe, MAPPING, ID2NAME = prepare_pipeline(model_name)
pipe, MAPPING, ID2NAME, device = prepare_pipeline(model_name)

generator = torch.Generator(device=device)
generator = generator.manual_seed(int(seed))

prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
Expand All @@ -179,7 +182,8 @@ def generate_images(model_name, prompt, negative_prompt, num_inference_steps, gu
f"The error message: {e}")
finally:
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return images, '; '.join(part2id)

Expand Down

0 comments on commit d24f26a

Please sign in to comment.