diff --git a/.gitignore b/.gitignore
index cb5263e..e3c02ab 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,4 +2,4 @@
*__pycache__
static/output
-
+logs
diff --git a/CEIQ.py b/CEIQ.py
index 5752306..5c0682d 100644
--- a/CEIQ.py
+++ b/CEIQ.py
@@ -1,14 +1,13 @@
+import pickle
+
import cv2
-import os
-import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
-from scipy.stats import entropy as ent
-import pickle
+
class CEIQ:
def __init__(self):
- with open('CEIQ_model_v1_1.pickle', 'rb') as f:
+ with open('models/CEIQ_model_v1_1.pickle', 'rb') as f:
self.model = pickle.load(f)
def entropy(self, hist, bit_instead_of_nat=False):
@@ -21,12 +20,12 @@ def entropy(self, hist, bit_instead_of_nat=False):
"""
# h = h[np.where(h!=0)[0]]
h = np.asarray(hist, dtype=np.float64)
- if h.sum()<=0 or (h<0).any():
- print("[entropy] WARNING, malformed/empty input %s. Returning None."%str(hist))
+ if h.sum() <= 0 or (h < 0).any():
+ print("[entropy] WARNING, malformed/empty input %s. Returning None." % str(hist))
return None
- h = h/h.sum()
+ h = h / h.sum()
log_fn = np.ma.log2 if bit_instead_of_nat else np.ma.log
- return -(h*log_fn(h)).sum()
+ return -(h * log_fn(h)).sum()
def cross_entropy(self, x, y):
""" Computes cross entropy between two distributions.
@@ -47,30 +46,30 @@ def cross_entropy(self, x, y):
# Ignore zero 'y' elements.
mask = y > 0
x = x[mask]
- y = y[mask]
- ce = -np.sum(x * np.log(y))
+ y = y[mask]
+ ce = -np.sum(x * np.log(y))
return ce
def generate_x(self, img_path, option=0):
if option == 0:
Ig = cv2.imread(img_path)
- Ig = 0.299*Ig[:, :, 2] + 0.587*Ig[:, :, 1] + 0.114*Ig[:, :, 0]
+ Ig = 0.299 * Ig[:, :, 2] + 0.587 * Ig[:, :, 1] + 0.114 * Ig[:, :, 0]
Ig = Ig.astype('uint8')
else:
# Ig = cv2.cvtColor(img_path, cv2.COLOR_BGR2GRAY)
- Ig = 0.299*img_path[:, :, 2] + 0.587*img_path[:, :, 1] + 0.114*img_path[:, :, 0]
+ Ig = 0.299 * img_path[:, :, 2] + 0.587 * img_path[:, :, 1] + 0.114 * img_path[:, :, 0]
Ig = Ig.astype('uint8')
Ie = cv2.equalizeHist(Ig)
### Calculate ssim ###
ssim_ig_ie, _ = ssim(Ig, Ie, full=True)
### Get histograms ###
- histg = cv2.calcHist([Ig],[0],None,[128],[0,256])
- histe = cv2.calcHist([Ie],[0],None,[128],[0,256])
+ histg = cv2.calcHist([Ig], [0], None, [128], [0, 256])
+ histe = cv2.calcHist([Ie], [0], None, [128], [0, 256])
histg = np.reshape(histg, (histg.shape[0]))
histe = np.reshape(histe, (histe.shape[0]))
- zero_idsg = np.where(histg==0)[0]
- zero_idse = np.where(histe==0)[0]
+ zero_idsg = np.where(histg == 0)[0]
+ zero_idse = np.where(histe == 0)[0]
zero_ids = np.unique(np.concatenate((zero_idsg, zero_idse)))
# print(zero_ids)
histg = np.delete(histg, zero_ids)
diff --git a/Deblurring/Datasets/README.md b/Deblurring/Datasets/README.md
index 8ef3a6b..a9f0eb8 100644
--- a/Deblurring/Datasets/README.md
+++ b/Deblurring/Datasets/README.md
@@ -1,14 +1,15 @@
-Download datasets from the google drive links and place them in this directory. Your directory tree should look like this
+Download datasets from the google drive links and place them in this directory. Your directory tree should look like
+this
`GoPro`
- `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
- `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)
+`├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
+`└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)
`HIDE`
- `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)
+`└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)
`RealBlur_J`
- `└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing)
+`└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing)
`RealBlur_R`
- `└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing)
+`└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing)
diff --git a/Deblurring/MPRNet.py b/Deblurring/MPRNet.py
index 5e1b040..fb9d6ab 100644
--- a/Deblurring/MPRNet.py
+++ b/Deblurring/MPRNet.py
@@ -6,14 +6,13 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
-from pdb import set_trace as stx
+
##########################################################################
-def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
+def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
- padding=(kernel_size//2), bias=bias, stride = stride)
+ padding=(kernel_size // 2), bias=bias, stride=stride)
##########################################################################
@@ -25,10 +24,10 @@ def __init__(self, channel, reduction=16, bias=False):
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
- nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
- nn.ReLU(inplace=True),
- nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
- nn.Sigmoid()
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
+ nn.Sigmoid()
)
def forward(self, x):
@@ -56,6 +55,7 @@ def forward(self, x):
res += x
return res
+
##########################################################################
## Supervised Attention Module
class SAM(nn.Module):
@@ -69,10 +69,11 @@ def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img
x2 = torch.sigmoid(self.conv3(img))
- x1 = x1*x2
- x1 = x1+x
+ x1 = x1 * x2
+ x1 = x1 + x
return x1, img
+
##########################################################################
## U-Net
@@ -80,26 +81,30 @@ class Encoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
super(Encoder, self).__init__()
- self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
- self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
- self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
+ self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
+ self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in
+ range(2)]
+ self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in
+ range(2)]
self.encoder_level1 = nn.Sequential(*self.encoder_level1)
self.encoder_level2 = nn.Sequential(*self.encoder_level2)
self.encoder_level3 = nn.Sequential(*self.encoder_level3)
- self.down12 = DownSample(n_feat, scale_unetfeats)
- self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
+ self.down12 = DownSample(n_feat, scale_unetfeats)
+ self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
# Cross Stage Feature Fusion (CSFF)
if csff:
- self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
- self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
- self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
+ self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
+ self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
+ self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1,
+ bias=bias)
- self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
- self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
- self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
+ self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
+ self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
+ self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1,
+ bias=bias)
def forward(self, x, encoder_outs=None, decoder_outs=None):
enc1 = self.encoder_level1(x)
@@ -117,26 +122,29 @@ def forward(self, x, encoder_outs=None, decoder_outs=None):
enc3 = self.encoder_level3(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
-
+
return [enc1, enc2, enc3]
+
class Decoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
super(Decoder, self).__init__()
- self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
- self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
- self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
+ self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
+ self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in
+ range(2)]
+ self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in
+ range(2)]
self.decoder_level1 = nn.Sequential(*self.decoder_level1)
self.decoder_level2 = nn.Sequential(*self.decoder_level2)
self.decoder_level3 = nn.Sequential(*self.decoder_level3)
- self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
- self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
+ self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
+ self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
- self.up21 = SkipUpSample(n_feat, scale_unetfeats)
- self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
+ self.up21 = SkipUpSample(n_feat, scale_unetfeats)
+ self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
def forward(self, outs):
enc1, enc2, enc3 = outs
@@ -148,41 +156,45 @@ def forward(self, outs):
x = self.up21(dec2, self.skip_attn1(enc1))
dec1 = self.decoder_level1(x)
- return [dec1,dec2,dec3]
+ return [dec1, dec2, dec3]
+
##########################################################################
##---------- Resizing Modules ----------
class DownSample(nn.Module):
- def __init__(self, in_channels,s_factor):
+ def __init__(self, in_channels, s_factor):
super(DownSample, self).__init__()
self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
- nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False))
+ nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.down(x)
return x
+
class UpSample(nn.Module):
- def __init__(self, in_channels,s_factor):
+ def __init__(self, in_channels, s_factor):
super(UpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
- nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.up(x)
return x
+
class SkipUpSample(nn.Module):
- def __init__(self, in_channels,s_factor):
+ def __init__(self, in_channels, s_factor):
super(SkipUpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
- nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x, y):
x = self.up(x)
x = x + y
return x
+
##########################################################################
## Original Resolution Block (ORB)
class ORB(nn.Module):
@@ -198,28 +210,31 @@ def forward(self, x):
res += x
return res
+
##########################################################################
class ORSNet(nn.Module):
def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab):
super(ORSNet, self).__init__()
- self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
- self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
- self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
+ self.orb1 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
+ self.orb2 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
+ self.orb3 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
self.up_enc1 = UpSample(n_feat, scale_unetfeats)
self.up_dec1 = UpSample(n_feat, scale_unetfeats)
- self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
- self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
+ self.up_enc2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats),
+ UpSample(n_feat, scale_unetfeats))
+ self.up_dec2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats),
+ UpSample(n_feat, scale_unetfeats))
- self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
- self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
- self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_enc1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_enc2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_enc3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
- self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
- self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
- self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_dec1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_dec2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
+ self.conv_dec3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias)
def forward(self, x, encoder_outs, decoder_outs):
x = self.orb1(x)
@@ -236,13 +251,17 @@ def forward(self, x, encoder_outs, decoder_outs):
##########################################################################
class MPRNet(nn.Module):
- def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False):
+ def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3,
+ reduction=4, bias=False):
super(MPRNet, self).__init__()
- act=nn.PReLU()
- self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
- self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
- self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
+ act = nn.PReLU()
+ self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias),
+ CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
+ self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias),
+ CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
+ self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias),
+ CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
# Cross Stage Feature Fusion (CSFF)
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
@@ -251,14 +270,15 @@ def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetf
self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
- self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab)
+ self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats,
+ num_cab)
self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
-
- self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias)
- self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias)
- self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias)
+
+ self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
+ self.concat23 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias)
+ self.tail = conv(n_feat + scale_orsnetfeats, out_c, kernel_size, bias=bias)
def forward(self, x3_img):
# Original-resolution Image for Stage 3
@@ -268,14 +288,14 @@ def forward(self, x3_img):
# Multi-Patch Hierarchy: Split Image into four non-overlapping patches
# Two Patches for Stage 2
- x2top_img = x3_img[:,:,0:int(H/2),:]
- x2bot_img = x3_img[:,:,int(H/2):H,:]
+ x2top_img = x3_img[:, :, 0:int(H / 2), :]
+ x2bot_img = x3_img[:, :, int(H / 2):H, :]
# Four Patches for Stage 1
- x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
- x1rtop_img = x2top_img[:,:,:,int(W/2):W]
- x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
- x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
+ x1ltop_img = x2top_img[:, :, :, 0:int(W / 2)]
+ x1rtop_img = x2top_img[:, :, :, int(W / 2):W]
+ x1lbot_img = x2bot_img[:, :, :, 0:int(W / 2)]
+ x1rbot_img = x2bot_img[:, :, :, int(W / 2):W]
##-------------------------------------------
##-------------- Stage 1---------------------
@@ -285,17 +305,17 @@ def forward(self, x3_img):
x1rtop = self.shallow_feat1(x1rtop_img)
x1lbot = self.shallow_feat1(x1lbot_img)
x1rbot = self.shallow_feat1(x1rbot_img)
-
+
## Process features of all 4 patches with Encoder of Stage 1
feat1_ltop = self.stage1_encoder(x1ltop)
feat1_rtop = self.stage1_encoder(x1rtop)
feat1_lbot = self.stage1_encoder(x1lbot)
feat1_rbot = self.stage1_encoder(x1rbot)
-
+
## Concat deep features
- feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
- feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
-
+ feat1_top = [torch.cat((k, v), 3) for k, v in zip(feat1_ltop, feat1_rtop)]
+ feat1_bot = [torch.cat((k, v), 3) for k, v in zip(feat1_lbot, feat1_rbot)]
+
## Pass features through Decoder of Stage 1
res1_top = self.stage1_decoder(feat1_top)
res1_bot = self.stage1_decoder(feat1_bot)
@@ -305,13 +325,13 @@ def forward(self, x3_img):
x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
## Output image at Stage 1
- stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2)
+ stage1_img = torch.cat([stage1_img_top, stage1_img_bot], 2)
##-------------------------------------------
##-------------- Stage 2---------------------
##-------------------------------------------
## Compute Shallow Features
- x2top = self.shallow_feat2(x2top_img)
- x2bot = self.shallow_feat2(x2bot_img)
+ x2top = self.shallow_feat2(x2top_img)
+ x2bot = self.shallow_feat2(x2bot_img)
## Concatenate SAM features of Stage 1 with shallow features of Stage 2
x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1))
@@ -322,7 +342,7 @@ def forward(self, x3_img):
feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
## Concat deep features
- feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
+ feat2 = [torch.cat((k, v), 2) for k, v in zip(feat2_top, feat2_bot)]
## Pass features through Decoder of Stage 2
res2 = self.stage2_decoder(feat2)
@@ -330,18 +350,17 @@ def forward(self, x3_img):
## Apply SAM
x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
-
##-------------------------------------------
##-------------- Stage 3---------------------
##-------------------------------------------
## Compute Shallow Features
- x3 = self.shallow_feat3(x3_img)
+ x3 = self.shallow_feat3(x3_img)
## Concatenate SAM features of Stage 2 with shallow features of Stage 3
x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1))
-
+
x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
stage3_img = self.tail(x3_cat)
- return [stage3_img+x3_img, stage2_img, stage1_img]
+ return [stage3_img + x3_img, stage2_img, stage1_img]
diff --git a/Deblurring/README.md b/Deblurring/README.md
index 312d311..2756478 100644
--- a/Deblurring/README.md
+++ b/Deblurring/README.md
@@ -1,4 +1,5 @@
## Training
+
- Download the [Datasets](Datasets/README.md)
- Train the model with default arguments by running
@@ -12,42 +13,53 @@ python train.py
### Download the [model](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) and place it in ./pretrained_models/
#### Testing on GoPro dataset
-- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/`
+
+- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and
+ place them in `./Datasets/GoPro/test/`
- Run
+
```
python test.py --dataset GoPro
```
#### Testing on HIDE dataset
-- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/`
+
+- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and
+ place them in `./Datasets/HIDE/test/`
- Run
+
```
python test.py --dataset HIDE
```
-
#### Testing on RealBlur-J dataset
-- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J and place them in `./Datasets/RealBlur_J/test/`
+
+- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J
+ and place them in `./Datasets/RealBlur_J/test/`
- Run
+
```
python test.py --dataset RealBlur_J
```
-
-
#### Testing on RealBlur-R dataset
-- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R and place them in `./Datasets/RealBlur_R/test/`
+
+- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R
+ and place them in `./Datasets/RealBlur_R/test/`
- Run
+
```
python test.py --dataset RealBlur_R
```
#### To reproduce PSNR/SSIM scores of the paper on GoPro and HIDE datasets, run this MATLAB script
+
```
evaluate_GOPRO_HIDE.m
```
#### To reproduce PSNR/SSIM scores of the paper on RealBlur dataset, run
+
```
evaluate_RealBlur.py
```
diff --git a/Deblurring/config.py b/Deblurring/config.py
index 199d693..0fb8c39 100644
--- a/Deblurring/config.py
+++ b/Deblurring/config.py
@@ -54,7 +54,6 @@ class Config(object):
"""
def __init__(self, config_yaml: str, config_override: List[Any] = []):
-
self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False
diff --git a/Deblurring/data_RGB.py b/Deblurring/data_RGB.py
index bd10f32..4f3f620 100644
--- a/Deblurring/data_RGB.py
+++ b/Deblurring/data_RGB.py
@@ -1,14 +1,18 @@
import os
+
from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest
+
def get_training_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, img_options)
+
def get_validation_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, img_options)
+
def get_test_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTest(rgb_dir, img_options)
diff --git a/Deblurring/dataset_RGB.py b/Deblurring/dataset_RGB.py
index ba06bc9..9cfff2d 100644
--- a/Deblurring/dataset_RGB.py
+++ b/Deblurring/dataset_RGB.py
@@ -1,15 +1,17 @@
import os
+import random
+
import numpy as np
-from torch.utils.data import Dataset
import torch
-from PIL import Image
import torchvision.transforms.functional as TF
-from pdb import set_trace as stx
-import random
+from PIL import Image
+from torch.utils.data import Dataset
+
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
+
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()
@@ -17,11 +19,11 @@ def __init__(self, rgb_dir, img_options=None):
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
- self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
+ self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
- self.sizex = len(self.tar_filenames) # get the size of target
+ self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
@@ -38,23 +40,23 @@ def __getitem__(self, index):
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
- w,h = tar_img.size
- padw = ps-w if wTesting using weights: ",args.weights)
+utils.load_checkpoint(model_restoration, args.weights)
+print("===>Testing using weights: ", args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
@@ -44,9 +43,10 @@
dataset = args.dataset
rgb_dir_test = os.path.join(args.input_dir, dataset, 'test', 'input')
test_dataset = get_test_data(rgb_dir_test, img_options={})
-test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
+test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False,
+ pin_memory=True)
-result_dir = os.path.join(args.result_dir, dataset)
+result_dir = os.path.join(args.result_dir, dataset)
utils.mkdir(result_dir)
with torch.no_grad():
@@ -54,27 +54,27 @@
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
- input_ = data_test[0].cuda()
+ input_ = data_test[0].cuda()
filenames = data_test[1]
# Padding in case images are not multiples of 8
if dataset == 'RealBlur_J' or dataset == 'RealBlur_R':
factor = 8
- h,w = input_.shape[2], input_.shape[3]
- H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
- padh = H-h if h%factor!=0 else 0
- padw = W-w if w%factor!=0 else 0
- input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
+ h, w = input_.shape[2], input_.shape[3]
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
+ padh = H - h if h % factor != 0 else 0
+ padw = W - w if w % factor != 0 else 0
+ input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
restored = model_restoration(input_)
- restored = torch.clamp(restored[0],0,1)
+ restored = torch.clamp(restored[0], 0, 1)
# Unpad images to original dimensions
if dataset == 'RealBlur_J' or dataset == 'RealBlur_R':
- restored = restored[:,:,:h,:w]
+ restored = restored[:, :, :h, :w]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
for batch in range(len(restored)):
restored_img = img_as_ubyte(restored[batch])
- utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
+ utils.save_img((os.path.join(result_dir, filenames[batch] + '.png')), restored_img)
diff --git a/Deblurring/train.py b/Deblurring/train.py
index 972b433..7bd9c3b 100644
--- a/Deblurring/train.py
+++ b/Deblurring/train.py
@@ -1,5 +1,7 @@
import os
-from config import Config
+
+from config import Config
+
opt = Config('training.yml')
gpus = ','.join([str(i) for i in opt.GPU])
@@ -7,10 +9,10 @@
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
import torch
+
torch.backends.cudnn.benchmark = True
import torch.nn as nn
-import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
@@ -24,7 +26,6 @@
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
-from pdb import set_trace as stx
######### Set Seeds ###########
random.seed(1234)
@@ -37,13 +38,13 @@
session = opt.MODEL.SESSION
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
-model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
+model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utils.mkdir(result_dir)
utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR
-val_dir = opt.TRAINING.VAL_DIR
+val_dir = opt.TRAINING.VAL_DIR
######### Model ###########
model_restoration = MPRNet()
@@ -51,24 +52,23 @@
device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
- print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
-
+ print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
new_lr = opt.OPTIM.LR_INITIAL
-optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)
-
+optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8)
######### Scheduler ###########
warmup_epochs = 3
-scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
+scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS - warmup_epochs,
+ eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()
######### Resume ###########
if opt.TRAINING.RESUME:
- path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
- utils.load_checkpoint(model_restoration,path_chk_rest)
+ path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
+ utils.load_checkpoint(model_restoration, path_chk_rest)
start_epoch = utils.load_start_epoch(path_chk_rest) + 1
utils.load_optim(optimizer, path_chk_rest)
@@ -79,21 +79,23 @@
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
-if len(device_ids)>1:
- model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
+if len(device_ids) > 1:
+ model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
######### DataLoaders ###########
-train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
-train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
+train_dataset = get_training_data(train_dir, {'patch_size': opt.TRAINING.TRAIN_PS})
+train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16,
+ drop_last=False, pin_memory=True)
-val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS})
-val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
+val_dataset = get_validation_data(val_dir, {'patch_size': opt.TRAINING.VAL_PS})
+val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False,
+ pin_memory=True)
-print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
+print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')
best_psnr = 0
@@ -115,18 +117,18 @@
input_ = data[1].cuda()
restored = model_restoration(input_)
-
+
# Compute loss at each stage
- loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))])
- loss_edge = np.sum([criterion_edge(restored[j],target) for j in range(len(restored))])
- loss = (loss_char) + (0.05*loss_edge)
-
+ loss_char = np.sum([criterion_char(restored[j], target) for j in range(len(restored))])
+ loss_edge = np.sum([criterion_edge(restored[j], target) for j in range(len(restored))])
+ loss = (loss_char) + (0.05 * loss_edge)
+
loss.backward()
optimizer.step()
- epoch_loss +=loss.item()
+ epoch_loss += loss.item()
#### Evaluation ####
- if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0:
+ if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0:
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
@@ -137,34 +139,34 @@
restored = model_restoration(input_)
restored = restored[0]
- for res,tar in zip(restored,target):
+ for res, tar in zip(restored, target):
psnr_val_rgb.append(utils.torchPSNR(res, tar))
- psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
+ psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
- torch.save({'epoch': epoch,
+ torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
- 'optimizer' : optimizer.state_dict()
- }, os.path.join(model_dir,"model_best.pth"))
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, "model_best.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
- torch.save({'epoch': epoch,
+ torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
- 'optimizer' : optimizer.state_dict()
- }, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
scheduler.step()
-
+
print("------------------------------------------------------------------")
- print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
+ print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time,
+ epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
- torch.save({'epoch': epoch,
+ torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
- 'optimizer' : optimizer.state_dict()
- }, os.path.join(model_dir,"model_latest.pth"))
-
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, "model_latest.pth"))
diff --git a/Deblurring/utils/__init__.py b/Deblurring/utils/__init__.py
index b2e9d2d..cc5d1e3 100644
--- a/Deblurring/utils/__init__.py
+++ b/Deblurring/utils/__init__.py
@@ -1,4 +1,4 @@
+from .dataset_utils import *
from .dir_utils import *
from .image_utils import *
from .model_utils import *
-from .dataset_utils import *
diff --git a/Deblurring/utils/dataset_utils.py b/Deblurring/utils/dataset_utils.py
index b57f474..e85cd67 100644
--- a/Deblurring/utils/dataset_utils.py
+++ b/Deblurring/utils/dataset_utils.py
@@ -1,5 +1,6 @@
import torch
+
class MixUp_AUG:
def __init__(self):
self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
@@ -10,9 +11,9 @@ def aug(self, rgb_gt, rgb_noisy):
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
- lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
+ lam = self.dist.rsample((bs, 1)).view(-1, 1, 1, 1).cuda()
- rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
- rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
+ rgb_gt = lam * rgb_gt + (1 - lam) * rgb_gt2
+ rgb_noisy = lam * rgb_noisy + (1 - lam) * rgb_noisy2
- return rgb_gt, rgb_noisy
\ No newline at end of file
+ return rgb_gt, rgb_noisy
diff --git a/Deblurring/utils/dir_utils.py b/Deblurring/utils/dir_utils.py
index 3be7063..fb653e5 100644
--- a/Deblurring/utils/dir_utils.py
+++ b/Deblurring/utils/dir_utils.py
@@ -1,7 +1,9 @@
import os
-from natsort import natsorted
from glob import glob
+from natsort import natsorted
+
+
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
@@ -9,10 +11,12 @@ def mkdirs(paths):
else:
mkdir(paths)
+
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
+
def get_last_path(path, session):
- x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
- return x
\ No newline at end of file
+ x = natsorted(glob(os.path.join(path, '*%s' % session)))[-1]
+ return x
diff --git a/Deblurring/utils/image_utils.py b/Deblurring/utils/image_utils.py
index 29e3f98..5ec2ff5 100644
--- a/Deblurring/utils/image_utils.py
+++ b/Deblurring/utils/image_utils.py
@@ -1,18 +1,21 @@
-import torch
-import numpy as np
import cv2
+import numpy as np
+import torch
+
def torchPSNR(tar_img, prd_img):
- imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
- rmse = (imdff**2).mean().sqrt()
- ps = 20*torch.log10(1/rmse)
+ imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)
+ rmse = (imdff ** 2).mean().sqrt()
+ ps = 20 * torch.log10(1 / rmse)
return ps
+
def save_img(filepath, img):
- cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
def numpyPSNR(tar_img, prd_img):
imdff = np.float32(prd_img) - np.float32(tar_img)
- rmse = np.sqrt(np.mean(imdff**2))
- ps = 20*np.log10(255/rmse)
+ rmse = np.sqrt(np.mean(imdff ** 2))
+ ps = 20 * np.log10(255 / rmse)
return ps
diff --git a/Deblurring/utils/model_utils.py b/Deblurring/utils/model_utils.py
index 154c5d8..58cfc00 100644
--- a/Deblurring/utils/model_utils.py
+++ b/Deblurring/utils/model_utils.py
@@ -1,24 +1,30 @@
-import torch
import os
from collections import OrderedDict
+import torch
+
+
def freeze(model):
for p in model.parameters():
- p.requires_grad=False
+ p.requires_grad = False
+
def unfreeze(model):
for p in model.parameters():
- p.requires_grad=True
+ p.requires_grad = True
+
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
+
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
- model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
+ model_out_path = os.path.join(model_dir, "model_epoch_{}_{}.pth".format(epoch, session))
torch.save(state, model_out_path)
+
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
@@ -27,7 +33,7 @@ def load_checkpoint(model, weights):
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
- name = k[7:] # remove `module.`
+ name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
@@ -37,15 +43,17 @@ def load_checkpoint_multigpu(model, weights):
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
- name = k[7:] # remove `module.`
+ name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
+
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
+
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..d9f8437
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,52 @@
+FROM python:3.7-slim as build
+
+ENV PYTHONFAULTHANDLER=1 \
+ PYTHONUNBUFFERED=1 \
+ PYTHONHASHSEED=random \
+ PIP_NO_CACHE_DIR=off \
+ PIP_DISABLE_PIP_VERSION_CHECK=on \
+ PIP_DEFAULT_TIMEOUT=100
+
+WORKDIR /app
+
+COPY ./requirements.txt /app/requirements.txt
+RUN pip3 install --upgrade pip && pip3 install \
+ --user \
+ --no-warn-script-location \
+# --no-cache-dir \
+ -r requirements.txt
+
+# Copy current folder to docker working dir
+COPY . .
+
+
+FROM python:3.7-slim as output
+#Install shared object & library for cv2 (opencv)
+RUN apt-get update -y && apt-get install -y --no-install-recommends libgl1-mesa-dev libglib2.0-0 \
+ && useradd -m r3v3r \
+ && mkdir /app \
+ && chown -R r3v3r:r3v3r /app
+
+#RUN apt-get update -y && useradd -m r3v3r \
+# && mkdir /app \
+# && chown -R r3v3r:r3v3r /app
+
+
+USER r3v3r
+WORKDIR /app
+
+ENV PYTHONPATH=/app \
+ HOME=/home/r3v3r \
+ PATH="/home/r3v3r/.local/bin:${PATH}"
+
+COPY --from=build --chown=r3v3r:r3v3r /usr/local/lib/python3.7/site-packages /usr/local/lib/python3.7/site-packages
+COPY --from=build --chown=r3v3r:r3v3r /root/.local /home/r3v3r/.local
+
+#COPY --from=build --chown=r3v3r:r3v3r /app/src /app/src
+##COPY --from=build --chown=r3v3r:r3v3r /app/models /app/models
+#COPY --from=build --chown=r3v3r:r3v3r /app/conf /app/conf
+#COPY --from=build --chown=r3v3r:r3v3r /app/main.py /app/main.py
+
+COPY --from=build --chown=r3v3r:r3v3r /app /app
+
+CMD ["python3","/app/main.py", "development"]
diff --git a/README.md b/README.md
index 3e4f0c8..a5a706a 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,31 @@
# Instruction
+
## Install linraries
+
```
pip install requirements.txt
```
+
## Testing API
+
- First, running the ./app.py file.
+
```
python app.py
```
+
- Then runing the ./request.py to send request to the API.
+
```
python request.py
```
+
## About the api
+
- The API receive raw JSON data.
-- You have to post raw JSON data, which contains the element 'urls' having an array of images need to be enhanced. Here is an example.
+- You have to post raw JSON data, which contains the element 'urls' having an array of images need to be enhanced. Here
+ is an example.
+
```
{
"urls":[
@@ -23,9 +34,12 @@ python request.py
]
}
```
+
- In Postman, we will post data like this.
-![alt text](images/1.png)
+ ![alt text](images/1.png)
+
## About the result
+
```
{
"result": [
@@ -41,5 +55,6 @@ python request.py
"time": 25.083990812301636
}
```
+
![](images/2.png)
diff --git a/app.py b/app.py
deleted file mode 100644
index 88a23d6..0000000
--- a/app.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import time
-
-from flask import Flask, request
-from flask_cors import CORS
-from py_profiler.profiler_controller import profiler_blueprint
-
-from enhance_service import *
-
-app = Flask(__name__)
-app.register_blueprint(profiler_blueprint)
-CORS(app)
-
-folder_out = "static/output/"
-os.makedirs(folder_out, exist_ok=True)
-
-enhance_service = EnhanceService(
- 'Deblurring/pretrained_models/model_deblurring.pth',
- use_cpu=True
-)
-
-
-@app.route('/')
-def index():
- return "hello"
-
-
-@app.route('/enhance', methods=['POST'])
-def enhance():
- urls = dict(request.get_json())['urls']
- #
- begin = time.time()
- output_result = enhance_service.process(urls, folder_out)
- print(f'Output: {output_result}')
- return {
- 'time': time.time() - begin,
- 'result': output_result
- }
-
-
-if __name__ == '__main__':
- app.run()
diff --git a/conf/development.yml b/conf/development.yml
new file mode 100644
index 0000000..e2af202
--- /dev/null
+++ b/conf/development.yml
@@ -0,0 +1,11 @@
+server:
+ http:
+ nthreads: 4
+ port: 31558
+# port: 8080
+downloader:
+ num_threads: 8
+
+run_on_cpu: True
+use_deblur_model: False
+deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth"
diff --git a/conf/production.yml b/conf/production.yml
new file mode 100644
index 0000000..23b994b
--- /dev/null
+++ b/conf/production.yml
@@ -0,0 +1,10 @@
+server:
+ http:
+ nthreads: 4
+ port: 8080
+
+downloader:
+ num_threads: 8
+run_on_cpu: False
+use_deblur_model: True
+deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth"
\ No newline at end of file
diff --git a/conf/staging.yml b/conf/staging.yml
new file mode 100644
index 0000000..fd759b3
--- /dev/null
+++ b/conf/staging.yml
@@ -0,0 +1,11 @@
+server:
+ http:
+ nthreads: 8
+ port: 8080
+
+downloader:
+ num_threads: 8
+
+run_on_cpu: True
+use_deblur_model: False
+deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth"
\ No newline at end of file
diff --git a/controller/__init__.py b/controller/__init__.py
new file mode 100644
index 0000000..113b895
--- /dev/null
+++ b/controller/__init__.py
@@ -0,0 +1 @@
+from .enhance_controller import *
diff --git a/controller/enhance_controller.py b/controller/enhance_controller.py
new file mode 100644
index 0000000..c7b12da
--- /dev/null
+++ b/controller/enhance_controller.py
@@ -0,0 +1,21 @@
+import os
+
+from dependency_injector.wiring import Provide, inject
+from flask import Blueprint, request
+
+from enhance_service import EnhanceService
+from .error_handlers import build_response
+from module.application_container import ApplicationContainer
+
+folder_out = "static/output/"
+os.makedirs(folder_out, exist_ok=True)
+
+enhance_blueprint = Blueprint("enhance", __name__)
+
+
+@enhance_blueprint.route("/enhance", methods=["POST"])
+@inject
+def detect(enhance_service: EnhanceService = Provide[ApplicationContainer.enhance_service]):
+ urls = dict(request.get_json())['urls']
+ output_result = enhance_service.process(urls, folder_out)
+ return build_response(result=output_result)
diff --git a/controller/error_handlers.py b/controller/error_handlers.py
new file mode 100644
index 0000000..6f57f96
--- /dev/null
+++ b/controller/error_handlers.py
@@ -0,0 +1,31 @@
+import json
+import logging
+import traceback
+
+from domain.errors import RError, InternalError
+
+
+def handle_defined_errors(error: RError):
+ return build_error_response(error)
+
+
+def handle_other_exceptions(error: Exception):
+ return build_error_response(InternalError(str(error), error))
+
+
+def build_response(**kwargs):
+ response = {}
+ response.update(kwargs)
+ return json.dumps(response)
+
+
+def build_error_response(error: RError):
+ print(error)
+ if error.__cause__ is not None:
+ logging.error("".join(traceback.TracebackException.from_exception(error.__cause__).format()))
+ else:
+ logging.error(error.message)
+ return {
+ "error": error.get_error(),
+ "message": error.message
+ }
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..1470854
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,13 @@
+version: "3.3"
+services:
+ image_enhancement:
+# image: registry.rever.vn/ai-research/room-detection:master
+# image: registry.rever.vn/ai-research/room-detection:v1.0.0
+ image: image_enhancement
+ build: .
+ command: ["python3","/app/main.py", "staging"]
+ ports:
+ - 31558:8080
+ volumes:
+ - ./logs:/app/logs
+ restart: always
diff --git a/domain/__init__.py b/domain/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/domain/errors.py b/domain/errors.py
new file mode 100644
index 0000000..29ae760
--- /dev/null
+++ b/domain/errors.py
@@ -0,0 +1,45 @@
+REQUEST_INVALID_ERR = "request_invalid_error"
+INTERNAL_ERR = "internal_error"
+
+
+class RError(Exception):
+ def __init__(
+ self,
+ message,
+ cause: Exception = None
+ ):
+ super(RError, self).__init__(message, cause)
+ self.message = message
+ self.__cause__ = cause
+
+ def get_message(self) -> str:
+ return self.message
+
+ def get_error(self) -> str:
+ pass
+
+
+class RequestInvalidError(RError):
+ def __init__(
+ self,
+ message,
+ cause: Exception = None):
+ super(RequestInvalidError, self).__init__(
+ message, cause
+ )
+
+ def get_error(self) -> str:
+ return REQUEST_INVALID_ERR
+
+
+class InternalError(RError):
+ def __init__(
+ self,
+ message,
+ cause: Exception = None):
+ super(InternalError, self).__init__(
+ message, cause
+ )
+
+ def get_error(self) -> str:
+ return INTERNAL_ERR
diff --git a/domain/jsonable.py b/domain/jsonable.py
new file mode 100644
index 0000000..d7b893c
--- /dev/null
+++ b/domain/jsonable.py
@@ -0,0 +1,10 @@
+import json
+
+
+class Jsonable:
+
+ def to_json(self) -> str:
+ return json.dumps(self.to_dict(), ensure_ascii=False)
+
+ def to_dict(self) -> dict:
+ pass
diff --git a/enhance_service.py b/enhance_service.py
index 912b6a5..b9798d3 100644
--- a/enhance_service.py
+++ b/enhance_service.py
@@ -1,3 +1,4 @@
+import logging
import os
# from enlighten_inference import EnlightenOnnxModel
import uuid
@@ -26,8 +27,9 @@
class EnhanceService:
- def __init__(self, deblur_model_path, use_cpu: bool = True):
+ def __init__(self, deblur_model_path, use_cpu: bool = True, use_deblur_model: bool = False):
self.use_cpu = use_cpu
+ self.use_deblur_model = use_deblur_model
# Executor to run enhance process concurrently
self.executor = ThreadPoolExecutor(max_workers=8)
# A downloader to download image using a thread pool with 16 threads
@@ -46,18 +48,17 @@ def __init__(self, deblur_model_path, use_cpu: bool = True):
self.white_balancer = WhiteBalancer()
self.ceiq_scoring_model = CEIQ()
- print(f'white_balancer: {type(self.white_balancer)}')
- print(f'CEIQ_model: {type(self.ceiq_scoring_model)}')
- print(f'Model: {type(self.deblur_model)}')
-
- print("Init successfully")
+ logging.info(f'white_balancer: {type(self.white_balancer)}')
+ logging.info(f'CEIQ_model: {type(self.ceiq_scoring_model)}')
+ logging.info(f'Model: {type(self.deblur_model)}')
+ logging.info("Init successfully")
@profiler()
def process(self, image_urls: List[str], enhanced_out_dir):
image_dict = self.image_downloader.bulk_download_as_image(image_urls)
if len(image_urls) == 0:
raise Exception(f"No image urls found at {image_urls}")
- print('Number of files: ', len(image_dict))
+ logging.info('Number of files: %d', len(image_dict))
future_to_checks = {
self.executor.submit(self._enhance_image, image, 8, enhanced_out_dir): url
@@ -69,27 +70,30 @@ def process(self, image_urls: List[str], enhanced_out_dir):
# The try-except-else clause is omitted here
for future in futures.as_completed(future_to_checks):
url = future_to_checks[future]
- output_path = future.result()
- result_dict[url] = output_path
+ output_path, enhanced_score = future.result()
+ result_dict[url] = {
+ 'enhanced_url': output_path,
+ 'enhanced_score': enhanced_score
+ }
return result_dict
@profiler()
- def _enhance_image(self, image, factor, out_dir) -> str:
- restored = self._deblur_image(image, factor)
+ def _enhance_image(self, image, factor, out_dir) -> [str, float]:
+ restored = self._deblur_image(image, factor) if self.use_deblur_model else np.asarray(image)
# processed = enlighten_model.predict(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
# processed = img_as_ubyte(processed)
img_output = self._process_white_balancing(restored)
restored = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)
- scores = self._calc_score([restored, img_output])
- if scores[0] > scores[1]:
+ origin_score, improved_score = self._calc_score([restored, img_output])
+ if origin_score > improved_score:
img_output = restored
- output_path = os.path.join(out_dir, f'{uuid.uuid1()}.png')
+ output_path = os.path.join(out_dir, f'{uuid.uuid1()}.jpg')
save_img(output_path, img_output)
- return output_path
+ return [output_path, ((improved_score - origin_score) / origin_score)]
@profiler()
def _deblur_image(self, img, factor: int = 8):
@@ -134,12 +138,12 @@ def _process_white_balancing(self, input_image, threshold: float = 0.3):
# Process image for gamma correction
output_image = None
if t < -threshold: # Dimmed Image
- print('Dimmed')
+ logging.info('Dimmed')
result = self.white_balancer.process_dimmed(Y)
YCrCb[:, :, 0] = result
output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR)
elif t > threshold:
- print('Bright Image') # Bright Image
+ logging.info('Bright Image') # Bright Image
result = self.white_balancer.process_bright(Y)
YCrCb[:, :, 0] = result
output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR)
@@ -154,5 +158,5 @@ def _calc_score(self, images):
# org_score = CEIQ_model.predict(np.expand_dims(restored, axis=0), 1)[0]
# imp_score = CEIQ_model.predict(np.expand_dims(img_output, axis=0), 1)[0]
scores = self.ceiq_scoring_model.predict(images, option=1)
- print(f"Scores: {scores[0]} -> {scores[1]}: Improved: {((scores[1] - scores[0]) * 100 / scores[0])} %")
+ logging.info(f"Scores: {scores[0]} -> {scores[1]}: Improved: {((scores[1] - scores[0]) * 100 / scores[0])} %")
return scores
diff --git a/images/1.png b/images/1.png
deleted file mode 100644
index 8042a7e..0000000
Binary files a/images/1.png and /dev/null differ
diff --git a/images/2.png b/images/2.png
deleted file mode 100644
index 61f59fd..0000000
Binary files a/images/2.png and /dev/null differ
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..da4a1dd
--- /dev/null
+++ b/main.py
@@ -0,0 +1,50 @@
+import logging
+import os
+import sys
+
+from dependency_injector import containers
+from dependency_injector.wiring import Provide, inject
+from flask import Flask
+from py_profiler.profiler_controller import profiler_blueprint
+from waitress import serve
+
+from controller import enhance_controller, enhance_blueprint
+from controller.error_handlers import handle_defined_errors, handle_other_exceptions
+from domain.errors import RError
+from module import ApplicationContainer
+from module.injector import create_injector
+from utils.setup_logging import setup_logging
+
+
+@inject
+def setup_http_server(
+ injector: containers.DeclarativeContainer,
+ port: int = Provide[ApplicationContainer.config.server.http.port.as_int()],
+ nthreads: int = Provide[ApplicationContainer.config.server.http.nthreads.as_int()],
+) -> None:
+ app = Flask(__name__)
+ app.url_map.strict_slashes = False
+ app.debug = True
+
+ app.register_blueprint(enhance_blueprint)
+ app.register_blueprint(profiler_blueprint)
+
+ app.register_error_handler(RError, handle_defined_errors)
+ app.register_error_handler(Exception, handle_other_exceptions)
+
+ logging.info(f"Created http server at port = {port} with {nthreads} concurrent threads.")
+ serve(
+ app, host="0.0.0.0",
+ port=port,
+ threads=nthreads if nthreads is not None else 4
+ )
+
+
+if __name__ == "__main__":
+ os.environ['APP_MODE'] = sys.argv[1] if len(sys.argv) > 1 else 'development'
+ setup_logging()
+ injector = create_injector([
+ sys.modules[__name__],
+ enhance_controller,
+ ], mode=os.environ['APP_MODE'])
+ setup_http_server(injector)
diff --git a/CEIQ_model_v1_1.pickle b/models/CEIQ_model_v1_1.pickle
similarity index 100%
rename from CEIQ_model_v1_1.pickle
rename to models/CEIQ_model_v1_1.pickle
diff --git a/module/__init__.py b/module/__init__.py
new file mode 100644
index 0000000..58409e7
--- /dev/null
+++ b/module/__init__.py
@@ -0,0 +1 @@
+from .application_container import *
diff --git a/module/application_container.py b/module/application_container.py
new file mode 100644
index 0000000..17ac2c6
--- /dev/null
+++ b/module/application_container.py
@@ -0,0 +1,21 @@
+from dependency_injector import containers, providers
+
+#
+# @author: anhlt
+#
+from enhance_service import EnhanceService
+from image_downloader import ImageDownloader
+
+
+class ApplicationContainer(containers.DeclarativeContainer):
+ config = providers.Configuration()
+
+ downloader = providers.Singleton(
+ ImageDownloader,
+ num_threads=config.downloader.num_threads
+ )
+ enhance_service = providers.Singleton(
+ EnhanceService,
+ deblur_model_path=config.deblur_model_path,
+ use_cpu=config.run_on_cpu
+ )
diff --git a/module/injector.py b/module/injector.py
new file mode 100644
index 0000000..56740a8
--- /dev/null
+++ b/module/injector.py
@@ -0,0 +1,21 @@
+import os
+
+from .application_container import ApplicationContainer
+
+
+def create_injector(
+ modules: list,
+ mode: str = 'development'
+):
+ os.environ['APP_MODE'] = mode if mode is not None else os.environ['APP_MODE']
+ mode = os.environ['APP_MODE']
+
+ print(f"Mode: {mode}")
+ print(f"Loaded config: config/{mode}.yml")
+ print(f"Modules: {modules}")
+ injector = ApplicationContainer()
+ injector.config.from_yaml(f"conf/{mode}.yml")
+ injector.wire(modules)
+ print("Wire completed")
+
+ return injector
diff --git a/request.py b/request.py
index 9fcc58a..952fe79 100644
--- a/request.py
+++ b/request.py
@@ -1,24 +1,25 @@
-import requests
import json
-url = "http://127.0.0.1:5000/enhance"
+import requests
+
+url = "http://127.0.0.1:8080/enhance"
payload = json.dumps({
- "urls": [
- "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
- "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
- "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
- "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
- "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
- "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
- "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
- "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
- "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
- "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0"
- ]
+ "urls": [
+ "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
+ "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
+ "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
+ "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
+ "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
+ "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
+ "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
+ "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0",
+ "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg",
+ "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0"
+ ]
})
headers = {
- 'Content-Type': 'application/json'
+ 'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
diff --git a/requirements.txt b/requirements.txt
index 51b8790..06a406e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -26,4 +26,7 @@ torch==1.9.0
torchvision==0.10.0
typing-extensions==3.10.0.0
Werkzeug==2.0.1
+dependency-injector
+pyyaml
+waitress
py_profiler
\ No newline at end of file
diff --git a/test_imgs/1.png b/static/test_imgs/1.png
similarity index 100%
rename from test_imgs/1.png
rename to static/test_imgs/1.png
diff --git a/test_imgs/2.png b/static/test_imgs/2.png
similarity index 100%
rename from test_imgs/2.png
rename to static/test_imgs/2.png
diff --git a/test_imgs/3.png b/static/test_imgs/3.png
similarity index 100%
rename from test_imgs/3.png
rename to static/test_imgs/3.png
diff --git a/utils/model_utils.py b/utils/model_utils.py
index 4cc0c8f..24e86f5 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -11,6 +11,7 @@ def bypass_ssl_verify():
if (not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None)):
ssl._create_default_https_context = ssl._create_unverified_context
+
def save_img(filepath, img):
import cv2
# cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
diff --git a/utils/setup_logging.py b/utils/setup_logging.py
new file mode 100644
index 0000000..4a761fe
--- /dev/null
+++ b/utils/setup_logging.py
@@ -0,0 +1,109 @@
+import logging
+import os
+from logging.config import dictConfig
+
+# # for sending error logs to slack
+# import json
+# import requests
+# class HTTPSlackHandler(logging.Handler):
+# def emit(self, record):
+# log_entry = self.format(record)
+# json_text = json.dumps({"text": log_entry})
+# url = 'https://hooks.slack.com/services//'
+# return requests.post(url, json_text, headers={"Content-type": "application/json"}).content
+
+
+# debug settings
+
+MAX_LOG_FILE_SIZE = 10 * 1024 * 1024
+
+
+def setup_logging():
+ os.makedirs("logs", exist_ok=True)
+ debug_mode = os.environ.get("APP_MODE", "development") == "development"
+
+ dictConfig({
+ "version": 1,
+ "disable_existing_loggers": True,
+ "formatters": {
+ "default": {
+ "format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
+ },
+ "access": {
+ "format": "%(message)s",
+ }
+ },
+ "handlers": {
+ "console": {
+ "level": "INFO",
+ "class": "logging.StreamHandler",
+ "formatter": "default",
+ "stream": "ext://sys.stdout",
+ },
+ # "email": {
+ # "class": "logging.handlers.SMTPHandler",
+ # "formatter": "default",
+ # "level": "ERROR",
+ # "mailhost": ("smtp.example.com", 587),
+ # "fromaddr": "devops@example.com",
+ # "toaddrs": ["receiver@example.com", "receiver2@example.com"],
+ # "subject": "Error Logs",
+ # "credentials": ("username", "password"),
+ # },
+ # "slack": {
+ # "class": "app.HTTPSlackHandler",
+ # "formatter": "default",
+ # "level": "ERROR",
+ # },
+ "service_file": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "formatter": "default",
+ "filename": "logs/service.log",
+ "maxBytes": MAX_LOG_FILE_SIZE,
+ "backupCount": 5,
+ "delay": "True",
+ },
+ "error_file": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "formatter": "default",
+ "filename": "logs/error.log",
+ "maxBytes": MAX_LOG_FILE_SIZE,
+ "backupCount": 10,
+ "delay": "True",
+ },
+ "access_file": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "formatter": "access",
+ "filename": "logs/access.log",
+ "maxBytes": MAX_LOG_FILE_SIZE,
+ "backupCount": 10,
+ "delay": "True",
+ }
+ },
+ "loggers": {
+ "error": {
+ "handlers": ["console"] if debug_mode else [
+ "console",
+ # "slack",
+ "error_file"
+ ],
+ "level": "INFO",
+ "propagate": False,
+ },
+ "access": {
+ "handlers": ["console"] if debug_mode else ["console", "access_file"],
+ "level": "INFO",
+ "propagate": False,
+ }
+ },
+ "root": {
+ "level": "DEBUG" if debug_mode else "INFO",
+ "handlers": ["console", "service_file"] if debug_mode else [
+ "console",
+ # "slack"
+ "service_file"
+ ],
+ }
+ })
+
+ logging.info(f"Setting up logging: DEBUG = {debug_mode}")