diff --git a/src/pytti/Image/VQGANImage.py b/src/pytti/Image/VQGANImage.py index 30ca279..b840d60 100644 --- a/src/pytti/Image/VQGANImage.py +++ b/src/pytti/Image/VQGANImage.py @@ -1,8 +1,7 @@ from pathlib import Path from os.path import exists as path_exists import sys -import subprocess -import shutil +import os from loguru import logger @@ -16,6 +15,8 @@ from torchvision.transforms import functional as TF from PIL import Image from omegaconf import OmegaConf +import urllib.request +from tqdm import tqdm VQGAN_MODEL = None VQGAN_NAME = None @@ -25,41 +26,70 @@ VQGAN_MODEL_NAMES = ["imagenet", "coco", "wikiart", "sflckr", "openimages"] VQGAN_CONFIG_URLS = { "imagenet": [ - "curl -L -o imagenet.yaml -C - https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1" + "https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1" ], - "coco": ["curl -L -o coco.yaml -C - https://dl.nmkd.de/ai/clip/coco/coco.yaml"], + "coco": ["https://dl.nmkd.de/ai/clip/coco/coco.yaml"], "wikiart": [ - "curl -L -o wikiart.yaml -C - http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml" + "http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml" ], "sflckr": [ - "curl -L -o sflckr.yaml -C - https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1" + "https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1" ], "faceshq": [ - "curl -L -o faceshq.yaml -C - https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT" + "https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT" ], "openimages": [ - "curl -L -o openimages.yaml -C - https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1" + "https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1" ], } VQGAN_CHECKPOINT_URLS = { "imagenet": [ - "curl -L -o imagenet.ckpt -C - https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1" + "https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1" ], - "coco": ["curl -L -o coco.ckpt -C - https://dl.nmkd.de/ai/clip/coco/coco.ckpt"], + "coco": ["https://dl.nmkd.de/ai/clip/coco/coco.ckpt"], "wikiart": [ - "curl -L -o wikiart.ckpt -C - http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt" + "http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt" ], "sflckr": [ - "curl -L -o sflckr.ckpt -C - https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1" + "https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1" ], "faceshq": [ - "curl -L -o faceshq.ckpt -C - https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt" + "https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt" ], "openimages": [ - "curl -L -o openimages.ckpt -C - https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1" + "https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1" ], } +def _download(url, dest): + os.makedirs(os.path.dirname(dest), exist_ok=True) + + with urllib.request.urlopen(url) as source: + file_size = int(source.info().get("Content-Length")) + + # Check if file already downloaded + if os.path.isfile(dest): + if os.path.getsize(dest) == file_size: + return True + else: + logger.warning( + f"WARNING: Pre-existing file at {dest} does not match the download size, overwriting." + ) + + print(f"Downloading {url} to {dest} ({file_size//1024}KB)") + + with open(dest, "wb") as output, tqdm(total=file_size) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + return os.path.getsize(dest) == file_size + + def load_vqgan_model(config_path, checkpoint_path): config = OmegaConf.load(config_path) @@ -235,18 +265,15 @@ def init_vqgan(model_name, model_artifacts_path, device=DEVICE): logger.debug(vqgan_config.absolute()) logger.debug(vqgan_checkpoint.absolute()) logger.debug(vqgan_checkpoint) - # good lord... the nested if statements and calling curl with subprocess... so much about this needs to change. - # for now, let's just use it the way it is and copy the file where it needs to go. - # if not path_exists(vqgan_config): + if not vqgan_config.exists(): logger.warning( f"WARNING: VQGAN config file {vqgan_config} not found. Initializing download." ) - command = VQGAN_CONFIG_URLS[model_name][0].split(" ", 6) - subprocess.run(command) - shutil.move(vqgan_config.name, vqgan_config) - if not path_exists(vqgan_config): + url = VQGAN_CONFIG_URLS[model_name][0] + + if not _download(url, vqgan_config): logger.critical( f"ERROR: VQGAN model {model_name} config failed to download! Please contact model host or find a new one." ) @@ -256,11 +283,10 @@ def init_vqgan(model_name, model_artifacts_path, device=DEVICE): logger.warning( f"WARNING: VQGAN checkpoint file {vqgan_checkpoint} not found. Initializing download." ) - command = VQGAN_CHECKPOINT_URLS[model_name][0].split(" ", 6) - subprocess.run(command) - shutil.move(vqgan_checkpoint.name, vqgan_checkpoint) - if not path_exists(vqgan_checkpoint): + url = VQGAN_CHECKPOINT_URLS[model_name][0] + + if not _download(url, vqgan_checkpoint): logger.critical( f"ERROR: VQGAN model {model_name} checkpoint failed to download! Please contact model host or find a new one." )