diff --git a/.gitignore b/.gitignore index cb5263e..e3c02ab 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ *__pycache__ static/output - +logs diff --git a/CEIQ.py b/CEIQ.py index 5752306..5c0682d 100644 --- a/CEIQ.py +++ b/CEIQ.py @@ -1,14 +1,13 @@ +import pickle + import cv2 -import os -import matplotlib.pyplot as plt import numpy as np from skimage.metrics import structural_similarity as ssim -from scipy.stats import entropy as ent -import pickle + class CEIQ: def __init__(self): - with open('CEIQ_model_v1_1.pickle', 'rb') as f: + with open('models/CEIQ_model_v1_1.pickle', 'rb') as f: self.model = pickle.load(f) def entropy(self, hist, bit_instead_of_nat=False): @@ -21,12 +20,12 @@ def entropy(self, hist, bit_instead_of_nat=False): """ # h = h[np.where(h!=0)[0]] h = np.asarray(hist, dtype=np.float64) - if h.sum()<=0 or (h<0).any(): - print("[entropy] WARNING, malformed/empty input %s. Returning None."%str(hist)) + if h.sum() <= 0 or (h < 0).any(): + print("[entropy] WARNING, malformed/empty input %s. Returning None." % str(hist)) return None - h = h/h.sum() + h = h / h.sum() log_fn = np.ma.log2 if bit_instead_of_nat else np.ma.log - return -(h*log_fn(h)).sum() + return -(h * log_fn(h)).sum() def cross_entropy(self, x, y): """ Computes cross entropy between two distributions. @@ -47,30 +46,30 @@ def cross_entropy(self, x, y): # Ignore zero 'y' elements. mask = y > 0 x = x[mask] - y = y[mask] - ce = -np.sum(x * np.log(y)) + y = y[mask] + ce = -np.sum(x * np.log(y)) return ce def generate_x(self, img_path, option=0): if option == 0: Ig = cv2.imread(img_path) - Ig = 0.299*Ig[:, :, 2] + 0.587*Ig[:, :, 1] + 0.114*Ig[:, :, 0] + Ig = 0.299 * Ig[:, :, 2] + 0.587 * Ig[:, :, 1] + 0.114 * Ig[:, :, 0] Ig = Ig.astype('uint8') else: # Ig = cv2.cvtColor(img_path, cv2.COLOR_BGR2GRAY) - Ig = 0.299*img_path[:, :, 2] + 0.587*img_path[:, :, 1] + 0.114*img_path[:, :, 0] + Ig = 0.299 * img_path[:, :, 2] + 0.587 * img_path[:, :, 1] + 0.114 * img_path[:, :, 0] Ig = Ig.astype('uint8') Ie = cv2.equalizeHist(Ig) ### Calculate ssim ### ssim_ig_ie, _ = ssim(Ig, Ie, full=True) ### Get histograms ### - histg = cv2.calcHist([Ig],[0],None,[128],[0,256]) - histe = cv2.calcHist([Ie],[0],None,[128],[0,256]) + histg = cv2.calcHist([Ig], [0], None, [128], [0, 256]) + histe = cv2.calcHist([Ie], [0], None, [128], [0, 256]) histg = np.reshape(histg, (histg.shape[0])) histe = np.reshape(histe, (histe.shape[0])) - zero_idsg = np.where(histg==0)[0] - zero_idse = np.where(histe==0)[0] + zero_idsg = np.where(histg == 0)[0] + zero_idse = np.where(histe == 0)[0] zero_ids = np.unique(np.concatenate((zero_idsg, zero_idse))) # print(zero_ids) histg = np.delete(histg, zero_ids) diff --git a/Deblurring/Datasets/README.md b/Deblurring/Datasets/README.md index 8ef3a6b..a9f0eb8 100644 --- a/Deblurring/Datasets/README.md +++ b/Deblurring/Datasets/README.md @@ -1,14 +1,15 @@ -Download datasets from the google drive links and place them in this directory. Your directory tree should look like this +Download datasets from the google drive links and place them in this directory. Your directory tree should look like +this `GoPro`
-  `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
-  `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) +`├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
+`└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) `HIDE`
-   `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) +`└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) `RealBlur_J`
-  `└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) +`└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) `RealBlur_R`
-  `└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) +`└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) diff --git a/Deblurring/MPRNet.py b/Deblurring/MPRNet.py index 5e1b040..fb9d6ab 100644 --- a/Deblurring/MPRNet.py +++ b/Deblurring/MPRNet.py @@ -6,14 +6,13 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from pdb import set_trace as stx + ########################################################################## -def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): +def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): return nn.Conv2d( in_channels, out_channels, kernel_size, - padding=(kernel_size//2), bias=bias, stride = stride) + padding=(kernel_size // 2), bias=bias, stride=stride) ########################################################################## @@ -25,10 +24,10 @@ def __init__(self, channel, reduction=16, bias=False): self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( - nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), - nn.ReLU(inplace=True), - nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), - nn.Sigmoid() + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), + nn.Sigmoid() ) def forward(self, x): @@ -56,6 +55,7 @@ def forward(self, x): res += x return res + ########################################################################## ## Supervised Attention Module class SAM(nn.Module): @@ -69,10 +69,11 @@ def forward(self, x, x_img): x1 = self.conv1(x) img = self.conv2(x) + x_img x2 = torch.sigmoid(self.conv3(img)) - x1 = x1*x2 - x1 = x1+x + x1 = x1 * x2 + x1 = x1 + x return x1, img + ########################################################################## ## U-Net @@ -80,26 +81,30 @@ class Encoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): super(Encoder, self).__init__() - self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] - self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] - self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] + self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] + self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in + range(2)] + self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in + range(2)] self.encoder_level1 = nn.Sequential(*self.encoder_level1) self.encoder_level2 = nn.Sequential(*self.encoder_level2) self.encoder_level3 = nn.Sequential(*self.encoder_level3) - self.down12 = DownSample(n_feat, scale_unetfeats) - self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) + self.down12 = DownSample(n_feat, scale_unetfeats) + self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) # Cross Stage Feature Fusion (CSFF) if csff: - self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) - self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) - self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) + self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) + self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, + bias=bias) - self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) - self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) - self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) + self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) + self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, + bias=bias) def forward(self, x, encoder_outs=None, decoder_outs=None): enc1 = self.encoder_level1(x) @@ -117,26 +122,29 @@ def forward(self, x, encoder_outs=None, decoder_outs=None): enc3 = self.encoder_level3(x) if (encoder_outs is not None) and (decoder_outs is not None): enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) - + return [enc1, enc2, enc3] + class Decoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): super(Decoder, self).__init__() - self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] - self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] - self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] + self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] + self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in + range(2)] + self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in + range(2)] self.decoder_level1 = nn.Sequential(*self.decoder_level1) self.decoder_level2 = nn.Sequential(*self.decoder_level2) self.decoder_level3 = nn.Sequential(*self.decoder_level3) - self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) - self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) + self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) + self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) - self.up21 = SkipUpSample(n_feat, scale_unetfeats) - self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) + self.up21 = SkipUpSample(n_feat, scale_unetfeats) + self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) def forward(self, outs): enc1, enc2, enc3 = outs @@ -148,41 +156,45 @@ def forward(self, outs): x = self.up21(dec2, self.skip_attn1(enc1)) dec1 = self.decoder_level1(x) - return [dec1,dec2,dec3] + return [dec1, dec2, dec3] + ########################################################################## ##---------- Resizing Modules ---------- class DownSample(nn.Module): - def __init__(self, in_channels,s_factor): + def __init__(self, in_channels, s_factor): super(DownSample, self).__init__() self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), - nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False)) + nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False)) def forward(self, x): x = self.down(x) return x + class UpSample(nn.Module): - def __init__(self, in_channels,s_factor): + def __init__(self, in_channels, s_factor): super(UpSample, self).__init__() self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) + nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) def forward(self, x): x = self.up(x) return x + class SkipUpSample(nn.Module): - def __init__(self, in_channels,s_factor): + def __init__(self, in_channels, s_factor): super(SkipUpSample, self).__init__() self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) + nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) def forward(self, x, y): x = self.up(x) x = x + y return x + ########################################################################## ## Original Resolution Block (ORB) class ORB(nn.Module): @@ -198,28 +210,31 @@ def forward(self, x): res += x return res + ########################################################################## class ORSNet(nn.Module): def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): super(ORSNet, self).__init__() - self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) - self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) - self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) + self.orb1 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) + self.orb2 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) + self.orb3 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.up_enc1 = UpSample(n_feat, scale_unetfeats) self.up_dec1 = UpSample(n_feat, scale_unetfeats) - self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) - self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) + self.up_enc2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats)) + self.up_dec2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats)) - self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_enc1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_enc2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_enc3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_dec1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_dec2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_dec3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) def forward(self, x, encoder_outs, decoder_outs): x = self.orb1(x) @@ -236,13 +251,17 @@ def forward(self, x, encoder_outs, decoder_outs): ########################################################################## class MPRNet(nn.Module): - def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False): + def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, + reduction=4, bias=False): super(MPRNet, self).__init__() - act=nn.PReLU() - self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) - self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) - self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) + act = nn.PReLU() + self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) + self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) + self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) # Cross Stage Feature Fusion (CSFF) self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) @@ -251,14 +270,15 @@ def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetf self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) - self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) + self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, + num_cab) self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) - - self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) - self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) - self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) + + self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias) + self.concat23 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias) + self.tail = conv(n_feat + scale_orsnetfeats, out_c, kernel_size, bias=bias) def forward(self, x3_img): # Original-resolution Image for Stage 3 @@ -268,14 +288,14 @@ def forward(self, x3_img): # Multi-Patch Hierarchy: Split Image into four non-overlapping patches # Two Patches for Stage 2 - x2top_img = x3_img[:,:,0:int(H/2),:] - x2bot_img = x3_img[:,:,int(H/2):H,:] + x2top_img = x3_img[:, :, 0:int(H / 2), :] + x2bot_img = x3_img[:, :, int(H / 2):H, :] # Four Patches for Stage 1 - x1ltop_img = x2top_img[:,:,:,0:int(W/2)] - x1rtop_img = x2top_img[:,:,:,int(W/2):W] - x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] - x1rbot_img = x2bot_img[:,:,:,int(W/2):W] + x1ltop_img = x2top_img[:, :, :, 0:int(W / 2)] + x1rtop_img = x2top_img[:, :, :, int(W / 2):W] + x1lbot_img = x2bot_img[:, :, :, 0:int(W / 2)] + x1rbot_img = x2bot_img[:, :, :, int(W / 2):W] ##------------------------------------------- ##-------------- Stage 1--------------------- @@ -285,17 +305,17 @@ def forward(self, x3_img): x1rtop = self.shallow_feat1(x1rtop_img) x1lbot = self.shallow_feat1(x1lbot_img) x1rbot = self.shallow_feat1(x1rbot_img) - + ## Process features of all 4 patches with Encoder of Stage 1 feat1_ltop = self.stage1_encoder(x1ltop) feat1_rtop = self.stage1_encoder(x1rtop) feat1_lbot = self.stage1_encoder(x1lbot) feat1_rbot = self.stage1_encoder(x1rbot) - + ## Concat deep features - feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] - feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] - + feat1_top = [torch.cat((k, v), 3) for k, v in zip(feat1_ltop, feat1_rtop)] + feat1_bot = [torch.cat((k, v), 3) for k, v in zip(feat1_lbot, feat1_rbot)] + ## Pass features through Decoder of Stage 1 res1_top = self.stage1_decoder(feat1_top) res1_bot = self.stage1_decoder(feat1_bot) @@ -305,13 +325,13 @@ def forward(self, x3_img): x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) ## Output image at Stage 1 - stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) + stage1_img = torch.cat([stage1_img_top, stage1_img_bot], 2) ##------------------------------------------- ##-------------- Stage 2--------------------- ##------------------------------------------- ## Compute Shallow Features - x2top = self.shallow_feat2(x2top_img) - x2bot = self.shallow_feat2(x2bot_img) + x2top = self.shallow_feat2(x2top_img) + x2bot = self.shallow_feat2(x2bot_img) ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) @@ -322,7 +342,7 @@ def forward(self, x3_img): feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) ## Concat deep features - feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] + feat2 = [torch.cat((k, v), 2) for k, v in zip(feat2_top, feat2_bot)] ## Pass features through Decoder of Stage 2 res2 = self.stage2_decoder(feat2) @@ -330,18 +350,17 @@ def forward(self, x3_img): ## Apply SAM x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) - ##------------------------------------------- ##-------------- Stage 3--------------------- ##------------------------------------------- ## Compute Shallow Features - x3 = self.shallow_feat3(x3_img) + x3 = self.shallow_feat3(x3_img) ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) - + x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) stage3_img = self.tail(x3_cat) - return [stage3_img+x3_img, stage2_img, stage1_img] + return [stage3_img + x3_img, stage2_img, stage1_img] diff --git a/Deblurring/README.md b/Deblurring/README.md index 312d311..2756478 100644 --- a/Deblurring/README.md +++ b/Deblurring/README.md @@ -1,4 +1,5 @@ ## Training + - Download the [Datasets](Datasets/README.md) - Train the model with default arguments by running @@ -12,42 +13,53 @@ python train.py ### Download the [model](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) and place it in ./pretrained_models/ #### Testing on GoPro dataset -- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/` + +- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and + place them in `./Datasets/GoPro/test/` - Run + ``` python test.py --dataset GoPro ``` #### Testing on HIDE dataset -- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/` + +- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and + place them in `./Datasets/HIDE/test/` - Run + ``` python test.py --dataset HIDE ``` - #### Testing on RealBlur-J dataset -- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J and place them in `./Datasets/RealBlur_J/test/` + +- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J + and place them in `./Datasets/RealBlur_J/test/` - Run + ``` python test.py --dataset RealBlur_J ``` - - #### Testing on RealBlur-R dataset -- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R and place them in `./Datasets/RealBlur_R/test/` + +- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R + and place them in `./Datasets/RealBlur_R/test/` - Run + ``` python test.py --dataset RealBlur_R ``` #### To reproduce PSNR/SSIM scores of the paper on GoPro and HIDE datasets, run this MATLAB script + ``` evaluate_GOPRO_HIDE.m ``` #### To reproduce PSNR/SSIM scores of the paper on RealBlur dataset, run + ``` evaluate_RealBlur.py ``` diff --git a/Deblurring/config.py b/Deblurring/config.py index 199d693..0fb8c39 100644 --- a/Deblurring/config.py +++ b/Deblurring/config.py @@ -54,7 +54,6 @@ class Config(object): """ def __init__(self, config_yaml: str, config_override: List[Any] = []): - self._C = CN() self._C.GPU = [0] self._C.VERBOSE = False diff --git a/Deblurring/data_RGB.py b/Deblurring/data_RGB.py index bd10f32..4f3f620 100644 --- a/Deblurring/data_RGB.py +++ b/Deblurring/data_RGB.py @@ -1,14 +1,18 @@ import os + from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest + def get_training_data(rgb_dir, img_options): assert os.path.exists(rgb_dir) return DataLoaderTrain(rgb_dir, img_options) + def get_validation_data(rgb_dir, img_options): assert os.path.exists(rgb_dir) return DataLoaderVal(rgb_dir, img_options) + def get_test_data(rgb_dir, img_options): assert os.path.exists(rgb_dir) return DataLoaderTest(rgb_dir, img_options) diff --git a/Deblurring/dataset_RGB.py b/Deblurring/dataset_RGB.py index ba06bc9..9cfff2d 100644 --- a/Deblurring/dataset_RGB.py +++ b/Deblurring/dataset_RGB.py @@ -1,15 +1,17 @@ import os +import random + import numpy as np -from torch.utils.data import Dataset import torch -from PIL import Image import torchvision.transforms.functional as TF -from pdb import set_trace as stx -import random +from PIL import Image +from torch.utils.data import Dataset + def is_image_file(filename): return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) + class DataLoaderTrain(Dataset): def __init__(self, rgb_dir, img_options=None): super(DataLoaderTrain, self).__init__() @@ -17,11 +19,11 @@ def __init__(self, rgb_dir, img_options=None): inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) - self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] + self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] self.img_options = img_options - self.sizex = len(self.tar_filenames) # get the size of target + self.sizex = len(self.tar_filenames) # get the size of target self.ps = self.img_options['patch_size'] @@ -38,23 +40,23 @@ def __getitem__(self, index): inp_img = Image.open(inp_path) tar_img = Image.open(tar_path) - w,h = tar_img.size - padw = ps-w if wTesting using weights: ",args.weights) +utils.load_checkpoint(model_restoration, args.weights) +print("===>Testing using weights: ", args.weights) model_restoration.cuda() model_restoration = nn.DataParallel(model_restoration) model_restoration.eval() @@ -44,9 +43,10 @@ dataset = args.dataset rgb_dir_test = os.path.join(args.input_dir, dataset, 'test', 'input') test_dataset = get_test_data(rgb_dir_test, img_options={}) -test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) +test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, + pin_memory=True) -result_dir = os.path.join(args.result_dir, dataset) +result_dir = os.path.join(args.result_dir, dataset) utils.mkdir(result_dir) with torch.no_grad(): @@ -54,27 +54,27 @@ torch.cuda.ipc_collect() torch.cuda.empty_cache() - input_ = data_test[0].cuda() + input_ = data_test[0].cuda() filenames = data_test[1] # Padding in case images are not multiples of 8 if dataset == 'RealBlur_J' or dataset == 'RealBlur_R': factor = 8 - h,w = input_.shape[2], input_.shape[3] - H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor - padh = H-h if h%factor!=0 else 0 - padw = W-w if w%factor!=0 else 0 - input_ = F.pad(input_, (0,padw,0,padh), 'reflect') + h, w = input_.shape[2], input_.shape[3] + H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + input_ = F.pad(input_, (0, padw, 0, padh), 'reflect') restored = model_restoration(input_) - restored = torch.clamp(restored[0],0,1) + restored = torch.clamp(restored[0], 0, 1) # Unpad images to original dimensions if dataset == 'RealBlur_J' or dataset == 'RealBlur_R': - restored = restored[:,:,:h,:w] + restored = restored[:, :, :h, :w] restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() for batch in range(len(restored)): restored_img = img_as_ubyte(restored[batch]) - utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) + utils.save_img((os.path.join(result_dir, filenames[batch] + '.png')), restored_img) diff --git a/Deblurring/train.py b/Deblurring/train.py index 972b433..7bd9c3b 100644 --- a/Deblurring/train.py +++ b/Deblurring/train.py @@ -1,5 +1,7 @@ import os -from config import Config + +from config import Config + opt = Config('training.yml') gpus = ','.join([str(i) for i in opt.GPU]) @@ -7,10 +9,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = gpus import torch + torch.backends.cudnn.benchmark = True import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader @@ -24,7 +26,6 @@ import losses from warmup_scheduler import GradualWarmupScheduler from tqdm import tqdm -from pdb import set_trace as stx ######### Set Seeds ########### random.seed(1234) @@ -37,13 +38,13 @@ session = opt.MODEL.SESSION result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) -model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) +model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) utils.mkdir(result_dir) utils.mkdir(model_dir) train_dir = opt.TRAINING.TRAIN_DIR -val_dir = opt.TRAINING.VAL_DIR +val_dir = opt.TRAINING.VAL_DIR ######### Model ########### model_restoration = MPRNet() @@ -51,24 +52,23 @@ device_ids = [i for i in range(torch.cuda.device_count())] if torch.cuda.device_count() > 1: - print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") - + print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") new_lr = opt.OPTIM.LR_INITIAL -optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8) - +optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8) ######### Scheduler ########### warmup_epochs = 3 -scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN) +scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS - warmup_epochs, + eta_min=opt.OPTIM.LR_MIN) scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) scheduler.step() ######### Resume ########### if opt.TRAINING.RESUME: - path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') - utils.load_checkpoint(model_restoration,path_chk_rest) + path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') + utils.load_checkpoint(model_restoration, path_chk_rest) start_epoch = utils.load_start_epoch(path_chk_rest) + 1 utils.load_optim(optimizer, path_chk_rest) @@ -79,21 +79,23 @@ print("==> Resuming Training with learning rate:", new_lr) print('------------------------------------------------------------------------------') -if len(device_ids)>1: - model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) +if len(device_ids) > 1: + model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids) ######### Loss ########### criterion_char = losses.CharbonnierLoss() criterion_edge = losses.EdgeLoss() ######### DataLoaders ########### -train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) -train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) +train_dataset = get_training_data(train_dir, {'patch_size': opt.TRAINING.TRAIN_PS}) +train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, + drop_last=False, pin_memory=True) -val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) -val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) +val_dataset = get_validation_data(val_dir, {'patch_size': opt.TRAINING.VAL_PS}) +val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, + pin_memory=True) -print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) +print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1)) print('===> Loading datasets') best_psnr = 0 @@ -115,18 +117,18 @@ input_ = data[1].cuda() restored = model_restoration(input_) - + # Compute loss at each stage - loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))]) - loss_edge = np.sum([criterion_edge(restored[j],target) for j in range(len(restored))]) - loss = (loss_char) + (0.05*loss_edge) - + loss_char = np.sum([criterion_char(restored[j], target) for j in range(len(restored))]) + loss_edge = np.sum([criterion_edge(restored[j], target) for j in range(len(restored))]) + loss = (loss_char) + (0.05 * loss_edge) + loss.backward() optimizer.step() - epoch_loss +=loss.item() + epoch_loss += loss.item() #### Evaluation #### - if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0: + if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: model_restoration.eval() psnr_val_rgb = [] for ii, data_val in enumerate((val_loader), 0): @@ -137,34 +139,34 @@ restored = model_restoration(input_) restored = restored[0] - for res,tar in zip(restored,target): + for res, tar in zip(restored, target): psnr_val_rgb.append(utils.torchPSNR(res, tar)) - psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() + psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() if psnr_val_rgb > best_psnr: best_psnr = psnr_val_rgb best_epoch = epoch - torch.save({'epoch': epoch, + torch.save({'epoch': epoch, 'state_dict': model_restoration.state_dict(), - 'optimizer' : optimizer.state_dict() - }, os.path.join(model_dir,"model_best.pth")) + 'optimizer': optimizer.state_dict() + }, os.path.join(model_dir, "model_best.pth")) print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) - torch.save({'epoch': epoch, + torch.save({'epoch': epoch, 'state_dict': model_restoration.state_dict(), - 'optimizer' : optimizer.state_dict() - }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) + 'optimizer': optimizer.state_dict() + }, os.path.join(model_dir, f"model_epoch_{epoch}.pth")) scheduler.step() - + print("------------------------------------------------------------------") - print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) + print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, + epoch_loss, scheduler.get_lr()[0])) print("------------------------------------------------------------------") - torch.save({'epoch': epoch, + torch.save({'epoch': epoch, 'state_dict': model_restoration.state_dict(), - 'optimizer' : optimizer.state_dict() - }, os.path.join(model_dir,"model_latest.pth")) - + 'optimizer': optimizer.state_dict() + }, os.path.join(model_dir, "model_latest.pth")) diff --git a/Deblurring/utils/__init__.py b/Deblurring/utils/__init__.py index b2e9d2d..cc5d1e3 100644 --- a/Deblurring/utils/__init__.py +++ b/Deblurring/utils/__init__.py @@ -1,4 +1,4 @@ +from .dataset_utils import * from .dir_utils import * from .image_utils import * from .model_utils import * -from .dataset_utils import * diff --git a/Deblurring/utils/dataset_utils.py b/Deblurring/utils/dataset_utils.py index b57f474..e85cd67 100644 --- a/Deblurring/utils/dataset_utils.py +++ b/Deblurring/utils/dataset_utils.py @@ -1,5 +1,6 @@ import torch + class MixUp_AUG: def __init__(self): self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) @@ -10,9 +11,9 @@ def aug(self, rgb_gt, rgb_noisy): rgb_gt2 = rgb_gt[indices] rgb_noisy2 = rgb_noisy[indices] - lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() + lam = self.dist.rsample((bs, 1)).view(-1, 1, 1, 1).cuda() - rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 - rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 + rgb_gt = lam * rgb_gt + (1 - lam) * rgb_gt2 + rgb_noisy = lam * rgb_noisy + (1 - lam) * rgb_noisy2 - return rgb_gt, rgb_noisy \ No newline at end of file + return rgb_gt, rgb_noisy diff --git a/Deblurring/utils/dir_utils.py b/Deblurring/utils/dir_utils.py index 3be7063..fb653e5 100644 --- a/Deblurring/utils/dir_utils.py +++ b/Deblurring/utils/dir_utils.py @@ -1,7 +1,9 @@ import os -from natsort import natsorted from glob import glob +from natsort import natsorted + + def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: @@ -9,10 +11,12 @@ def mkdirs(paths): else: mkdir(paths) + def mkdir(path): if not os.path.exists(path): os.makedirs(path) + def get_last_path(path, session): - x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] - return x \ No newline at end of file + x = natsorted(glob(os.path.join(path, '*%s' % session)))[-1] + return x diff --git a/Deblurring/utils/image_utils.py b/Deblurring/utils/image_utils.py index 29e3f98..5ec2ff5 100644 --- a/Deblurring/utils/image_utils.py +++ b/Deblurring/utils/image_utils.py @@ -1,18 +1,21 @@ -import torch -import numpy as np import cv2 +import numpy as np +import torch + def torchPSNR(tar_img, prd_img): - imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) - rmse = (imdff**2).mean().sqrt() - ps = 20*torch.log10(1/rmse) + imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1) + rmse = (imdff ** 2).mean().sqrt() + ps = 20 * torch.log10(1 / rmse) return ps + def save_img(filepath, img): - cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + def numpyPSNR(tar_img, prd_img): imdff = np.float32(prd_img) - np.float32(tar_img) - rmse = np.sqrt(np.mean(imdff**2)) - ps = 20*np.log10(255/rmse) + rmse = np.sqrt(np.mean(imdff ** 2)) + ps = 20 * np.log10(255 / rmse) return ps diff --git a/Deblurring/utils/model_utils.py b/Deblurring/utils/model_utils.py index 154c5d8..58cfc00 100644 --- a/Deblurring/utils/model_utils.py +++ b/Deblurring/utils/model_utils.py @@ -1,24 +1,30 @@ -import torch import os from collections import OrderedDict +import torch + + def freeze(model): for p in model.parameters(): - p.requires_grad=False + p.requires_grad = False + def unfreeze(model): for p in model.parameters(): - p.requires_grad=True + p.requires_grad = True + def is_frozen(model): x = [p.requires_grad for p in model.parameters()] return not all(x) + def save_checkpoint(model_dir, state, session): epoch = state['epoch'] - model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) + model_out_path = os.path.join(model_dir, "model_epoch_{}_{}.pth".format(epoch, session)) torch.save(state, model_out_path) + def load_checkpoint(model, weights): checkpoint = torch.load(weights) try: @@ -27,7 +33,7 @@ def load_checkpoint(model, weights): state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): - name = k[7:] # remove `module.` + name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) @@ -37,15 +43,17 @@ def load_checkpoint_multigpu(model, weights): state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): - name = k[7:] # remove `module.` + name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) + def load_start_epoch(weights): checkpoint = torch.load(weights) epoch = checkpoint["epoch"] return epoch + def load_optim(optimizer, weights): checkpoint = torch.load(weights) optimizer.load_state_dict(checkpoint['optimizer']) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d9f8437 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,52 @@ +FROM python:3.7-slim as build + +ENV PYTHONFAULTHANDLER=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONHASHSEED=random \ + PIP_NO_CACHE_DIR=off \ + PIP_DISABLE_PIP_VERSION_CHECK=on \ + PIP_DEFAULT_TIMEOUT=100 + +WORKDIR /app + +COPY ./requirements.txt /app/requirements.txt +RUN pip3 install --upgrade pip && pip3 install \ + --user \ + --no-warn-script-location \ +# --no-cache-dir \ + -r requirements.txt + +# Copy current folder to docker working dir +COPY . . + + +FROM python:3.7-slim as output +#Install shared object & library for cv2 (opencv) +RUN apt-get update -y && apt-get install -y --no-install-recommends libgl1-mesa-dev libglib2.0-0 \ + && useradd -m r3v3r \ + && mkdir /app \ + && chown -R r3v3r:r3v3r /app + +#RUN apt-get update -y && useradd -m r3v3r \ +# && mkdir /app \ +# && chown -R r3v3r:r3v3r /app + + +USER r3v3r +WORKDIR /app + +ENV PYTHONPATH=/app \ + HOME=/home/r3v3r \ + PATH="/home/r3v3r/.local/bin:${PATH}" + +COPY --from=build --chown=r3v3r:r3v3r /usr/local/lib/python3.7/site-packages /usr/local/lib/python3.7/site-packages +COPY --from=build --chown=r3v3r:r3v3r /root/.local /home/r3v3r/.local + +#COPY --from=build --chown=r3v3r:r3v3r /app/src /app/src +##COPY --from=build --chown=r3v3r:r3v3r /app/models /app/models +#COPY --from=build --chown=r3v3r:r3v3r /app/conf /app/conf +#COPY --from=build --chown=r3v3r:r3v3r /app/main.py /app/main.py + +COPY --from=build --chown=r3v3r:r3v3r /app /app + +CMD ["python3","/app/main.py", "development"] diff --git a/README.md b/README.md index 3e4f0c8..a5a706a 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,31 @@ # Instruction + ## Install linraries + ``` pip install requirements.txt ``` + ## Testing API + - First, running the ./app.py file. + ``` python app.py ``` + - Then runing the ./request.py to send request to the API. + ``` python request.py ``` + ## About the api + - The API receive raw JSON data. -- You have to post raw JSON data, which contains the element 'urls' having an array of images need to be enhanced. Here is an example. +- You have to post raw JSON data, which contains the element 'urls' having an array of images need to be enhanced. Here + is an example. + ``` { "urls":[ @@ -23,9 +34,12 @@ python request.py ] } ``` + - In Postman, we will post data like this. -![alt text](images/1.png) + ![alt text](images/1.png) + ## About the result + ``` { "result": [ @@ -41,5 +55,6 @@ python request.py "time": 25.083990812301636 } ``` + ![](images/2.png) diff --git a/app.py b/app.py deleted file mode 100644 index 88a23d6..0000000 --- a/app.py +++ /dev/null @@ -1,41 +0,0 @@ -import time - -from flask import Flask, request -from flask_cors import CORS -from py_profiler.profiler_controller import profiler_blueprint - -from enhance_service import * - -app = Flask(__name__) -app.register_blueprint(profiler_blueprint) -CORS(app) - -folder_out = "static/output/" -os.makedirs(folder_out, exist_ok=True) - -enhance_service = EnhanceService( - 'Deblurring/pretrained_models/model_deblurring.pth', - use_cpu=True -) - - -@app.route('/') -def index(): - return "hello" - - -@app.route('/enhance', methods=['POST']) -def enhance(): - urls = dict(request.get_json())['urls'] - # - begin = time.time() - output_result = enhance_service.process(urls, folder_out) - print(f'Output: {output_result}') - return { - 'time': time.time() - begin, - 'result': output_result - } - - -if __name__ == '__main__': - app.run() diff --git a/conf/development.yml b/conf/development.yml new file mode 100644 index 0000000..e2af202 --- /dev/null +++ b/conf/development.yml @@ -0,0 +1,11 @@ +server: + http: + nthreads: 4 + port: 31558 +# port: 8080 +downloader: + num_threads: 8 + +run_on_cpu: True +use_deblur_model: False +deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth" diff --git a/conf/production.yml b/conf/production.yml new file mode 100644 index 0000000..23b994b --- /dev/null +++ b/conf/production.yml @@ -0,0 +1,10 @@ +server: + http: + nthreads: 4 + port: 8080 + +downloader: + num_threads: 8 +run_on_cpu: False +use_deblur_model: True +deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth" \ No newline at end of file diff --git a/conf/staging.yml b/conf/staging.yml new file mode 100644 index 0000000..fd759b3 --- /dev/null +++ b/conf/staging.yml @@ -0,0 +1,11 @@ +server: + http: + nthreads: 8 + port: 8080 + +downloader: + num_threads: 8 + +run_on_cpu: True +use_deblur_model: False +deblur_model_path: "Deblurring/pretrained_models/model_deblurring.pth" \ No newline at end of file diff --git a/controller/__init__.py b/controller/__init__.py new file mode 100644 index 0000000..113b895 --- /dev/null +++ b/controller/__init__.py @@ -0,0 +1 @@ +from .enhance_controller import * diff --git a/controller/enhance_controller.py b/controller/enhance_controller.py new file mode 100644 index 0000000..c7b12da --- /dev/null +++ b/controller/enhance_controller.py @@ -0,0 +1,21 @@ +import os + +from dependency_injector.wiring import Provide, inject +from flask import Blueprint, request + +from enhance_service import EnhanceService +from .error_handlers import build_response +from module.application_container import ApplicationContainer + +folder_out = "static/output/" +os.makedirs(folder_out, exist_ok=True) + +enhance_blueprint = Blueprint("enhance", __name__) + + +@enhance_blueprint.route("/enhance", methods=["POST"]) +@inject +def detect(enhance_service: EnhanceService = Provide[ApplicationContainer.enhance_service]): + urls = dict(request.get_json())['urls'] + output_result = enhance_service.process(urls, folder_out) + return build_response(result=output_result) diff --git a/controller/error_handlers.py b/controller/error_handlers.py new file mode 100644 index 0000000..6f57f96 --- /dev/null +++ b/controller/error_handlers.py @@ -0,0 +1,31 @@ +import json +import logging +import traceback + +from domain.errors import RError, InternalError + + +def handle_defined_errors(error: RError): + return build_error_response(error) + + +def handle_other_exceptions(error: Exception): + return build_error_response(InternalError(str(error), error)) + + +def build_response(**kwargs): + response = {} + response.update(kwargs) + return json.dumps(response) + + +def build_error_response(error: RError): + print(error) + if error.__cause__ is not None: + logging.error("".join(traceback.TracebackException.from_exception(error.__cause__).format())) + else: + logging.error(error.message) + return { + "error": error.get_error(), + "message": error.message + } diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1470854 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,13 @@ +version: "3.3" +services: + image_enhancement: +# image: registry.rever.vn/ai-research/room-detection:master +# image: registry.rever.vn/ai-research/room-detection:v1.0.0 + image: image_enhancement + build: . + command: ["python3","/app/main.py", "staging"] + ports: + - 31558:8080 + volumes: + - ./logs:/app/logs + restart: always diff --git a/domain/__init__.py b/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/domain/errors.py b/domain/errors.py new file mode 100644 index 0000000..29ae760 --- /dev/null +++ b/domain/errors.py @@ -0,0 +1,45 @@ +REQUEST_INVALID_ERR = "request_invalid_error" +INTERNAL_ERR = "internal_error" + + +class RError(Exception): + def __init__( + self, + message, + cause: Exception = None + ): + super(RError, self).__init__(message, cause) + self.message = message + self.__cause__ = cause + + def get_message(self) -> str: + return self.message + + def get_error(self) -> str: + pass + + +class RequestInvalidError(RError): + def __init__( + self, + message, + cause: Exception = None): + super(RequestInvalidError, self).__init__( + message, cause + ) + + def get_error(self) -> str: + return REQUEST_INVALID_ERR + + +class InternalError(RError): + def __init__( + self, + message, + cause: Exception = None): + super(InternalError, self).__init__( + message, cause + ) + + def get_error(self) -> str: + return INTERNAL_ERR diff --git a/domain/jsonable.py b/domain/jsonable.py new file mode 100644 index 0000000..d7b893c --- /dev/null +++ b/domain/jsonable.py @@ -0,0 +1,10 @@ +import json + + +class Jsonable: + + def to_json(self) -> str: + return json.dumps(self.to_dict(), ensure_ascii=False) + + def to_dict(self) -> dict: + pass diff --git a/enhance_service.py b/enhance_service.py index 912b6a5..b9798d3 100644 --- a/enhance_service.py +++ b/enhance_service.py @@ -1,3 +1,4 @@ +import logging import os # from enlighten_inference import EnlightenOnnxModel import uuid @@ -26,8 +27,9 @@ class EnhanceService: - def __init__(self, deblur_model_path, use_cpu: bool = True): + def __init__(self, deblur_model_path, use_cpu: bool = True, use_deblur_model: bool = False): self.use_cpu = use_cpu + self.use_deblur_model = use_deblur_model # Executor to run enhance process concurrently self.executor = ThreadPoolExecutor(max_workers=8) # A downloader to download image using a thread pool with 16 threads @@ -46,18 +48,17 @@ def __init__(self, deblur_model_path, use_cpu: bool = True): self.white_balancer = WhiteBalancer() self.ceiq_scoring_model = CEIQ() - print(f'white_balancer: {type(self.white_balancer)}') - print(f'CEIQ_model: {type(self.ceiq_scoring_model)}') - print(f'Model: {type(self.deblur_model)}') - - print("Init successfully") + logging.info(f'white_balancer: {type(self.white_balancer)}') + logging.info(f'CEIQ_model: {type(self.ceiq_scoring_model)}') + logging.info(f'Model: {type(self.deblur_model)}') + logging.info("Init successfully") @profiler() def process(self, image_urls: List[str], enhanced_out_dir): image_dict = self.image_downloader.bulk_download_as_image(image_urls) if len(image_urls) == 0: raise Exception(f"No image urls found at {image_urls}") - print('Number of files: ', len(image_dict)) + logging.info('Number of files: %d', len(image_dict)) future_to_checks = { self.executor.submit(self._enhance_image, image, 8, enhanced_out_dir): url @@ -69,27 +70,30 @@ def process(self, image_urls: List[str], enhanced_out_dir): # The try-except-else clause is omitted here for future in futures.as_completed(future_to_checks): url = future_to_checks[future] - output_path = future.result() - result_dict[url] = output_path + output_path, enhanced_score = future.result() + result_dict[url] = { + 'enhanced_url': output_path, + 'enhanced_score': enhanced_score + } return result_dict @profiler() - def _enhance_image(self, image, factor, out_dir) -> str: - restored = self._deblur_image(image, factor) + def _enhance_image(self, image, factor, out_dir) -> [str, float]: + restored = self._deblur_image(image, factor) if self.use_deblur_model else np.asarray(image) # processed = enlighten_model.predict(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) # processed = img_as_ubyte(processed) img_output = self._process_white_balancing(restored) restored = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR) - scores = self._calc_score([restored, img_output]) - if scores[0] > scores[1]: + origin_score, improved_score = self._calc_score([restored, img_output]) + if origin_score > improved_score: img_output = restored - output_path = os.path.join(out_dir, f'{uuid.uuid1()}.png') + output_path = os.path.join(out_dir, f'{uuid.uuid1()}.jpg') save_img(output_path, img_output) - return output_path + return [output_path, ((improved_score - origin_score) / origin_score)] @profiler() def _deblur_image(self, img, factor: int = 8): @@ -134,12 +138,12 @@ def _process_white_balancing(self, input_image, threshold: float = 0.3): # Process image for gamma correction output_image = None if t < -threshold: # Dimmed Image - print('Dimmed') + logging.info('Dimmed') result = self.white_balancer.process_dimmed(Y) YCrCb[:, :, 0] = result output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR) elif t > threshold: - print('Bright Image') # Bright Image + logging.info('Bright Image') # Bright Image result = self.white_balancer.process_bright(Y) YCrCb[:, :, 0] = result output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR) @@ -154,5 +158,5 @@ def _calc_score(self, images): # org_score = CEIQ_model.predict(np.expand_dims(restored, axis=0), 1)[0] # imp_score = CEIQ_model.predict(np.expand_dims(img_output, axis=0), 1)[0] scores = self.ceiq_scoring_model.predict(images, option=1) - print(f"Scores: {scores[0]} -> {scores[1]}: Improved: {((scores[1] - scores[0]) * 100 / scores[0])} %") + logging.info(f"Scores: {scores[0]} -> {scores[1]}: Improved: {((scores[1] - scores[0]) * 100 / scores[0])} %") return scores diff --git a/images/1.png b/images/1.png deleted file mode 100644 index 8042a7e..0000000 Binary files a/images/1.png and /dev/null differ diff --git a/images/2.png b/images/2.png deleted file mode 100644 index 61f59fd..0000000 Binary files a/images/2.png and /dev/null differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..da4a1dd --- /dev/null +++ b/main.py @@ -0,0 +1,50 @@ +import logging +import os +import sys + +from dependency_injector import containers +from dependency_injector.wiring import Provide, inject +from flask import Flask +from py_profiler.profiler_controller import profiler_blueprint +from waitress import serve + +from controller import enhance_controller, enhance_blueprint +from controller.error_handlers import handle_defined_errors, handle_other_exceptions +from domain.errors import RError +from module import ApplicationContainer +from module.injector import create_injector +from utils.setup_logging import setup_logging + + +@inject +def setup_http_server( + injector: containers.DeclarativeContainer, + port: int = Provide[ApplicationContainer.config.server.http.port.as_int()], + nthreads: int = Provide[ApplicationContainer.config.server.http.nthreads.as_int()], +) -> None: + app = Flask(__name__) + app.url_map.strict_slashes = False + app.debug = True + + app.register_blueprint(enhance_blueprint) + app.register_blueprint(profiler_blueprint) + + app.register_error_handler(RError, handle_defined_errors) + app.register_error_handler(Exception, handle_other_exceptions) + + logging.info(f"Created http server at port = {port} with {nthreads} concurrent threads.") + serve( + app, host="0.0.0.0", + port=port, + threads=nthreads if nthreads is not None else 4 + ) + + +if __name__ == "__main__": + os.environ['APP_MODE'] = sys.argv[1] if len(sys.argv) > 1 else 'development' + setup_logging() + injector = create_injector([ + sys.modules[__name__], + enhance_controller, + ], mode=os.environ['APP_MODE']) + setup_http_server(injector) diff --git a/CEIQ_model_v1_1.pickle b/models/CEIQ_model_v1_1.pickle similarity index 100% rename from CEIQ_model_v1_1.pickle rename to models/CEIQ_model_v1_1.pickle diff --git a/module/__init__.py b/module/__init__.py new file mode 100644 index 0000000..58409e7 --- /dev/null +++ b/module/__init__.py @@ -0,0 +1 @@ +from .application_container import * diff --git a/module/application_container.py b/module/application_container.py new file mode 100644 index 0000000..17ac2c6 --- /dev/null +++ b/module/application_container.py @@ -0,0 +1,21 @@ +from dependency_injector import containers, providers + +# +# @author: anhlt +# +from enhance_service import EnhanceService +from image_downloader import ImageDownloader + + +class ApplicationContainer(containers.DeclarativeContainer): + config = providers.Configuration() + + downloader = providers.Singleton( + ImageDownloader, + num_threads=config.downloader.num_threads + ) + enhance_service = providers.Singleton( + EnhanceService, + deblur_model_path=config.deblur_model_path, + use_cpu=config.run_on_cpu + ) diff --git a/module/injector.py b/module/injector.py new file mode 100644 index 0000000..56740a8 --- /dev/null +++ b/module/injector.py @@ -0,0 +1,21 @@ +import os + +from .application_container import ApplicationContainer + + +def create_injector( + modules: list, + mode: str = 'development' +): + os.environ['APP_MODE'] = mode if mode is not None else os.environ['APP_MODE'] + mode = os.environ['APP_MODE'] + + print(f"Mode: {mode}") + print(f"Loaded config: config/{mode}.yml") + print(f"Modules: {modules}") + injector = ApplicationContainer() + injector.config.from_yaml(f"conf/{mode}.yml") + injector.wire(modules) + print("Wire completed") + + return injector diff --git a/request.py b/request.py index 9fcc58a..952fe79 100644 --- a/request.py +++ b/request.py @@ -1,24 +1,25 @@ -import requests import json -url = "http://127.0.0.1:5000/enhance" +import requests + +url = "http://127.0.0.1:8080/enhance" payload = json.dumps({ - "urls": [ - "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", - "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", - "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", - "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", - "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", - "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", - "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", - "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", - "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", - "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0" - ] + "urls": [ + "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", + "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", + "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", + "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", + "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", + "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", + "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", + "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0", + "https://genk.mediacdn.vn/k:thumb_w/640/2014/mg-1235-3-1416328703227/cuu-sang-anh-bang-vai-thao-tac-don-gian-trong-photoshop.jpg", + "https://th.bing.com/th/id/R.7924efd2efd3b67225ebb6cc82331bc8?rik=TKz8a7gW1Ur9zw&riu=http%3a%2f%2fpalistudio.com%2fupload%2fimage%2fdata%2fTin-tuc%2fChup-anh-cuoi%2fHa-Noi%2fZEN1242-d23a92.jpg&ehk=HgnL8rnD6HnJRBawFWnFTJC8HacfoC%2f75d5xrPwGDxI%3d&risl=&pid=ImgRaw&r=0" + ] }) headers = { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json' } response = requests.request("POST", url, headers=headers, data=payload) diff --git a/requirements.txt b/requirements.txt index 51b8790..06a406e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,7 @@ torch==1.9.0 torchvision==0.10.0 typing-extensions==3.10.0.0 Werkzeug==2.0.1 +dependency-injector +pyyaml +waitress py_profiler \ No newline at end of file diff --git a/test_imgs/1.png b/static/test_imgs/1.png similarity index 100% rename from test_imgs/1.png rename to static/test_imgs/1.png diff --git a/test_imgs/2.png b/static/test_imgs/2.png similarity index 100% rename from test_imgs/2.png rename to static/test_imgs/2.png diff --git a/test_imgs/3.png b/static/test_imgs/3.png similarity index 100% rename from test_imgs/3.png rename to static/test_imgs/3.png diff --git a/utils/model_utils.py b/utils/model_utils.py index 4cc0c8f..24e86f5 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -11,6 +11,7 @@ def bypass_ssl_verify(): if (not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None)): ssl._create_default_https_context = ssl._create_unverified_context + def save_img(filepath, img): import cv2 # cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) diff --git a/utils/setup_logging.py b/utils/setup_logging.py new file mode 100644 index 0000000..4a761fe --- /dev/null +++ b/utils/setup_logging.py @@ -0,0 +1,109 @@ +import logging +import os +from logging.config import dictConfig + +# # for sending error logs to slack +# import json +# import requests +# class HTTPSlackHandler(logging.Handler): +# def emit(self, record): +# log_entry = self.format(record) +# json_text = json.dumps({"text": log_entry}) +# url = 'https://hooks.slack.com/services//' +# return requests.post(url, json_text, headers={"Content-type": "application/json"}).content + + +# debug settings + +MAX_LOG_FILE_SIZE = 10 * 1024 * 1024 + + +def setup_logging(): + os.makedirs("logs", exist_ok=True) + debug_mode = os.environ.get("APP_MODE", "development") == "development" + + dictConfig({ + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "default": { + "format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s", + }, + "access": { + "format": "%(message)s", + } + }, + "handlers": { + "console": { + "level": "INFO", + "class": "logging.StreamHandler", + "formatter": "default", + "stream": "ext://sys.stdout", + }, + # "email": { + # "class": "logging.handlers.SMTPHandler", + # "formatter": "default", + # "level": "ERROR", + # "mailhost": ("smtp.example.com", 587), + # "fromaddr": "devops@example.com", + # "toaddrs": ["receiver@example.com", "receiver2@example.com"], + # "subject": "Error Logs", + # "credentials": ("username", "password"), + # }, + # "slack": { + # "class": "app.HTTPSlackHandler", + # "formatter": "default", + # "level": "ERROR", + # }, + "service_file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "default", + "filename": "logs/service.log", + "maxBytes": MAX_LOG_FILE_SIZE, + "backupCount": 5, + "delay": "True", + }, + "error_file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "default", + "filename": "logs/error.log", + "maxBytes": MAX_LOG_FILE_SIZE, + "backupCount": 10, + "delay": "True", + }, + "access_file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "access", + "filename": "logs/access.log", + "maxBytes": MAX_LOG_FILE_SIZE, + "backupCount": 10, + "delay": "True", + } + }, + "loggers": { + "error": { + "handlers": ["console"] if debug_mode else [ + "console", + # "slack", + "error_file" + ], + "level": "INFO", + "propagate": False, + }, + "access": { + "handlers": ["console"] if debug_mode else ["console", "access_file"], + "level": "INFO", + "propagate": False, + } + }, + "root": { + "level": "DEBUG" if debug_mode else "INFO", + "handlers": ["console", "service_file"] if debug_mode else [ + "console", + # "slack" + "service_file" + ], + } + }) + + logging.info(f"Setting up logging: DEBUG = {debug_mode}")