Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
andy1xx8 committed Sep 7, 2021
2 parents 25350f4 + 797237b commit d022dab
Show file tree
Hide file tree
Showing 46 changed files with 828 additions and 374 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
*__pycache__

static/output

logs
33 changes: 16 additions & 17 deletions CEIQ.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pickle

import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
from scipy.stats import entropy as ent
import pickle


class CEIQ:
def __init__(self):
with open('CEIQ_model_v1_1.pickle', 'rb') as f:
with open('models/CEIQ_model_v1_1.pickle', 'rb') as f:
self.model = pickle.load(f)

def entropy(self, hist, bit_instead_of_nat=False):
Expand All @@ -21,12 +20,12 @@ def entropy(self, hist, bit_instead_of_nat=False):
"""
# h = h[np.where(h!=0)[0]]
h = np.asarray(hist, dtype=np.float64)
if h.sum()<=0 or (h<0).any():
print("[entropy] WARNING, malformed/empty input %s. Returning None."%str(hist))
if h.sum() <= 0 or (h < 0).any():
print("[entropy] WARNING, malformed/empty input %s. Returning None." % str(hist))
return None
h = h/h.sum()
h = h / h.sum()
log_fn = np.ma.log2 if bit_instead_of_nat else np.ma.log
return -(h*log_fn(h)).sum()
return -(h * log_fn(h)).sum()

def cross_entropy(self, x, y):
""" Computes cross entropy between two distributions.
Expand All @@ -47,30 +46,30 @@ def cross_entropy(self, x, y):
# Ignore zero 'y' elements.
mask = y > 0
x = x[mask]
y = y[mask]
ce = -np.sum(x * np.log(y))
y = y[mask]
ce = -np.sum(x * np.log(y))
return ce

def generate_x(self, img_path, option=0):
if option == 0:
Ig = cv2.imread(img_path)
Ig = 0.299*Ig[:, :, 2] + 0.587*Ig[:, :, 1] + 0.114*Ig[:, :, 0]
Ig = 0.299 * Ig[:, :, 2] + 0.587 * Ig[:, :, 1] + 0.114 * Ig[:, :, 0]
Ig = Ig.astype('uint8')
else:
# Ig = cv2.cvtColor(img_path, cv2.COLOR_BGR2GRAY)
Ig = 0.299*img_path[:, :, 2] + 0.587*img_path[:, :, 1] + 0.114*img_path[:, :, 0]
Ig = 0.299 * img_path[:, :, 2] + 0.587 * img_path[:, :, 1] + 0.114 * img_path[:, :, 0]
Ig = Ig.astype('uint8')
Ie = cv2.equalizeHist(Ig)
### Calculate ssim ###
ssim_ig_ie, _ = ssim(Ig, Ie, full=True)

### Get histograms ###
histg = cv2.calcHist([Ig],[0],None,[128],[0,256])
histe = cv2.calcHist([Ie],[0],None,[128],[0,256])
histg = cv2.calcHist([Ig], [0], None, [128], [0, 256])
histe = cv2.calcHist([Ie], [0], None, [128], [0, 256])
histg = np.reshape(histg, (histg.shape[0]))
histe = np.reshape(histe, (histe.shape[0]))
zero_idsg = np.where(histg==0)[0]
zero_idse = np.where(histe==0)[0]
zero_idsg = np.where(histg == 0)[0]
zero_idse = np.where(histe == 0)[0]
zero_ids = np.unique(np.concatenate((zero_idsg, zero_idse)))
# print(zero_ids)
histg = np.delete(histg, zero_ids)
Expand Down
13 changes: 7 additions & 6 deletions Deblurring/Datasets/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
Download datasets from the google drive links and place them in this directory. Your directory tree should look like this
Download datasets from the google drive links and place them in this directory. Your directory tree should look like
this

`GoPro` <br/>
  `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing) <br/>
  `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)
`├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing) <br/>
`└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)

`HIDE` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)
`└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)

`RealBlur_J` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing)
`└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing)

`RealBlur_R` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing)
`└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing)
171 changes: 95 additions & 76 deletions Deblurring/MPRNet.py

Large diffs are not rendered by default.

26 changes: 19 additions & 7 deletions Deblurring/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Training

- Download the [Datasets](Datasets/README.md)

- Train the model with default arguments by running
Expand All @@ -12,42 +13,53 @@ python train.py
### Download the [model](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) and place it in ./pretrained_models/

#### Testing on GoPro dataset
- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/`

- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and
place them in `./Datasets/GoPro/test/`
- Run

```
python test.py --dataset GoPro
```

#### Testing on HIDE dataset
- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/`

- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and
place them in `./Datasets/HIDE/test/`
- Run

```
python test.py --dataset HIDE
```


#### Testing on RealBlur-J dataset
- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J and place them in `./Datasets/RealBlur_J/test/`

- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J
and place them in `./Datasets/RealBlur_J/test/`
- Run

```
python test.py --dataset RealBlur_J
```



#### Testing on RealBlur-R dataset
- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R and place them in `./Datasets/RealBlur_R/test/`

- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R
and place them in `./Datasets/RealBlur_R/test/`
- Run

```
python test.py --dataset RealBlur_R
```

#### To reproduce PSNR/SSIM scores of the paper on GoPro and HIDE datasets, run this MATLAB script

```
evaluate_GOPRO_HIDE.m
```

#### To reproduce PSNR/SSIM scores of the paper on RealBlur dataset, run

```
evaluate_RealBlur.py
```
1 change: 0 additions & 1 deletion Deblurring/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class Config(object):
"""

def __init__(self, config_yaml: str, config_override: List[Any] = []):

self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False
Expand Down
4 changes: 4 additions & 0 deletions Deblurring/data_RGB.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os

from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest


def get_training_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, img_options)


def get_validation_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, img_options)


def get_test_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTest(rgb_dir, img_options)
89 changes: 46 additions & 43 deletions Deblurring/dataset_RGB.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import os
import random

import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
from PIL import Image
from torch.utils.data import Dataset


def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])


class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()

inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))

self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]

self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.sizex = len(self.tar_filenames) # get the size of target

self.ps = self.img_options['patch_size']

Expand All @@ -38,23 +40,23 @@ def __getitem__(self, index):
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)

w,h = tar_img.size
padw = ps-w if w<ps else 0
padh = ps-h if h<ps else 0
w, h = tar_img.size
padw = ps - w if w < ps else 0
padh = ps - h if h < ps else 0

# Reflect Pad in case image is smaller than patch_size
if padw!=0 or padh!=0:
inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')
if padw != 0 or padh != 0:
inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect')

aug = random.randint(0, 2)
aug = random.randint(0, 2)
if aug == 1:
inp_img = TF.adjust_gamma(inp_img, 1)
tar_img = TF.adjust_gamma(tar_img, 1)

aug = random.randint(0, 2)
aug = random.randint(0, 2)
if aug == 1:
sat_factor = 1 + (0.2 - 0.4*np.random.rand())
sat_factor = 1 + (0.2 - 0.4 * np.random.rand())
inp_img = TF.adjust_saturation(inp_img, sat_factor)
tar_img = TF.adjust_saturation(tar_img, sat_factor)

Expand All @@ -63,53 +65,54 @@ def __getitem__(self, index):

hh, ww = tar_img.shape[1], tar_img.shape[2]

rr = random.randint(0, hh-ps)
cc = random.randint(0, ww-ps)
aug = random.randint(0, 8)
rr = random.randint(0, hh - ps)
cc = random.randint(0, ww - ps)
aug = random.randint(0, 8)

# Crop patch
inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]
inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]

# Data Augmentations
if aug==1:
if aug == 1:
inp_img = inp_img.flip(1)
tar_img = tar_img.flip(1)
elif aug==2:
elif aug == 2:
inp_img = inp_img.flip(2)
tar_img = tar_img.flip(2)
elif aug==3:
inp_img = torch.rot90(inp_img,dims=(1,2))
tar_img = torch.rot90(tar_img,dims=(1,2))
elif aug==4:
inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
elif aug==5:
inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
elif aug==6:
inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
elif aug==7:
inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))
elif aug == 3:
inp_img = torch.rot90(inp_img, dims=(1, 2))
tar_img = torch.rot90(tar_img, dims=(1, 2))
elif aug == 4:
inp_img = torch.rot90(inp_img, dims=(1, 2), k=2)
tar_img = torch.rot90(tar_img, dims=(1, 2), k=2)
elif aug == 5:
inp_img = torch.rot90(inp_img, dims=(1, 2), k=3)
tar_img = torch.rot90(tar_img, dims=(1, 2), k=3)
elif aug == 6:
inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2))
tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2))
elif aug == 7:
inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2))
tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2))

filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

return tar_img, inp_img, filename


class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(DataLoaderVal, self).__init__()

inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))

self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]

self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.sizex = len(self.tar_filenames) # get the size of target

self.ps = self.img_options['patch_size']

Expand All @@ -128,8 +131,8 @@ def __getitem__(self, index):

# Validate on center crop
if self.ps is not None:
inp_img = TF.center_crop(inp_img, (ps,ps))
tar_img = TF.center_crop(tar_img, (ps,ps))
inp_img = TF.center_crop(inp_img, (ps, ps))
tar_img = TF.center_crop(tar_img, (ps, ps))

inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
Expand All @@ -138,6 +141,7 @@ def __getitem__(self, index):

return tar_img, inp_img, filename


class DataLoaderTest(Dataset):
def __init__(self, inp_dir, img_options):
super(DataLoaderTest, self).__init__()
Expand All @@ -152,7 +156,6 @@ def __len__(self):
return self.inp_size

def __getitem__(self, index):

path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
Expand Down
Loading

0 comments on commit d022dab

Please sign in to comment.