Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pytorch implementation for Sea-Pix-GAN #5

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f98b5ae
add new readme
w4a2y4 Mar 16, 2024
cb48577
rm tensorflow
w4a2y4 Mar 16, 2024
506e352
add branch rule
w4a2y4 Mar 16, 2024
397e1c9
propose dir structure
w4a2y4 Mar 16, 2024
a046294
add layer structures for generator
w4a2y4 Mar 16, 2024
f22bad9
add file structure to README
w4a2y4 Mar 16, 2024
c594f6b
ignore .pyc files
w4a2y4 Mar 16, 2024
48069e5
finish generator impl
w4a2y4 Mar 27, 2024
4a87838
ignore pycache
w4a2y4 Mar 27, 2024
defc44f
Implemented correct loss functions (I think) into train based on ugan
ejerez Mar 27, 2024
8de8ba5
Merge branch 'sea-pix-gan' into generator-charlotte
w4a2y4 Mar 28, 2024
76c8476
add discriminator
ZIMUQIN-L Mar 28, 2024
07f2ffe
add discriminator
ZIMUQIN-L Mar 28, 2024
d188692
Merge pull request #2 from w4a2y4/discriminator-peijieli
ejerez Mar 30, 2024
5b5f84b
Merge branch 'sea-pix-gan' into generator-charlotte
ejerez Mar 30, 2024
811e776
Merge pull request #1 from w4a2y4/generator-charlotte
ejerez Mar 30, 2024
255de19
rm unused vars
w4a2y4 Mar 30, 2024
a80fd69
finish generator training script
w4a2y4 Mar 30, 2024
6051bc0
Merge pull request #3 from w4a2y4/generator-train
ZIMUQIN-L Mar 31, 2024
299eac7
add discriminator train
ZIMUQIN-L Mar 31, 2024
3d7367e
Merge pull request #4 from w4a2y4/discriminator_train
w4a2y4 Mar 31, 2024
8b884fa
address comment
w4a2y4 Mar 31, 2024
348874d
fix generator typo
ZIMUQIN-L Apr 1, 2024
689ea38
add qinxin's test
w4a2y4 Apr 3, 2024
5eb01e4
try add sea-pix-gan nb
w4a2y4 Apr 3, 2024
4dc97b4
merge networks file to avoid relative import
w4a2y4 Apr 3, 2024
845347b
use same downlayer for G & D
w4a2y4 Apr 3, 2024
a226e94
rm nb
w4a2y4 Apr 3, 2024
85afeee
Revert "use same downlayer for G & D"
w4a2y4 Apr 3, 2024
3253595
fix runtime errors
w4a2y4 Apr 3, 2024
479e5a1
add logs
w4a2y4 Apr 6, 2024
85f77cb
clean training code
w4a2y4 Apr 8, 2024
7955296
rm log
w4a2y4 Apr 8, 2024
08f5881
revert readmes
w4a2y4 Apr 8, 2024
68e91da
revert data & TF
w4a2y4 Apr 8, 2024
540dcfc
add test
Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.pyc
4 changes: 2 additions & 2 deletions Evaluation/uqim_utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def eme(x, window_size):
x.shape[1] = width
"""
# if 4 blocks, then 2x2...etc.
k1 = x.shape[1]/window_size
k2 = x.shape[0]/window_size
k1 = int(x.shape[1]/window_size)
k2 = int(x.shape[0]/window_size)
# weight
w = 2./(k1*k2)
blocksize_x = window_size
Expand Down
3 changes: 2 additions & 1 deletion PyTorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
### Resources
- Implementation of **FUnIE-GAN** (paired)
- Simplified implementations of **UGAN / UGAN-P** ([original repo](https://github.com/cameronfabbri/Underwater-Color-Correction))
- Implementation of **Sea-Pix-GAN**
- Implementation: PyTorch 1.6 (Python 3.8)

### Usage
- Download the data, setup data-paths in the [config files](configs)
- Use the training scripts for paired training of FUnIE-GAN or UGAN/UGAN-P
- Use the training scripts for paired training of FUnIE-GAN, UGAN/UGAN-P or Sea-Pix-GAN
- Use the [test.py](test.py) script for evaluation
- A sample model is provided in [models](models)
- *Note that the [TF-Keras implementation](/TF-Keras/) is the official one; use those weights to reproduce results in the paper*
Expand Down
122 changes: 122 additions & 0 deletions PyTorch/nets/seapixgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
> Network architecture of Sea-pix-GAN model
* Original paper: https://doi.org/10.1016/j.jvcir.2023.104021
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class SeaPixGAN_Nets:
def __init__(self, base_model='pix2pix'):
if base_model=='pix2pix': # default
self.netG = GeneratorSeaPixGan()
self.netD = DiscriminatorSeaPixGan()
elif base_model=='resnet':
#TODO: add ResNet support
raise NotImplementedError()
else:
raise NotImplementedError()

class GeneratorSeaPixGan(nn.Module):
def __init__(self):
super(GeneratorSeaPixGan, self).__init__()

self.e1 = _EncodeLayer(3, 64, batch_normalize=False)
self.e2 = _EncodeLayer(64, 128)
self.e3 = _EncodeLayer(128, 256)
self.e4 = _EncodeLayer(256, 512)
self.e5 = _EncodeLayer(512, 512)
self.e6 = _EncodeLayer(512, 512)
self.e7 = _EncodeLayer(512, 512)
self.e8 = _EncodeLayer(512, 512)

self.d1 = _DecodeLayer(512, 512, dropout=True)
self.d2 = _DecodeLayer(1024, 512, dropout=True)
self.d3 = _DecodeLayer(1024, 512, dropout=True)
self.d4 = _DecodeLayer(1024, 512)
self.d5 = _DecodeLayer(1024, 256)
self.d6 = _DecodeLayer(512, 128)
self.d7 = _DecodeLayer(256, 64)

self.deconv = nn.ConvTranspose2d(
in_channels=128, out_channels=3,
kernel_size=4, stride=2, padding=1, bias=False
)

def forward(self, x):
# x: (256×256×3)
e1 = self.e1(x) # (128×128×64)
e2 = self.e2(e1) # (64×64×128)
e3 = self.e3(e2) # (32×32×256)
e4 = self.e4(e3) # (16×16×512)
e5 = self.e5(e4) # (8×8×512)
e6 = self.e6(e5) # (4×4×512)
e7 = self.e7(e6) # (2×2×512)
e8 = self.e8(e7) # (1×1×512)

d1 = self.d1(e8, e7) # (2×2×(512+512))
d2 = self.d2(d1, e6) # (4×4×(512+512))
d3 = self.d3(d2, e5) # (8×8×(512+512))
d4 = self.d4(d3, e4) # (16×16×(512+512))
d5 = self.d5(d4, e3) # (32×32×(256+256))
d6 = self.d6(d5, e2) # (64×64×(128+128))
d7 = self.d7(d6, e1) # (128×128×(64+64))

final = self.deconv(d7) # (256×256×3)
return final

class _EncodeLayer(nn.Module):
def __init__(self, in_size, out_size, batch_normalize=True):
super(_EncodeLayer, self).__init__()
layers = [nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=4, stride=2, padding=1, bias=False)]
if batch_normalize:
layers.append(nn.BatchNorm2d(num_features=out_size))
layers.append(nn.LeakyReLU(negative_slope=0.2))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

class _DecodeLayer(nn.Module):
def __init__(self, in_size, out_size, dropout=False):
super(_DecodeLayer, self).__init__()
layers = [
nn.ConvTranspose2d(in_channels=in_size, out_channels=out_size, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(0.5))
self.model = nn.Sequential(*layers)

def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x


class DiscriminatorSeaPixGan(nn.Module):
def __init__(self, in_channels=3):
super(DiscriminatorSeaPixGan, self).__init__()

def down_layer(in_filters, out_filters, normalization=True):
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.BatchNorm2d(out_filters, momentum=0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*down_layer(2*in_channels, 64),
*down_layer(64, 128),
*down_layer(128, 256),
nn.ZeroPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, 4, padding=0, bias=False),
nn.BatchNorm2d(512, momentum=0.8),
nn.ZeroPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 1, 4, padding=0, bias=False)
)

def forward(self, img_A, img_B):
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
184 changes: 184 additions & 0 deletions PyTorch/train_seapixgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
> Training pipeline for Sea-pix-GAN models
* Original paper: https://doi.org/10.1016/j.jvcir.2023.104021
"""

# py libs
import os
import sys
import yaml
import argparse
from PIL import Image
# pytorch libs
import torch
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.transforms as transforms
# local libs
from nets.seapixgan import SeaPixGAN_Nets
from nets.commons import Weights_Normal
from utils.data_utils import GetTrainingPairs, GetValImage

## get configs and training options
parser = argparse.ArgumentParser()
parser.add_argument("--cfg_file", type=str, default="configs/train_euvp.yaml")
parser.add_argument("--epoch", type=int, default=0, help="which epoch to start from")
parser.add_argument("--num_epochs", type=int, default=150, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches, paper uses 64")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate, paper uses 0.0002")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of 1st order momentum, paper uses 0.5")
parser.add_argument("--b2", type=float, default=0.99, help="adam: decay of 2nd order momentum")
parser.add_argument("--l1_weight", type=float, default=100, help="Weight for L1 loss, paper uses 100")

args = parser.parse_args()

## training params
epoch = args.epoch
num_epochs = args.num_epochs
num_critic = args.n_critic
batch_size = args.batch_size
lr = args.lr
beta_1 = args.b1
beta_2 = args.b2
lambda_1 = args.l1_weight
model_v = "Sea-pix-GAN"

# load the data config file
with open(args.cfg_file) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)

# get info from config file
dataset_name = cfg["dataset_name"]
dataset_path = cfg["dataset_path"]
channels = cfg["chans"]
img_width = cfg["im_width"]
img_height = cfg["im_height"]
val_interval = cfg["val_interval"]
ckpt_interval = cfg["ckpt_interval"]

## create dir for model and validation data
samples_dir = "samples/%s/%s" % (model_v, dataset_name)
checkpoint_dir = "checkpoints/%s/%s/" % (model_v, dataset_name)
os.makedirs(samples_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)


""" Sea-pix-GAN specifics: loss functions
-------------------------------------------------"""
L1_G = torch.nn.L1Loss() # l1 loss term
L_BCE = torch.nn.BCEWithLogitsLoss() # Binary cross entropy


# Initialize generator and discriminator
seapixgan_ = SeaPixGAN_Nets(base_model='pix2pix')
generator = seapixgan_.netG
discriminator = seapixgan_.netD

# see if cuda is available
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
L1_G = L1_G.cuda()
L_BCE = L_BCE.cuda()
Tensor = torch.cuda.FloatTensor
else:
Tensor = torch.FloatTensor

# Initialize weights or load pretrained models
if args.epoch == 0:
generator.apply(Weights_Normal)
discriminator.apply(Weights_Normal)
else:
generator.load_state_dict(torch.load("checkpoints/%s/%s/generator_%d.pth" % (model_v, dataset_name, args.epoch)))
discriminator.load_state_dict(torch.load("checkpoints/%s/%s/discriminator_%d.pth" % (model_v, dataset_name, epoch)))
print ("Loaded model from epoch %d" %(epoch))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta_1, beta_2))


## Data pipeline
transforms_ = [
transforms.Resize((img_height, img_width), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
GetTrainingPairs(dataset_path, dataset_name, transforms_=transforms_),
batch_size = batch_size,
shuffle = True,
num_workers = 4,
)

val_dataloader = DataLoader(
GetValImage(dataset_path, dataset_name, transforms_=transforms_, sub_dir='validation'),
batch_size=4,
shuffle=True,
num_workers=1,
)


## Training pipeline
all_loss = []
for epoch in range(epoch, num_epochs):
batch_loss = []
for i, batch in enumerate(dataloader):
# Model inputs
imgs_distorted = Variable(batch["A"].type(Tensor)) # x: input underwater img
imgs_good_gt = Variable(batch["B"].type(Tensor)) # y: ground truth underwater img

## Train Discriminator
optimizer_D.zero_grad()
imgs_fake = generator(imgs_distorted)
pred_real = discriminator(imgs_good_gt, imgs_distorted)
pred_fake = discriminator(imgs_fake, imgs_distorted)
loss_D_gen = L_BCE(pred_fake, torch.zeros_like(pred_fake))
loss_D_real = L_BCE(pred_real, torch.ones_like(pred_real))
loss_D = loss_D_gen + loss_D_real
loss_D.backward()
optimizer_D.step()

## Train Generator
optimizer_G.zero_grad()
# regenerate imgs
imgs_fake = generator(imgs_distorted)
pred_fake = discriminator(imgs_fake.detach(), imgs_distorted.detach())
# calculate loss function
loss_1 = L1_G(imgs_fake, imgs_good_gt)
loss_cgan = L_BCE(pred_fake, torch.ones_like(pred_fake))
loss_G = loss_cgan + lambda_1 * loss_1 # Total loss: Eq.4 in paper
# backward & steps
loss_G.backward()
optimizer_G.step()

batch_loss.append([loss_D.item(), loss_G.item(), loss_cgan.item(), loss_1.item()])

## If at sample interval save image
batches_done = epoch * len(dataloader) + i
if batches_done % val_interval == 0:
imgs = next(iter(val_dataloader))
imgs_val = Variable(imgs["val"].type(Tensor))
imgs_gen = generator(imgs_val)
img_sample = torch.cat((imgs_val.data, imgs_gen.data), -2)
save_image(img_sample, "samples/%s/%s/%s.png" % (model_v, dataset_name, batches_done), nrow=5, normalize=True)

epoch_loss = (torch.Tensor(batch_loss)).mean(dim=0).tolist()
all_loss.append(epoch_loss)
print("[Epoch %d/%d] [DLoss: %.3f, GLoss: %.3f, cGanLoss: %.3f, L1Loss: %.3f]"
%(
epoch, num_epochs,
epoch_loss[0], epoch_loss[1], epoch_loss[2], epoch_loss[3]
)
)

## Save model checkpoints
if (epoch % ckpt_interval == 0):
torch.save(generator.state_dict(), "checkpoints/%s/%s/generator_%d.pth" % (model_v, dataset_name, epoch))
torch.save(discriminator.state_dict(), "checkpoints/%s/%s/discriminator_%d.pth" % (model_v, dataset_name, epoch))

## result
torch.save(generator.state_dict(), "checkpoints/%s/%s/generator.pth" % (model_v, dataset_name))
torch.save(discriminator.state_dict(), "checkpoints/%s/%s/discriminator.pth" % (model_v, dataset_name))