Skip to content

Commit

Permalink
Logger & trainer tests (#5)
Browse files Browse the repository at this point in the history
* Add tensorboard logger

* Minor

* Update affine coupling

* Fix flow block
Add autoflake

* Minor fixes

* Add tests for trainer
  • Loading branch information
kashperova authored Dec 21, 2024
1 parent fc55341 commit d28101a
Show file tree
Hide file tree
Showing 27 changed files with 500 additions and 75 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
todo.py
./runs
./samples
./.misc/notebooks
5 changes: 5 additions & 0 deletions .misc/notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- Norm flows can't can't work with discrete random variables, so we need to dequantize input image tensors.
Here the simplest solution [implemented](../src/modules/utils/tensors.py): adding a small amount of noise to each discrete value.
But in general it is better to use <a href="https://arxiv.org/abs/1902.00275">variational dequantization</a>.
- Read more about <a href="https://arxiv.org/abs/1605.08803v3">KL duality</a>
- Jacobian can be interpreted as an indicator of how the volume of the probability space changes
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ repos:
hooks:
- id: isort
args: [ "--profile", "black", "--filter-files" ]
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args:
- '--remove-all-unused-imports'
- '--remove-unused-variables'
- '--exclude=__init__.py'
- '--in-place'
- '--recursive'
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
Expand Down
197 changes: 196 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ hydra-core = "1.3.2"
tqdm = "4.66.5"
torchvision = "0.19.0"
natsort = "8.4.0"
tensorboard = "2.18.0"


[build-system]
Expand Down
10 changes: 6 additions & 4 deletions src/configs/celeba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- _self_
optimizer:
_target_: torch.optim.Adam
lr: 2e-4
lr: 1e-4
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
Expand All @@ -15,16 +15,18 @@ lr_scheduler:
loss_func:
_target_: modules.utils.losses.GlowLoss
trainer:
run_name: baseline
n_bins: 32
n_epochs: 2
train_test_split: 0.85
train_batch_size: 16
test_batch_size: 16
image_size: 64
log_steps: 10
sampling_iters: 30
log_steps: 50
log_dir: ./runs
sampling_steps: 50
n_samples: 10
samples_dir: samples
save_dir: glow
save_dir: ./glow
seed: 42
use_ddp: false
Empty file removed src/configs/model.yaml
Empty file.
19 changes: 10 additions & 9 deletions src/model/affine_coupling/affine_coupling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import Tensor
from torch.nn import functional as F

from model.affine_coupling.net import NN
from model.invert_block import InvertBlock
Expand All @@ -23,22 +24,22 @@ class AffineCoupling(InvertBlock):
"""

def __init__(self, in_ch: int, hidden_ch: int):
super(AffineCoupling, self).__init__()
super().__init__()
self.net = NN(in_ch, hidden_ch)

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x_a, x_b = x.chunk(2, dim=1)
log_s, t = self.net(x_b)
s = torch.exp(log_s)
log_s, t = self.net(x_a)
s = F.sigmoid(log_s + 2)
log_det = torch.sum(torch.log(s).view(x.shape[0], -1), 1)
y_a = x_a * s + t
y_b = x_b
y_b = (x_b + t) * s
y_a = x_a
return torch.concat([y_a, y_b], dim=1), log_det

def reverse(self, y: Tensor) -> Tensor:
y_a, y_b = y.chunk(2, dim=1)
log_s, t = self.net(y_b)
s = torch.exp(log_s)
x_a = (y_a - t) / s
x_b = y_b
log_s, t = self.net(y_a)
s = F.sigmoid(log_s + 2)
x_b = y_b / s - t
x_a = y_a
return torch.concat([x_a, x_b], dim=1)
8 changes: 3 additions & 5 deletions src/model/affine_coupling/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ class NN(nn.Module):
"""

def __init__(self, in_ch: int, hidden_ch: int):
super(NN, self).__init__()
conv1 = nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1)
conv2 = nn.Conv2d(hidden_ch, hidden_ch, 1)
super().__init__()
self.net = nn.Sequential(
conv1,
nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1),
nn.ReLU(inplace=True),
conv2,
nn.Conv2d(hidden_ch, hidden_ch, 1),
nn.ReLU(inplace=True),
ZeroConv2d(hidden_ch, in_ch),
)
Expand Down
2 changes: 1 addition & 1 deletion src/model/flow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
# split out on 2 parts
out, z_new = out.chunk(2, dim=1)
log_p = self.__get_prob_density(
prior_out=out, out=out, batch_size=batch_size
prior_out=out, out=z_new, batch_size=batch_size
)
else:
# for the last level prior distribution
Expand Down
6 changes: 4 additions & 2 deletions src/model/glow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

from torch import Tensor, nn

from model.flow_block import FlowBlock
Expand Down Expand Up @@ -30,8 +32,8 @@ def __init__(
coupling_hidden_ch: int = 512,
squeeze_factor: int = 2,
):
super(Glow, self).__init__()
self.in_ch = in_ch
super().__init__()
self.in_ch = deepcopy(in_ch)
self.n_flows = n_flows
self.num_blocks = num_blocks
self.squeeze_factor = squeeze_factor
Expand Down
4 changes: 4 additions & 0 deletions src/model/invert_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class InvertConv(InvertBlock):
so determinant's calculation has not cubic,
but linear complexity.
having fixed the P, we restrict some class of all transformations;
having achieved that the elements on the L & U diagonal will be positive
we will be sure that the weights matrix is invertible
attrs (trainable)
----------
ut_matrix: nn.Parameter
Expand Down
1 change: 1 addition & 0 deletions src/modules/logger/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from modules.logger.logger import TensorboardLogger
30 changes: 30 additions & 0 deletions src/modules/logger/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

from torch import Tensor
from torch.utils.tensorboard import SummaryWriter


class TensorboardLogger:
def __init__(self, log_dir: str, run_name: str, log_steps: int):
log_dir = os.path.join(log_dir, run_name)

if not os.path.exists(log_dir):
os.makedirs(log_dir)

self.log_dir = log_dir
self.log_steps = log_steps
self.writer = SummaryWriter(log_dir=log_dir)

def __del__(self):
self.writer.flush()
self.writer.close()

def log_train_loss(self, loss: float, step: int):
if step % self.log_steps == 0:
self.writer.add_scalar("Loss/train", loss, step)

def log_test_loss(self, loss: float, epoch: int):
self.writer.add_scalar("Loss/test", loss, epoch)

def log_images(self, grid: Tensor, step: int):
self.writer.add_image(tag="samples", img_tensor=grid, global_step=step)
Empty file removed src/modules/trainer/ddp.py
Empty file.
5 changes: 2 additions & 3 deletions src/modules/trainer/ddp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,5 @@ def train(self):

for i in tqdm(range(self.train_config.epochs)):
self.ddp.set_train_epoch(i)
train_loss = self.train_epoch()
test_loss = self.test_epoch()
print(f"Train loss: {train_loss}, Test loss: {test_loss}", flush=True)
self.train_epoch()
self.test_epoch()
72 changes: 51 additions & 21 deletions src/modules/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
from typing import Callable

import torch
from omegaconf import DictConfig
from PIL import Image
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand All @@ -11,10 +13,14 @@
from tqdm import tqdm

from model.glow import Glow
from modules.logger import TensorboardLogger
from modules.utils.sampling import get_z_list
from modules.utils.tensors import dequantize
from modules.utils.train import SizedDataset, train_test_split

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Trainer:
def __init__(
Expand All @@ -41,6 +47,11 @@ def __init__(
)
self.train_loader = None
self.test_loader = None
self.logger = TensorboardLogger(
log_dir=self.train_config.log_dir,
run_name=self.train_config.run_name,
log_steps=self.train_config.log_steps,
)

self.z_list = get_z_list(
glow=self.model,
Expand All @@ -51,30 +62,34 @@ def __init__(

self.z_list = [z_i.to(self.device) for z_i in self.z_list]

def train_epoch(self) -> float:
def train_epoch(self, epoch: int) -> float:
self.model.train()
run_train_loss = 0.0
run_train_loss, n_iters = 0.0, 0
for i, images in enumerate(self.train_loader):
images = dequantize(images)
images = dequantize(images, n_bins=self.train_config.n_bins)
images = images.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.loss_func(outputs, images)
loss.backward()
self.optimizer.step()
run_train_loss += loss.item()
n_iters += 1

if i % self.train_config.sampling_iters == 0:
self.save_sample(label=f"iter_{i}")
if i % self.train_config.sampling_steps == 0 and i != 0:
self.log_samples(step=epoch + i)
avg_loss = run_train_loss / n_iters
self.logger.log_train_loss(loss=avg_loss, step=epoch + i)
logger.info(f"Train avg loss: {avg_loss}")

return run_train_loss

@torch.inference_mode()
def test_epoch(self) -> float:
self.model.eval()
run_test_loss = 0.0
for images, _ in self.test_loader:
images = dequantize(images)
for images in self.test_loader:
images = dequantize(images, self.train_config.n_bins)
images = images.to(self.device)
outputs = self.model(images)
run_test_loss += self.loss_func(outputs, images).item()
Expand All @@ -97,20 +112,28 @@ def train(self):
self.model = nn.DataParallel(self.model).to(self.device)

with torch.no_grad():
images = dequantize(next(iter(self.test_loader)))
images = dequantize(
next(iter(self.test_loader)), n_bins=self.train_config.n_bins
)
images = images.to(self.device)
self.model.module(images)

for i in tqdm(range(self.train_config.n_epochs)):
train_loss = self.train_epoch()
train_loss = self.train_epoch(1)
train_loss /= len(self.train_dataset)

test_loss = self.test_epoch()
test_loss /= len(self.test_dataset)

self.logger.log_test_loss(loss=test_loss, epoch=i + 1)
self.lr_scheduler.step(test_loss)

self.save_checkpoint(epoch=i)

def save_checkpoint(self, epoch: int):
if not os.path.exists(self.train_config.save_dir):
os.makedirs(self.train_config.save_dir, exist_ok=True)

torch.save(
self.model.state_dict(), f"{self.train_config.save_dir}/model_{epoch}.pt"
)
Expand All @@ -120,15 +143,22 @@ def save_checkpoint(self, epoch: int):
)

@torch.inference_mode()
def save_sample(self, label: str):
# todo: change to logging (tensorboard)
if not os.path.exists(self.train_config.samples_dir):
os.makedirs(self.train_config.samples_dir)

utils.save_image(
self.model.module.reverse(self.z_list).cpu().data,
f"{self.train_config.samples_dir}/{label}.png",
normalize=True,
nrow=10,
value_range=(-0.5, 0.5),
)
def log_samples(self, step: int, save_png: bool = True):
data = self.model.module.reverse(self.z_list).cpu().data
grid = utils.make_grid(data, nrow=5, normalize=True, value_range=(-0.5, 0.5))
self.logger.log_images(grid=grid, step=step)

if save_png:
if not os.path.exists(self.train_config.samples_dir):
os.makedirs(self.train_config.samples_dir)

np_array = (
grid.mul(255)
.add_(0.5)
.clamp_(0, 255)
.permute(1, 2, 0)
.to("cpu", torch.uint8)
.numpy()
)
im = Image.fromarray(np_array)
im.save(f"{self.train_config.samples_dir}/{step}.png")
7 changes: 7 additions & 0 deletions src/modules/utils/tensors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch import Tensor

Expand All @@ -24,6 +26,11 @@ def reverse_squeeze(x: Tensor, factor: int = 2) -> Tensor:

def dequantize(x: Tensor, n_bins: int = 256) -> Tensor:
x = x * 255
n_bits = math.log(n_bins, 2)

if n_bits < 8:
x = torch.floor(x / 2 ** (8 - n_bits))

x = x / n_bins - 0.5
x = x + torch.rand_like(x) / n_bins
return x
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
"fixtures.blocks",
"fixtures.config",
"fixtures.inputs",
"fixtures.trainer",
]
6 changes: 3 additions & 3 deletions tests/fixtures/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def invert_conv():
@pytest.fixture(scope="function")
def affine_coupling():
return AffineCoupling(
in_ch=TestConfig.in_ch, hidden_ch=TestConfig.coupling_hidden_ch
in_ch=TestConfig.in_ch * 2, hidden_ch=TestConfig.coupling_hidden_ch
)


@pytest.fixture(scope="function")
def flow():
return Flow(
in_ch=TestConfig.in_ch, coupling_hidden_ch=TestConfig.coupling_hidden_ch
in_ch=TestConfig.in_ch * 2, coupling_hidden_ch=TestConfig.coupling_hidden_ch
)


Expand All @@ -54,7 +54,7 @@ def last_flow_block():
)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def glow():
return Glow(
in_ch=TestConfig.in_ch,
Expand Down
Loading

0 comments on commit d28101a

Please sign in to comment.