Skip to content

Commit

Permalink
feat: add diffusion models
Browse files Browse the repository at this point in the history
  • Loading branch information
minkyu-choi07 committed Sep 16, 2023
1 parent 2c1fb96 commit 80135d9
Show file tree
Hide file tree
Showing 142 changed files with 19,837 additions and 0 deletions.
108 changes: 108 additions & 0 deletions ns_vfs/config/InstructPix2Pix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
general:
resolution: 512
steps: 100
checkpoint: /opt/Neuro-Symbolic-Video-Frame-Search/artifacts/weights/instruct-pix2pix-00-22000.ckpt
vae_ckpt:
edit: "turn human face into a joker's face"
cfg_text: 6.5
cfg_image: 1.5
seed:

model:
base_learning_rate: 1.0e-04
target: ns_vfs.model.diffusion.stable_diffusion.ldm.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: true
load_ema: true

scheduler_config: # 10000 warmup steps
target: ns_vfs.model.diffusion.stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ns_vfs.model.diffusion.stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ns_vfs.model.diffusion.stable_diffusion.ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ns_vfs.model.diffusion.stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder

data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True
Empty file.
11 changes: 11 additions & 0 deletions ns_vfs/model/diffusion/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

import abc


class Diffusion(abc.ABC):
"""Abstract base class for diffusion models."""

@abc.abstractmethod
def diffuse(self, input: any):
"""Diffuse the input."""
131 changes: 131 additions & 0 deletions ns_vfs/model/diffusion/pix2pix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import math
import random

import einops
import k_diffusion as K
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from torch import autocast, nn

from ns_vfs.model.diffusion.stable_diffusion.ldm.util import instantiate_from_config

from ._base import Diffusion


class PixToPix(Diffusion):
def __init__(self, config: OmegaConf):
self._config = config.general
self._model_config = config.model
self._model = self.load_model_from_config(
self._config, self._config.checkpoint, self._config.vae_ckpt
)
self._model.eval().cuda()
self._model_wrap = K.external.CompVisDenoiser(self._model)
self._model_wrap_cfg = CFGDenoiser(self._model_wrap)
self._null_token = self._model.get_learned_conditioning([""])
self._seed = random.randint(0, 100000) if self._config.seed is None else self._config.seed

def image_process(self, image):
if isinstance(image, np.ndarray):
input_image = Image.fromarray(image)
else:
input_image = Image.open(image).convert("RGB")
width, height = input_image.size
factor = self._config.resolution / max(width, height)
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
width = int((width * factor) // 64) * 64
height = int((height * factor) // 64) * 64
return ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)

def diffuse(self, input: any):
input_image = self.image_process(input)
with torch.no_grad(), autocast("cuda"), self._model.ema_scope():
cond = {}
cond["c_crossattn"] = [self._model.get_learned_conditioning([self._config.edit])]
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
input_image = rearrange(input_image, "h w c -> 1 c h w").to(self._model.device)
cond["c_concat"] = [self._model.encode_first_stage(input_image).mode()]

uncond = {}
uncond["c_crossattn"] = [self._null_token]
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]

sigmas = self._model_wrap.get_sigmas(self._config.steps)

extra_args = {
"cond": cond,
"uncond": uncond,
"text_cfg_scale": self._config.cfg_text,
"image_cfg_scale": self._config.cfg_image,
}
torch.manual_seed(self._seed)
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
z = K.sampling.sample_euler_ancestral(
self._model_wrap_cfg, z, sigmas, extra_args=extra_args
)
x = self._model.decode_first_stage(z)
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
x = 255.0 * rearrange(x, "1 c h w -> h w c")
edited_img = Image.fromarray(x.type(torch.uint8).cpu().numpy())
edited_img.save("output.jpg")
return edited_img

def load_model_from_config(self, config, ckpt, vae_ckpt=None, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
if vae_ckpt is not None:
print(f"Loading VAE from {vae_ckpt}")
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
sd = {
k: vae_sd[k[len("first_stage_model.") :]]
if k.startswith("first_stage_model.")
else v
for k, v in sd.items()
}
model = instantiate_from_config(self._model_config)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
return model


class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model

def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
cfg_cond = {
"c_crossattn": [
torch.cat(
[
cond["c_crossattn"][0],
uncond["c_crossattn"][0],
uncond["c_crossattn"][0],
]
)
],
"c_concat": [
torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])
],
}
out_cond, out_img_cond, out_uncond = self.inner_model(
cfg_z, cfg_sigma, cond=cfg_cond
).chunk(3)
return (
out_uncond
+ text_cfg_scale * (out_cond - out_img_cond)
+ image_cfg_scale * (out_img_cond - out_uncond)
)
Loading

0 comments on commit 80135d9

Please sign in to comment.