Skip to content

Commit

Permalink
Merge pull request #174 from pytti-tools/test
Browse files Browse the repository at this point in the history
* added support for device configuration (resolving forced usage of cuda:0)
* refactored image_model's
* added tests, docstrings, typehints
* isolated cutouts code to facilitate increased cutout control, future support for Dango's method, etc.
* POC'd backwards-compatible approach for adding config options using open_dict context manager
* fixed tests that used local paths
  • Loading branch information
dmarx authored Jun 10, 2022
2 parents 326bcba + 38562f3 commit 0b28f62
Show file tree
Hide file tree
Showing 31 changed files with 629 additions and 208 deletions.
54 changes: 54 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[[source]]
url = "https://download.pytorch.org/whl/cu113/"
verify_ssl = false
name = "pytorch"

[packages]
transformers = "==4.15.0"
gdown = "===4.2.0"
ftfy = "==6.0.3"
regex = "*"
tqdm = "==4.62.3"
omegaconf = "==2.1.1"
pytorch-lightning = "==1.5.7"
kornia = "==0.6.2"
einops = "==0.3.2"
imageio-ffmpeg = "==0.4.5"
exrex = "*"
matplotlib-label-lines = "==0.4.3"
pandas = "==1.3.4"
seaborn = "==0.11.2"
scikit-learn = "*"
loguru = "*"
hydra-core = "*"
jupyter = "*"
imageio = "==2.4.1"
PyGLM = "==2.5.7"
adjustText = "*"
Pillow = "*"
torch = "*"
torchvision = "*"
torchaudio = "*"
requests = "*"
pyttitools-adabins = {path = "./vendor/AdaBins"}
pyttitools-gma = {path = "./vendor/GMA"}
clip = {path = "./vendor/CLIP"}
pyttitools-taming-transformers = {path = "./vendor/taming-transformers"}
tensorflow = "*"
protobuf = "==3.9.2"
pyttitools-core = {path = "."}
mmc = {git = "https://github.com/dmarx/Multi-Modal-Comparators"}

[dev-packages]
pytest = "*"
pre-commit = "*"
click = "==8.0.4"
black = "*"

[requires]
python_version = "3.9"
11 changes: 0 additions & 11 deletions src/pytti/Image/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
vram_usage_mode,
)
from pytti.AudioParse import SpectralAudioParser
from pytti.Image.differentiable_image import DifferentiableImage
from pytti.Image.PixelImage import PixelImage
from pytti.image_models.differentiable_image import DifferentiableImage
from pytti.image_models.pixel import PixelImage
from pytti.Notebook import tqdm, make_hbox

# from pytti.rotoscoper import update_rotoscopers
Expand Down
12 changes: 8 additions & 4 deletions src/pytti/LossAug/BaseLossClass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from torch import nn
from pytti import DEVICE, replace_grad, parametric_eval
from pytti import replace_grad, parametric_eval


class Loss(nn.Module):
def __init__(self, weight, stop, name):
def __init__(self, weight, stop, name, device=None):
super().__init__()
# self.register_buffer('weight', torch.as_tensor(weight))
# self.register_buffer('stop', torch.as_tensor(stop))
Expand All @@ -13,6 +13,9 @@ def __init__(self, weight, stop, name):
self.input_axes = ("n", "s", "y", "x")
self.name = name
self.enabled = True
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device

def get_loss(self, input, img):
raise NotImplementedError
Expand All @@ -29,10 +32,11 @@ def set_stop(stop):
def __str__(self):
return self.name

def forward(self, input, img, device=DEVICE):
def forward(self, input, img, device=None):
if not self.enabled or self.weight in [0, "0"]:
return 0, 0

if device is None:
device = self.device
weight = torch.as_tensor(parametric_eval(self.weight), device=device)
stop = torch.as_tensor(parametric_eval(self.stop), device=device)
loss_raw = self.get_loss(input, img)
Expand Down
16 changes: 10 additions & 6 deletions src/pytti/LossAug/DepthLossClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
infer_helper = None


def init_AdaBins():
def init_AdaBins(device=None):
global infer_helper
if infer_helper is None:
with vram_usage_mode("AdaBins"):
logger.debug("Loading AdaBins...")
infer_helper = InferenceHelper(dataset="nyu")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
infer_helper = InferenceHelper(dataset="nyu", device=device)
logger.debug("AdaBins loaded.")


Expand Down Expand Up @@ -55,13 +57,15 @@ def get_loss(self, input, img):

@classmethod
@vram_usage_mode("Depth Loss")
def make_comp(cls, pil_image, device=DEVICE):
depth, _ = DepthLoss.get_depth(pil_image)
def make_comp(cls, pil_image, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
depth, _ = DepthLoss.get_depth(pil_image, device=device)
return torch.from_numpy(depth).to(device)

@staticmethod
def get_depth(pil_image):
init_AdaBins()
def get_depth(pil_image, device=None):
init_AdaBins(device=device)
width, height = pil_image.size

# if the area of an image is above this, the depth model fails
Expand Down
3 changes: 2 additions & 1 deletion src/pytti/LossAug/LossOrchestratorClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from loguru import logger
from PIL import Image

from pytti.Image import PixelImage
from pytti.image_models import PixelImage

# from pytti.LossAug import build_loss
from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss
Expand Down Expand Up @@ -125,6 +125,7 @@ def configure_optical_flows(img, params, loss_augs):
TargetFlowLoss.TargetImage(
f"optical flow stabilization:{params.flow_stabilization_weight}",
img.image_shape,
device="cuda",
)
]
for optical_flow in optical_flows:
Expand Down
28 changes: 18 additions & 10 deletions src/pytti/LossAug/MSELossClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# from pytti.Notebook import Rotoscoper
from pytti.rotoscoper import Rotoscoper
from pytti import DEVICE, fetch, parse, vram_usage_mode
from pytti import fetch, parse, vram_usage_mode
import torch


Expand All @@ -19,22 +19,22 @@ def __init__(
stop=-math.inf,
name="direct target loss",
image_shape=None,
device=DEVICE,
device=None,
):
super().__init__(weight, stop, name)
super().__init__(weight, stop, name, device)
self.register_buffer("comp", comp)
if image_shape is None:
height, width = comp.shape[-2:]
image_shape = (width, height)
self.image_shape = image_shape
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=device))
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device))
self.use_mask = False

@classmethod
@vram_usage_mode("Loss Augs")
@torch.no_grad()
def TargetImage(
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None
):
# Why is this prompt parsing stuff here? Deprecate in favor of centralized
# parsing functions (if feasible)
Expand All @@ -44,6 +44,8 @@ def TargetImage(
weight, mask = parse(weight, r"_", ["1", ""])
text = text.strip()
mask = mask.strip()
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if pil_image is None and text != "" and is_path:
pil_image = Image.open(fetch(text)).convert("RGB")
im = pil_image.resize(image_shape, Image.LANCZOS)
Expand All @@ -55,12 +57,14 @@ def TargetImage(
comp = cls.make_comp(im)
if image_shape is None:
image_shape = pil_image.size
out = cls(comp, weight, stop, text + " (direct)", image_shape)
out = cls(comp, weight, stop, text + " (direct)", image_shape, device=device)
out.set_mask(mask)
return out

@torch.no_grad()
def set_mask(self, mask, inverted=False, device=DEVICE):
def set_mask(self, mask, inverted=False, device=None):
if device is None:
device = self.device
if isinstance(mask, str) and mask != "":
if mask[0] == "-":
mask = mask[1:]
Expand All @@ -86,16 +90,20 @@ def convert_input(cls, input, img):
return input

@classmethod
def make_comp(cls, pil_image, device=DEVICE):
def make_comp(cls, pil_image, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
out = (
TF.to_tensor(pil_image)
.unsqueeze(0)
.to(device, memory_format=torch.channels_last)
)
return cls.convert_input(out, None)

def set_comp(self, pil_image, device=DEVICE):
self.comp.set_(type(self).make_comp(pil_image))
def set_comp(self, pil_image, device=None):
if device is None:
device = self.device
self.comp.set_(type(self).make_comp(pil_image, device=device))

def get_loss(self, input, img):
input = type(self).convert_input(input, img)
Expand Down
Loading

0 comments on commit 0b28f62

Please sign in to comment.