Skip to content

Commit

Permalink
Merge pull request #110 - refactor loss configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx authored Apr 4, 2022
2 parents f50c9a7 + 9f8ccd1 commit 1bb70c8
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 236 deletions.
10 changes: 7 additions & 3 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from pytti.Image.differentiable_image import DifferentiableImage
from pytti.Image.PixelImage import PixelImage
from pytti.Notebook import tqdm, make_hbox
from pytti.rotoscoper import update_rotoscopers

# from pytti.rotoscoper import update_rotoscopers
from pytti.rotoscoper import ROTOSCOPERS
from pytti.Transforms import (
animate_2d,
zoom_3d,
Expand Down Expand Up @@ -82,7 +84,8 @@ def __init__(
base_name=None,
fig=None,
axs=None,
video_frames=None,
#####################
video_frames=None, # # only need this to pass to animate_video_source
optical_flows=None,
stabilization_augs=None,
last_frame_semantic=None,
Expand Down Expand Up @@ -517,7 +520,8 @@ def update(
# next_step_pil = None
if (i - params.pre_animation_steps) % params.steps_per_frame == 0:
logger.debug(f"Time: {t:.4f} seconds")
update_rotoscopers(
# update_rotoscopers(
ROTOSCOPERS.update_rotoscopers(
((i - params.pre_animation_steps) // params.steps_per_frame + 1)
* params.frame_stride
)
Expand Down
33 changes: 12 additions & 21 deletions src/pytti/LossAug/DepthLossClass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# from infer import InferenceHelper
import gc
import math

from adabins.infer import InferenceHelper
from pytti.LossAug.MSELossClass import MSELoss
import gc, torch, os, math
from pytti import DEVICE, vram_usage_mode
from torchvision.transforms import functional as TF
from torch.nn import functional as F
from PIL import Image, ImageOps
from loguru import logger
from PIL import Image
import torch
from torch.nn import functional as F
from torchvision.transforms import functional as TF

from pytti import DEVICE, vram_usage_mode
from pytti.LossAug.MSELossClass import MSELoss


infer_helper = None

Expand All @@ -16,11 +20,6 @@ def init_AdaBins():
if infer_helper is None:
with vram_usage_mode("AdaBins"):
logger.debug("Loading AdaBins...")
# os.chdir('AdaBins')
# try:
# infer_helper = InferenceHelper(dataset='nyu')
# finally:
# os.chdir('..')
infer_helper = InferenceHelper(dataset="nyu")
logger.debug("AdaBins loaded.")

Expand All @@ -45,16 +44,13 @@ def get_loss(self, input, img):
depth_input = TF.resize(
input, (height, width), interpolation=TF.InterpolationMode.BILINEAR
)
depth_resized = True
else:
depth_input = input
depth_resized = False

_, depth_map = infer_helper.model(depth_input)
depth_map = F.interpolate(
depth_map, self.comp.shape[-2:], mode="bilinear", align_corners=True
)
# depth_map = F.interpolate(depth_map, (height, width), mode='bilinear', align_corners=True)
return super().get_loss(depth_map, img)

@classmethod
Expand All @@ -81,14 +77,9 @@ def get_depth(pil_image):
else:
depth_input = pil_image
depth_resized = False
# run the depth model (whatever that means)

gc.collect()
torch.cuda.empty_cache()
# os.chdir('AdaBins')
# try:
# _, depth_map = infer_helper.predict_pil(depth_input)
# finally:
# os.chdir('..')
_, depth_map = infer_helper.predict_pil(depth_input)
gc.collect()
torch.cuda.empty_cache()
Expand Down
254 changes: 254 additions & 0 deletions src/pytti/LossAug/LossOrchestratorClass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from IPython import display
from loguru import logger
from PIL import Image

from pytti.Image import PixelImage

# from pytti.LossAug import build_loss
from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss
from pytti.Perceptor.Prompt import parse_prompt

from pytti.LossAug.BaseLossClass import Loss
from pytti.LossAug.DepthLossClass import DepthLoss
from pytti.LossAug.EdgeLossClass import EdgeLoss


class LossBuilder:

LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss}

def __init__(self, weight_name, weight, name, img, pil_target):
self.weight_name = weight_name
self.weight = weight
self.name = name
self.img = img
self.pil_target = pil_target

# uh.... should the places this is beind used maybe just use Loss.__init__?
# TO DO: let's make this a class attribute on something

@property
def weight_category(self):
return self.weight_name.split("_")[0]

@property
def loss_factory(self):
weight_name = self.weight_category
if weight_name == "direct":
Loss = type(self.img).get_preferred_loss()
else:
Loss = self.LOSS_DICT[weight_name]
return Loss

def build_loss(self) -> Loss:
"""
Given a weight name, weight, name, image, and target image, returns a loss object
:param weight_name: The name of the loss function
:param weight: The weight of the loss
:param name: The name of the loss function
:param img: The image to be optimized
:param pil_target: The target image
:return: The loss function.
"""
Loss = self.loss_factory
out = Loss.TargetImage(
f"{self.weight_category} {self.name}:{self.weight}",
self.img.image_shape,
self.pil_target,
)
out.set_enabled(self.pil_target is not None)
return out


def _standardize_null(weight):
weight = str(weight).strip()
if weight in ("", "None"):
weight = "0"
if float(weight) == 0:
weight = ""
return weight


class LossConfigurator:
"""
Groups together procedures for initializing losses
"""

def __init__(
self,
init_image_pil: Image.Image,
restore: bool,
img: PixelImage,
embedder,
prompts,
# params,
########
direct_image_prompts,
semantic_stabilization_weight,
init_image,
semantic_init_weight,
animation_mode,
flow_stabilization_weight,
flow_long_term_samples,
smoothing_weight,
###########
direct_init_weight,
direct_stabilization_weight,
depth_stabilization_weight,
edge_stabilization_weight,
):
self.init_image_pil = init_image_pil
self.img = img
self.embedder = embedder
self.prompts = prompts

self.init_augs = []
self.loss_augs = []
self.optical_flows = []
self.last_frame_semantic = None
self.semantic_init_prompt = None

# self.params = params
self.restore = restore

### params
self.direct_image_prompts = direct_image_prompts
self.semantic_stabilization_weight = _standardize_null(
semantic_stabilization_weight
)
self.init_image = init_image
self.semantic_init_weight = _standardize_null(semantic_init_weight)
self.animation_mode = animation_mode
self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight)
self.flow_long_term_samples = flow_long_term_samples
self.smoothing_weight = _standardize_null(smoothing_weight)

######
self.direct_init_weight = _standardize_null(direct_init_weight)
self.direct_stabilization_weight = _standardize_null(
direct_stabilization_weight
)
self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight)
self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight)

def process_direct_image_prompts(self):
# prompt parsing shouldn't go here.
self.loss_augs.extend(
type(self.img)
.get_preferred_loss()
.TargetImage(p.strip(), self.img.image_shape, is_path=True)
for p in self.direct_image_prompts.split("|")
if p.strip()
)

def process_semantic_stabilization(self):
last_frame_pil = self.init_image_pil
if not last_frame_pil:
last_frame_pil = self.img.decode_image()
self.last_frame_semantic = parse_prompt(
self.embedder,
f"stabilization:{self.semantic_stabilization_weight}",
last_frame_pil,
)
self.last_frame_semantic.set_enabled(self.init_image_pil is not None)
for scene in self.prompts:
scene.append(self.last_frame_semantic)

def configure_losses(self):
if self.init_image_pil is not None:
self.configure_init_image()
self.process_direct_image_prompts()
if self.semantic_stabilization_weight:
self.process_semantic_stabilization()
self.configure_stabilization_augs()
self.configure_optical_flows()
self.configure_aesthetic_losses()

return (
self.loss_augs,
self.init_augs,
self.optical_flows,
self.semantic_init_prompt,
self.last_frame_semantic,
self.img,
)

def configure_init_image(self):

if not self.restore:
# move these logging statements into .encode_image()
logger.info("Encoding image...")
self.img.encode_image(self.init_image_pil)
logger.info("Encoded Image:")
# pretty sure this assumes we're in a notebook
display.display(self.img.decode_image())

## wrap this for the flexibility that the loop is pretending to provide...
# set up init image prompt
if self.direct_init_weight:
init_aug = LossBuilder(
"direct_init_weight",
self.direct_init_weight,
f"init image ({self.init_image})",
self.img,
self.init_image_pil,
).build_loss()
self.loss_augs.append(init_aug)
self.init_augs.append(init_aug)

########
if self.semantic_init_weight:
self.semantic_init_prompt = parse_prompt(
self.embedder,
f"init image [{self.init_image}]:{self.semantic_init_weight}",
self.init_image_pil,
)
self.prompts[0].append(self.semantic_init_prompt)

# stabilization
def configure_stabilization_augs(self):
d_augs = {
"direct_stabilization_weight": self.direct_stabilization_weight,
"depth_stabilization_weight": self.depth_stabilization_weight,
"edge_stabilization_weight": self.edge_stabilization_weight,
}
stabilization_augs = [
LossBuilder(
k, v, "stabilization", self.img, self.init_image_pil
).build_loss()
for k, v in d_augs.items()
if v
]
self.loss_augs.extend(stabilization_augs)

def configure_optical_flows(self):
optical_flows = None

if self.animation_mode == "Video Source":
if self.flow_stabilization_weight == "":
self.flow_stabilization_weight = "0"
optical_flows = [
OpticalFlowLoss.TargetImage(
f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}",
self.img.image_shape,
)
for i in range(self.flow_long_term_samples + 1)
]

elif self.animation_mode == "3D" and self.flow_stabilization_weight:
optical_flows = [
TargetFlowLoss.TargetImage(
f"optical flow stabilization:{self.flow_stabilization_weight}",
self.img.image_shape,
)
]

if optical_flows is not None:
for optical_flow in optical_flows:
optical_flow.set_enabled(False)
self.loss_augs.extend(optical_flows)

def configure_aesthetic_losses(self):
if self.smoothing_weight != 0:
self.loss_augs.append(TVLoss(weight=self.smoothing_weight))
27 changes: 0 additions & 27 deletions src/pytti/LossAug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,3 @@

# yeesh the ordering fragility in here...
# TO DO: let's make this a class attribute on something
LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss}


# uh.... should the places this is beind used maybe just use Loss.__init__?
# TO DO: let's make this a class attribute on something
def build_loss(weight_name, weight, name, img, pil_target) -> Loss:
"""
Given a weight name, weight, name, image, and target image, returns a loss object
:param weight_name: The name of the loss function
:param weight: The weight of the loss
:param name: The name of the loss function
:param img: The image to be optimized
:param pil_target: The target image
:return: The loss function.
"""

weight_name, suffix = weight_name.split("_", 1)
if weight_name == "direct":
Loss = type(img).get_preferred_loss()
else:
Loss = LOSS_DICT[weight_name]
out = Loss.TargetImage(
f"{weight_name} {name}:{weight}", img.image_shape, pil_target
)
out.set_enabled(pil_target is not None)
return out
Loading

0 comments on commit 1bb70c8

Please sign in to comment.