Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable diffusion VAE fine tuning (backport AutoencoderKL and its config.yaml to taming-transformers) #222

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions configs/finetune_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.AutoencoderKL
params:
embed_dim: 4
ckpt_path: "path/to/some/vae"
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1,2,4,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 10000
disc_weight: 0.8
#codebook_weight: 1.0

data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
train:
target: taming.data.custom.CustomTrain
params:
training_images_list_file: train_img.txt
size: 256
validation:
target: taming.data.custom.CustomTest
params:
test_images_list_file: val_img.txt
size: 256

238 changes: 238 additions & 0 deletions taming/models/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,244 @@
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer


class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x

def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])

def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)

def mode(self):
return self.mean


class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
train_decoder_only=True,
#ema_decay=None,
#learn_logvar=False
):
super().__init__()
#self.learn_logvar = learn_logvar
self.train_decoder_only = train_decoder_only
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) # factor 2: mean and variance
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

#self.use_ema = ema_decay is not None
#if self.use_ema:
# self.ema_decay = ema_decay
# assert 0. < ema_decay < 1.
# self.model_ema = LitEma(self, decay=ema_decay)
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

def init_from_ckpt(self, path, ignore_keys=list()):
if path.endswith(".safetensors"):
from safetensors import safe_open
with safe_open(path, framework="pt", device=0) as f:
sd = {k: f.get_tensor(k) for k in f.keys()}
else: sd = torch.load(path, map_location="cpu")["state_dict"]

keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")

#@contextmanager
#def ema_scope(self, context=None):
# if self.use_ema:
# self.model_ema.store(self.parameters())
# self.model_ema.copy_to(self)
# if context is not None:
# print(f"{context}: Switched to EMA weights")
# try:
# yield None
# finally:
# if self.use_ema:
# self.model_ema.restore(self.parameters())
# if context is not None:
# print(f"{context}: Restored training weights")

def on_train_batch_end(self, *args, **kwargs):
#if self.use_ema:
# self.model_ema(self)
pass

def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior

def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec

def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior

def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x

def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)

if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(torch.zeros_like(inputs), inputs, reconstructions, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss

if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(torch.zeros_like(inputs), inputs, reconstructions, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")

self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss

def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
#with self.ema_scope():
# log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict

def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(torch.zeros_like(inputs), inputs, reconstructions, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)

discloss, log_dict_disc = self.loss(torch.zeros_like(inputs), inputs, reconstructions, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)

self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict

def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
if not self.train_decoder_only:
ae_params_list += list(self.encoder.parameters()) + list(self.quant_conv.parameters())
#if self.learn_logvar:
# print(f"{self.__class__.__name__}: Learning logvar")
# ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []

def get_last_layer(self):
return self.decoder.conv_out.weight

@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
#if x.shape[1] > 3:
# # colorize with random projection
# assert xrec.shape[1] > 3
# x = self.to_rgb(x)
# xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
#if log_ema or self.use_ema:
# with self.ema_scope():
# xrec_ema, posterior_ema = self(x)
# if x.shape[1] > 3:
# # colorize with random projection
# assert xrec_ema.shape[1] > 3
# xrec_ema = self.to_rgb(xrec_ema)
# log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
# log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log

#def to_rgb(self, x):
# assert self.image_key == "segmentation"
# if not hasattr(self, "colorize"):
# self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
# x = F.conv2d(x, weight=self.colorize)
# x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
# return x

class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
Expand Down