Skip to content

Commit

Permalink
Merge pull request #111 from zzbuzzard/VQGAN-download-no-subprocess
Browse files Browse the repository at this point in the history
VQGANImage refactoring
  • Loading branch information
dmarx authored Apr 11, 2022
2 parents 1bb70c8 + 72b70e2 commit e1b1cba
Showing 1 changed file with 51 additions and 25 deletions.
76 changes: 51 additions & 25 deletions src/pytti/Image/VQGANImage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pathlib import Path
from os.path import exists as path_exists
import sys
import subprocess
import shutil
import os

from loguru import logger

Expand All @@ -16,6 +15,8 @@
from torchvision.transforms import functional as TF
from PIL import Image
from omegaconf import OmegaConf
import urllib.request
from tqdm import tqdm

VQGAN_MODEL = None
VQGAN_NAME = None
Expand All @@ -25,41 +26,70 @@
VQGAN_MODEL_NAMES = ["imagenet", "coco", "wikiart", "sflckr", "openimages"]
VQGAN_CONFIG_URLS = {
"imagenet": [
"curl -L -o imagenet.yaml -C - https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1"
"https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1"
],
"coco": ["curl -L -o coco.yaml -C - https://dl.nmkd.de/ai/clip/coco/coco.yaml"],
"coco": ["https://dl.nmkd.de/ai/clip/coco/coco.yaml"],
"wikiart": [
"curl -L -o wikiart.yaml -C - http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml"
"http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml"
],
"sflckr": [
"curl -L -o sflckr.yaml -C - https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1"
"https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1"
],
"faceshq": [
"curl -L -o faceshq.yaml -C - https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT"
"https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT"
],
"openimages": [
"curl -L -o openimages.yaml -C - https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1"
"https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1"
],
}
VQGAN_CHECKPOINT_URLS = {
"imagenet": [
"curl -L -o imagenet.ckpt -C - https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1"
"https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1"
],
"coco": ["curl -L -o coco.ckpt -C - https://dl.nmkd.de/ai/clip/coco/coco.ckpt"],
"coco": ["https://dl.nmkd.de/ai/clip/coco/coco.ckpt"],
"wikiart": [
"curl -L -o wikiart.ckpt -C - http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt"
"http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt"
],
"sflckr": [
"curl -L -o sflckr.ckpt -C - https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1"
"https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1"
],
"faceshq": [
"curl -L -o faceshq.ckpt -C - https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt"
"https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt"
],
"openimages": [
"curl -L -o openimages.ckpt -C - https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1"
"https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1"
],
}

def _download(url, dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)

with urllib.request.urlopen(url) as source:
file_size = int(source.info().get("Content-Length"))

# Check if file already downloaded
if os.path.isfile(dest):
if os.path.getsize(dest) == file_size:
return True
else:
logger.warning(
f"WARNING: Pre-existing file at {dest} does not match the download size, overwriting."
)

print(f"Downloading {url} to {dest} ({file_size//1024}KB)")

with open(dest, "wb") as output, tqdm(total=file_size) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

return os.path.getsize(dest) == file_size



def load_vqgan_model(config_path, checkpoint_path):
config = OmegaConf.load(config_path)
Expand Down Expand Up @@ -235,18 +265,15 @@ def init_vqgan(model_name, model_artifacts_path, device=DEVICE):
logger.debug(vqgan_config.absolute())
logger.debug(vqgan_checkpoint.absolute())
logger.debug(vqgan_checkpoint)
# good lord... the nested if statements and calling curl with subprocess... so much about this needs to change.
# for now, let's just use it the way it is and copy the file where it needs to go.
# if not path_exists(vqgan_config):

if not vqgan_config.exists():
logger.warning(
f"WARNING: VQGAN config file {vqgan_config} not found. Initializing download."
)
command = VQGAN_CONFIG_URLS[model_name][0].split(" ", 6)
subprocess.run(command)
shutil.move(vqgan_config.name, vqgan_config)

if not path_exists(vqgan_config):
url = VQGAN_CONFIG_URLS[model_name][0]

if not _download(url, vqgan_config):
logger.critical(
f"ERROR: VQGAN model {model_name} config failed to download! Please contact model host or find a new one."
)
Expand All @@ -256,11 +283,10 @@ def init_vqgan(model_name, model_artifacts_path, device=DEVICE):
logger.warning(
f"WARNING: VQGAN checkpoint file {vqgan_checkpoint} not found. Initializing download."
)
command = VQGAN_CHECKPOINT_URLS[model_name][0].split(" ", 6)
subprocess.run(command)
shutil.move(vqgan_checkpoint.name, vqgan_checkpoint)

if not path_exists(vqgan_checkpoint):
url = VQGAN_CHECKPOINT_URLS[model_name][0]

if not _download(url, vqgan_checkpoint):
logger.critical(
f"ERROR: VQGAN model {model_name} checkpoint failed to download! Please contact model host or find a new one."
)
Expand Down

0 comments on commit e1b1cba

Please sign in to comment.