From 9bafe6b4febbb37f6331db509ff19b77dfc2f18a Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 22 Sep 2020 23:16:27 +0800 Subject: [PATCH 01/23] fixbug: PSNR performs on uint8 type --- basicsr/metrics/metric_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py index cac7026..55a2b3a 100644 --- a/basicsr/metrics/metric_util.py +++ b/basicsr/metrics/metric_util.py @@ -25,9 +25,9 @@ def reorder_image(img, input_order='HWC'): "'HWC' and 'CHW'") if len(img.shape) == 2: img = img[..., None] - return img if input_order == 'CHW': img = img.transpose(1, 2, 0) + img = img.astype(np.float64) return img From 7d03ae2a8d6960b7fe1bdb4d35607d51a8eb3ee6 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 29 Sep 2020 20:41:53 +0800 Subject: [PATCH 02/23] Updates for lints (#297) * format: fused_bias_act * do not use list in arguments * use multiple inheritance for VideoGANModel --- basicsr/models/archs/stylegan2_arch.py | 32 ++-- basicsr/models/lr_scheduler.py | 6 +- .../ops/fused_act/src/fused_bias_act.cpp | 12 +- basicsr/models/video_gan_model.py | 149 ++---------------- setup.py | 5 +- 5 files changed, 40 insertions(+), 164 deletions(-) diff --git a/basicsr/models/archs/stylegan2_arch.py b/basicsr/models/archs/stylegan2_arch.py index f0d3453..26c2ea3 100644 --- a/basicsr/models/archs/stylegan2_arch.py +++ b/basicsr/models/archs/stylegan2_arch.py @@ -211,7 +211,7 @@ class ModulatedConv2d(nn.Module): sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). eps (float): A value added to the denominator for numerical stability. Default: 1e-8. """ @@ -223,7 +223,7 @@ def __init__(self, num_style_feat, demodulate=True, sample_mode=None, - resample_kernel=[1, 3, 3, 1], + resample_kernel=(1, 3, 3, 1), eps=1e-8): super(ModulatedConv2d, self).__init__() self.in_channels = in_channels @@ -333,7 +333,7 @@ class StyleConv(nn.Module): sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). """ def __init__(self, @@ -343,7 +343,7 @@ def __init__(self, num_style_feat, demodulate=True, sample_mode=None, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1)): super(StyleConv, self).__init__() self.modulated_conv = ModulatedConv2d( in_channels, @@ -377,14 +377,14 @@ class ToRGB(nn.Module): num_style_feat (int): Channel number of style features. upsample (bool): Whether to upsample. Default: True. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). """ def __init__(self, in_channels, num_style_feat, upsample=True, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1)): super(ToRGB, self).__init__() if upsample: self.upsample = UpFirDnUpsample(resample_kernel, factor=2) @@ -447,7 +447,7 @@ class StyleGAN2Generator(nn.Module): StyleGAN2. Default: 2. resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample - kenrel to 2D resample kernel. Default: [1, 3, 3, 1]. + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. """ @@ -456,7 +456,7 @@ def __init__(self, num_style_feat=512, num_mlp=8, channel_multiplier=2, - resample_kernel=[1, 3, 3, 1], + resample_kernel=(1, 3, 3, 1), lr_mlp=0.01): super(StyleGAN2Generator, self).__init__() # Style MLP layers @@ -736,7 +736,7 @@ class ConvLayer(nn.Sequential): resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. - Default: [1, 3, 3, 1]. + Default: (1, 3, 3, 1). bias (bool): Whether with bias. Default: True. activate (bool): Whether use activateion. Default: True. """ @@ -746,7 +746,7 @@ def __init__(self, out_channels, kernel_size, downsample=False, - resample_kernel=[1, 3, 3, 1], + resample_kernel=(1, 3, 3, 1), bias=True, activate=True): layers = [] @@ -791,13 +791,11 @@ class ResBlock(nn.Module): resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. - Default: [1, 3, 3, 1]. + Default: (1, 3, 3, 1). """ - def __init__(self, - in_channels, - out_channels, - resample_kernel=[1, 3, 3, 1]): + def __init__(self, in_channels, out_channels, + resample_kernel=(1, 3, 3, 1)): super(ResBlock, self).__init__() self.conv1 = ConvLayer( @@ -836,13 +834,13 @@ class StyleGAN2Discriminator(nn.Module): StyleGAN2. Default: 2. resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample - kenrel to 2D resample kernel. Default: [1, 3, 3, 1]. + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). """ def __init__(self, out_size, channel_multiplier=2, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1)): super(StyleGAN2Discriminator, self).__init__() channels = { diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py index eaa0b53..8cb63b7 100644 --- a/basicsr/models/lr_scheduler.py +++ b/basicsr/models/lr_scheduler.py @@ -20,8 +20,8 @@ def __init__(self, optimizer, milestones, gamma=0.1, - restarts=[0], - restart_weights=[1], + restarts=(0), + restart_weights=(1), last_epoch=-1): self.milestones = Counter(milestones) self.gamma = gamma @@ -90,7 +90,7 @@ class CosineAnnealingRestartLR(_LRScheduler): def __init__(self, optimizer, periods, - restart_weights=[1], + restart_weights=(1), eta_min=0, last_epoch=-1): self.periods = periods diff --git a/basicsr/models/ops/fused_act/src/fused_bias_act.cpp b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp index cc9b8f7..85ed0a7 100755 --- a/basicsr/models/ops/fused_act/src/fused_bias_act.cpp +++ b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp @@ -2,15 +2,19 @@ #include -torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale); +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale) { +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py index 94ccf4b..290434b 100644 --- a/basicsr/models/video_gan_model.py +++ b/basicsr/models/video_gan_model.py @@ -1,142 +1,15 @@ -import importlib -import torch -from collections import OrderedDict -from copy import deepcopy - -from basicsr.models.archs import define_network +from basicsr.models.srgan_model import SRGANModel from basicsr.models.video_base_model import VideoBaseModel -loss_module = importlib.import_module('basicsr.models.losses') - - -class VideoGANModel(VideoBaseModel): - """Video GAN model.""" - - def init_training_settings(self): - train_opt = self.opt['train'] - - # define network net_d - self.net_d = define_network(deepcopy(self.opt['network_d'])) - self.net_d = self.model_to_device(self.net_d) - self.print_network(self.net_d) - - # load pretrained models - load_path = self.opt['path'].get('pretrain_model_d', None) - if load_path is not None: - self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) - - self.net_g.train() - self.net_d.train() - - # define losses - if train_opt.get('pixel_opt'): - pixel_type = train_opt['pixel_opt'].pop('type') - cri_pix_cls = getattr(loss_module, pixel_type) - self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to( - self.device) - else: - self.cri_pix = None - - if train_opt.get('perceptual_opt'): - percep_type = train_opt['perceptual_opt'].pop('type') - cri_perceptual_cls = getattr(loss_module, percep_type) - self.cri_perceptual = cri_perceptual_cls( - **train_opt['perceptual_opt']).to(self.device) - else: - self.cri_perceptual = None - - if train_opt.get('gan_opt'): - gan_type = train_opt['gan_opt'].pop('type') - cri_gan_cls = getattr(loss_module, gan_type) - self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device) - - self.net_d_iters = train_opt.get('net_d_iters', 1) - self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) - - # set up optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - - def setup_optimizers(self): - train_opt = self.opt['train'] - # optimizer g - optim_type = train_opt['optim_g'].pop('type') - if optim_type == 'Adam': - self.optimizer_g = torch.optim.Adam(self.net_g.parameters(), - **train_opt['optim_g']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - self.optimizers.append(self.optimizer_g) - # optimizer d - optim_type = train_opt['optim_d'].pop('type') - if optim_type == 'Adam': - self.optimizer_d = torch.optim.Adam(self.net_d.parameters(), - **train_opt['optim_d']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - self.optimizers.append(self.optimizer_d) - - def optimize_parameters(self, current_iter): - # optimize net_g - for p in self.net_d.parameters(): - p.requires_grad = False - - self.optimizer_g.zero_grad() - self.output = self.net_g(self.lq) - - l_g_total = 0 - loss_dict = OrderedDict() - if (current_iter % self.net_d_iters == 0 - and current_iter > self.net_d_init_iters): - # pixel loss - if self.cri_pix: - l_g_pix = self.cri_pix(self.output, self.gt) - l_g_total += l_g_pix - loss_dict['l_g_pix'] = l_g_pix - # perceptual loss - if self.cri_perceptual: - l_g_percep, l_g_style = self.cri_perceptual( - self.output, self.gt) - if l_g_percep is not None: - l_g_total += l_g_percep - loss_dict['l_g_percep'] = l_g_percep - if l_g_style is not None: - l_g_total += l_g_style - loss_dict['l_g_style'] = l_g_style - # gan loss - fake_g_pred = self.net_d(self.output) - l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) - l_g_total += l_g_gan - loss_dict['l_g_gan'] = l_g_gan - - l_g_total.backward() - self.optimizer_g.step() - - # optimize net_d - for p in self.net_d.parameters(): - p.requires_grad = True - - self.optimizer_d.zero_grad() - # real - real_d_pred = self.net_d(self.gt) - l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) - loss_dict['l_d_real'] = l_d_real - loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) - l_d_real.backward() - # fake - fake_d_pred = self.net_d(self.output.detach()) - l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) - loss_dict['l_d_fake'] = l_d_fake - loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) - l_d_fake.backward() - self.optimizer_d.step() - self.log_dict = self.reduce_loss_dict(loss_dict) +class VideoGANModel(SRGANModel, VideoBaseModel): + """Video GAN model. - def save(self, epoch, current_iter): - self.save_network(self.net_g, 'net_g', current_iter) - self.save_network(self.net_d, 'net_d', current_iter) - self.save_training_state(epoch, current_iter) + Use multiple inheritance. + It will first use the functions of SRGANModel: + init_training_settings + setup_optimizers + optimize_parameters + save + Then find functions in VideoBaseModel. + """ diff --git a/setup.py b/setup.py index 0a339ff..3050e9b 100644 --- a/setup.py +++ b/setup.py @@ -85,8 +85,9 @@ def get_version(): return locals()['__version__'] -def make_cuda_ext(name, module, sources, sources_cuda=[]): - +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] define_macros = [] extra_compile_args = {'cxx': []} From cedf2caddd84592416b1e487690f28e04c77e94f Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 4 Oct 2020 00:06:56 +0800 Subject: [PATCH 03/23] Updates (#300) * update resume pretrained paths * test_scripts * add dist util * use os * add matlab functions * scandir and bgr2rgb replace * cv2.flip * replace imwrite * update file client * update utils.util * update utils.download * use relative import * add img_util * update strict_load * updat train.py * update test.py * add flow util * update requirements.txt * fix bugs * fix bugs --- basicsr/data/__init__.py | 7 +- basicsr/data/ffhq_dataset.py | 10 +- basicsr/data/paired_image_dataset.py | 14 +- basicsr/data/reds_dataset.py | 27 +- basicsr/data/single_image_dataset.py | 14 +- basicsr/data/transforms.py | 46 +-- basicsr/data/util.py | 17 +- basicsr/data/video_test_dataset.py | 15 +- basicsr/data/vimeo90k_dataset.py | 12 +- basicsr/metrics/metric_util.py | 5 +- basicsr/models/__init__.py | 5 +- basicsr/models/archs/__init__.py | 5 +- basicsr/models/base_model.py | 2 +- basicsr/models/lr_scheduler.py | 6 +- basicsr/models/sr_model.py | 9 +- basicsr/models/srgan_model.py | 4 +- basicsr/models/stylegan2_model.py | 22 +- basicsr/models/video_base_model.py | 7 +- basicsr/test.py | 54 +--- basicsr/train.py | 116 ++++--- basicsr/utils/__init__.py | 31 +- basicsr/utils/dist_util.py | 83 +++++ basicsr/utils/download.py | 2 +- basicsr/utils/file_client.py | 296 +++++++++++------- basicsr/utils/flow_util.py | 180 +++++++++++ basicsr/utils/img_util.py | 162 ++++++++++ basicsr/utils/lmdb.py | 7 +- basicsr/utils/logger.py | 7 +- basicsr/utils/matlab_functions.py | 192 ++++++++++++ basicsr/utils/options.py | 2 +- basicsr/utils/util.py | 183 +++++------ docs/Config.md | 10 +- docs/Config_CN.md | 10 +- docs/DatasetPreparation.md | 2 +- docs/DatasetPreparation_CN.md | 2 +- docs/DesignConvention.md | 2 +- docs/DesignConvention_CN.md | 2 +- options/test/DUF/test_DUF_official.yml | 4 +- options/test/EDSR/test_EDSR_Lx2.yml | 4 +- options/test/EDSR/test_EDSR_Lx3.yml | 4 +- options/test/EDSR/test_EDSR_Lx4.yml | 4 +- options/test/EDSR/test_EDSR_Mx2.yml | 4 +- options/test/EDSR/test_EDSR_Mx3.yml | 4 +- options/test/EDSR/test_EDSR_Mx4.yml | 4 +- options/test/EDVR/test_EDVR_L_deblur_REDS.yml | 4 +- .../test/EDVR/test_EDVR_L_deblurcomp_REDS.yml | 4 +- options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml | 4 +- options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml | 4 +- .../test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml | 4 +- .../test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml | 4 +- options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml | 4 +- options/test/ESRGAN/test_ESRGAN_x4.yml | 4 +- options/test/ESRGAN/test_ESRGAN_x4_woGT.yml | 4 +- options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml | 4 +- options/test/RCAN/test_RCAN.yml | 4 +- .../test/SRResNet_SRGAN/test_MSRGAN_x4.yml | 4 +- .../test/SRResNet_SRGAN/test_MSRResNet_x2.yml | 4 +- .../test/SRResNet_SRGAN/test_MSRResNet_x3.yml | 4 +- .../test/SRResNet_SRGAN/test_MSRResNet_x4.yml | 4 +- .../SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml | 4 +- options/test/TOF/test_TOF_official.yml | 4 +- options/train/EDSR/train_EDSR_Lx2.yml | 4 +- options/train/EDSR/train_EDSR_Lx3.yml | 4 +- options/train/EDSR/train_EDSR_Lx4.yml | 4 +- options/train/EDSR/train_EDSR_Mx2.yml | 4 +- options/train/EDSR/train_EDSR_Mx3.yml | 4 +- options/train/EDSR/train_EDSR_Mx4.yml | 4 +- .../train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml | 4 +- .../train/EDVR/train_EDVR_L_x4_SR_REDS.yml | 4 +- .../EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml | 4 +- .../train/EDVR/train_EDVR_M_x4_SR_REDS.yml | 4 +- .../EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml | 4 +- options/train/ESRGAN/train_ESRGAN_x4.yml | 4 +- .../train/ESRGAN/train_RRDBNet_PSNR_x4.yml | 4 +- options/train/RCAN/train_RCAN_x2.yml | 4 +- .../train/SRResNet_SRGAN/train_MSRGAN_x4.yml | 4 +- .../SRResNet_SRGAN/train_MSRResNet_x2.yml | 4 +- .../SRResNet_SRGAN/train_MSRResNet_x3.yml | 4 +- .../SRResNet_SRGAN/train_MSRResNet_x4.yml | 4 +- .../train_StyleGAN2_256_Cmul2_FFHQ.yml | 4 +- requirements.txt | 3 +- scripts/calculate_psnr_ssim.py | 19 +- scripts/create_lmdb.py | 6 +- scripts/download_pretrained_models.py | 4 +- scripts/extract_subimages.py | 6 +- scripts/generate_meta_info.py | 5 +- setup.cfg | 2 +- {tests => test_scripts}/test_face_dfdnet.py | 12 +- {tests => test_scripts}/test_stylegan2.py | 4 +- tests/test_ffhq_dataset.py | 4 +- tests/test_paired_image_dataset.py | 4 +- tests/test_reds_dataset.py | 4 +- tests/test_vimeo90k_dataset.py | 4 +- 93 files changed, 1216 insertions(+), 603 deletions(-) create mode 100644 basicsr/utils/dist_util.py create mode 100644 basicsr/utils/flow_util.py create mode 100644 basicsr/utils/img_util.py create mode 100644 basicsr/utils/matlab_functions.py rename {tests => test_scripts}/test_face_dfdnet.py (97%) rename {tests => test_scripts}/test_stylegan2.py (97%) diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py index c7b09bc..8232845 100644 --- a/basicsr/data/__init__.py +++ b/basicsr/data/__init__.py @@ -1,15 +1,14 @@ import importlib -import mmcv import numpy as np import random import torch import torch.utils.data from functools import partial -from mmcv.runner import get_dist_info from os import path as osp from basicsr.data.prefetch_dataloader import PrefetchDataLoader -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info __all__ = ['create_dataset', 'create_dataloader'] @@ -17,7 +16,7 @@ # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py index b4cabab..ef93ed6 100644 --- a/basicsr/data/ffhq_dataset.py +++ b/basicsr/data/ffhq_dataset.py @@ -1,11 +1,9 @@ -import mmcv -import numpy as np from os import path as osp from torch.utils import data as data from torchvision.transforms.functional import normalize -from basicsr.data.transforms import augment, totensor -from basicsr.utils import FileClient +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, imfrombytes, img2tensor class FFHQDataset(data.Dataset): @@ -53,12 +51,12 @@ def __getitem__(self, index): # load gt image gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor - img_gt = totensor(img_gt, bgr2rgb=True, float32=True) + img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path} diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index c5b01a8..0e2de96 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -1,12 +1,10 @@ -import mmcv -import numpy as np from torch.utils import data as data -from basicsr.data.transforms import augment, paired_random_crop, totensor +from basicsr.data.transforms import augment, paired_random_crop from basicsr.data.util import (paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file) -from basicsr.utils import FileClient +from basicsr.utils import FileClient, imfrombytes, img2tensor class PairedImageDataset(data.Dataset): @@ -79,10 +77,10 @@ def __getitem__(self, index): # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': @@ -96,7 +94,9 @@ def __getitem__(self, index): # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = totensor([img_gt, img_lq], bgr2rgb=True, float32=True) + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) return { 'lq': img_lq, diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py index 8f5f1de..7f7db26 100644 --- a/basicsr/data/reds_dataset.py +++ b/basicsr/data/reds_dataset.py @@ -1,12 +1,12 @@ -import mmcv import numpy as np import random import torch from pathlib import Path from torch.utils import data as data -from basicsr.data.transforms import augment, paired_random_crop, totensor -from basicsr.utils import FileClient, get_root_logger +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.flow_util import dequantize_flow class REDSDataset(data.Dataset): @@ -144,7 +144,7 @@ def __getitem__(self, index): else: img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] @@ -154,7 +154,7 @@ def __getitem__(self, index): else: img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # get flows @@ -168,10 +168,11 @@ def __getitem__(self, index): flow_path = ( self.flow_root / clip_name / f'{frame_name}_p{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') - cat_flow = mmcv.imfrombytes( - img_bytes, flag='grayscale') # uint8, [0, 255] + cat_flow = imfrombytes( + img_bytes, flag='grayscale', + float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) - flow = mmcv.video.dequantize_flow( + flow = dequantize_flow( dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) @@ -183,9 +184,11 @@ def __getitem__(self, index): flow_path = ( self.flow_root / clip_name / f'{frame_name}_n{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') - cat_flow = mmcv.imfrombytes(img_bytes, flag='grayscale') + cat_flow = imfrombytes( + img_bytes, flag='grayscale', + float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) - flow = mmcv.video.dequantize_flow( + flow = dequantize_flow( dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) @@ -210,12 +213,12 @@ def __getitem__(self, index): img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) - img_results = totensor(img_results) + img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] if self.flow_root is not None: - img_flows = totensor(img_flows) + img_flows = img2tensor(img_flows) # add the zero center flow img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py index 3a15934..cb1bc01 100644 --- a/basicsr/data/single_image_dataset.py +++ b/basicsr/data/single_image_dataset.py @@ -1,11 +1,8 @@ -import mmcv -import numpy as np from os import path as osp from torch.utils import data as data from torchvision.transforms.functional import normalize -from basicsr.data.transforms import totensor -from basicsr.utils import FileClient +from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir class SingleImageDataset(data.Dataset): @@ -40,10 +37,7 @@ def __init__(self, opt): line.split(' ')[0]) for line in fin ] else: - self.paths = [ - osp.join(self.lq_folder, v) - for v in mmcv.scandir(self.lq_folder) - ] + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) def __getitem__(self, index): if self.file_client is None: @@ -53,11 +47,11 @@ def __getitem__(self, index): # load lq image lq_path = self.paths[index] img_bytes = self.file_client.get(lq_path) - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor - img_lq = totensor(img_lq, bgr2rgb=True, float32=True) + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py index 6c7eb80..d7d1e7a 100644 --- a/basicsr/data/transforms.py +++ b/basicsr/data/transforms.py @@ -1,6 +1,5 @@ -import mmcv +import cv2 import random -import torch def mod_crop(img, scale): @@ -110,20 +109,20 @@ def augment(imgs, hflip=True, rotation=True, flows=None): rot90 = rotation and random.random() < 0.5 def _augment(img): - if hflip: - mmcv.imflip_(img, 'horizontal') - if vflip: - mmcv.imflip_(img, 'vertical') + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) if rot90: img = img.transpose(1, 0, 2) return img def _augment_flow(flow): - if hflip: - mmcv.imflip_(flow, 'horizontal') + if hflip: # horizontal + cv2.flip(flow, 1, flow) flow[:, :, 0] *= -1 - if vflip: - mmcv.imflip_(flow, 'vertical') + if vflip: # vertical + cv2.flip(flow, 0, flow) flow[:, :, 1] *= -1 if rot90: flow = flow.transpose(1, 0, 2) @@ -145,30 +144,3 @@ def _augment_flow(flow): return imgs, flows else: return imgs - - -def totensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - img = mmcv.bgr2rgb(img) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) diff --git a/basicsr/data/util.py b/basicsr/data/util.py index 50245ac..b4a14e9 100644 --- a/basicsr/data/util.py +++ b/basicsr/data/util.py @@ -1,10 +1,11 @@ -import mmcv +import cv2 import numpy as np import torch from os import path as osp from torch.nn import functional as F -from basicsr.data.transforms import mod_crop, totensor +from basicsr.data.transforms import mod_crop +from basicsr.utils import img2tensor, scandir def read_img_seq(path, require_mod_crop=False, scale=1): @@ -22,11 +23,11 @@ def read_img_seq(path, require_mod_crop=False, scale=1): if isinstance(path, list): img_paths = path else: - img_paths = sorted([osp.join(path, v) for v in mmcv.scandir(path)]) - imgs = [mmcv.imread(v).astype(np.float32) / 255. for v in img_paths] + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] if require_mod_crop: imgs = [mod_crop(img, scale) for img in imgs] - imgs = totensor(imgs, bgr2rgb=True, float32=True) + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) imgs = torch.stack(imgs, dim=0) return imgs @@ -227,8 +228,8 @@ def paired_paths_from_folder(folders, keys, filename_tmpl): input_folder, gt_folder = folders input_key, gt_key = keys - input_paths = list(mmcv.scandir(input_folder)) - gt_paths = list(mmcv.scandir(gt_folder)) + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) assert len(input_paths) == len(gt_paths), ( f'{input_key} and {gt_key} datasets have different number of images: ' f'{len(input_paths)}, {len(gt_paths)}.') @@ -256,7 +257,7 @@ def paths_from_folder(folder): list[str]: Returned path list. """ - paths = list(mmcv.scandir(folder)) + paths = list(scandir(folder)) paths = [osp.join(folder, path) for path in paths] return paths diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py index 0ab7d99..d3d21f9 100644 --- a/basicsr/data/video_test_dataset.py +++ b/basicsr/data/video_test_dataset.py @@ -1,12 +1,11 @@ import glob -import mmcv import torch from os import path as osp from torch.utils import data as data from basicsr.data import util as util from basicsr.data.util import duf_downsample -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir class VideoTestDataset(data.Dataset): @@ -81,14 +80,10 @@ def __init__(self, opt): subfolders_gt): # get frame list for lq and gt subfolder_name = osp.basename(subfolder_lq) - img_paths_lq = sorted([ - osp.join(subfolder_lq, v) - for v in mmcv.scandir(subfolder_lq) - ]) - img_paths_gt = sorted([ - osp.join(subfolder_gt, v) - for v in mmcv.scandir(subfolder_gt) - ]) + img_paths_lq = sorted( + list(scandir(subfolder_lq, full_path=True))) + img_paths_gt = sorted( + list(scandir(subfolder_gt, full_path=True))) max_idx = len(img_paths_lq) assert max_idx == len(img_paths_gt), ( diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py index a88216e..71d5d11 100644 --- a/basicsr/data/vimeo90k_dataset.py +++ b/basicsr/data/vimeo90k_dataset.py @@ -1,12 +1,10 @@ -import mmcv -import numpy as np import random import torch from pathlib import Path from torch.utils import data as data -from basicsr.data.transforms import augment, paired_random_crop, totensor -from basicsr.utils import FileClient, get_root_logger +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor class Vimeo90KDataset(data.Dataset): @@ -97,7 +95,7 @@ def __getitem__(self, index): else: img_gt_path = self.gt_root / clip / seq / 'im4.png' img_bytes = self.file_client.get(img_gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] @@ -107,7 +105,7 @@ def __getitem__(self, index): else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # randomly crop @@ -119,7 +117,7 @@ def __getitem__(self, index): img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) - img_results = totensor(img_results) + img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py index 55a2b3a..4258781 100644 --- a/basicsr/metrics/metric_util.py +++ b/basicsr/metrics/metric_util.py @@ -1,6 +1,7 @@ -import mmcv import numpy as np +from basicsr.utils.matlab_functions import bgr2ycbcr + def reorder_image(img, input_order='HWC'): """Reorder images to 'HWC' order. @@ -42,6 +43,6 @@ def to_y_channel(img): """ img = img.astype(np.float32) / 255. if img.ndim == 3 and img.shape[2] == 3: - img = mmcv.bgr2ycbcr(img, y_only=True) + img = bgr2ycbcr(img, y_only=True) img = img[..., None] return img * 255. diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py index 4f173de..10f3b9f 100644 --- a/basicsr/models/__init__.py +++ b/basicsr/models/__init__.py @@ -1,15 +1,14 @@ import importlib -import mmcv from os import path as osp -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir # automatically scan and import model modules # scan all the files under the 'models' folder and collect files ending with # '_model.py' model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(model_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py') ] # import all the model modules diff --git a/basicsr/models/archs/__init__.py b/basicsr/models/archs/__init__.py index a00982a..40410be 100644 --- a/basicsr/models/archs/__init__.py +++ b/basicsr/models/archs/__init__.py @@ -1,13 +1,14 @@ import importlib -import mmcv from os import path as osp +from basicsr.utils import scandir + # automatically scan and import arch modules # scan all the files under the 'archs' folder and collect files ending with # '_arch.py' arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(arch_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py') ] # import all the arch modules diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index 5baa524..3bb89b2 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -3,10 +3,10 @@ import torch from collections import OrderedDict from copy import deepcopy -from mmcv.runner import master_only from torch.nn.parallel import DataParallel, DistributedDataParallel from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils.dist_util import master_only logger = logging.getLogger('basicsr') diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py index 8cb63b7..a2b4d35 100644 --- a/basicsr/models/lr_scheduler.py +++ b/basicsr/models/lr_scheduler.py @@ -20,8 +20,8 @@ def __init__(self, optimizer, milestones, gamma=0.1, - restarts=(0), - restart_weights=(1), + restarts=(0, ), + restart_weights=(1, ), last_epoch=-1): self.milestones = Counter(milestones) self.gamma = gamma @@ -90,7 +90,7 @@ class CosineAnnealingRestartLR(_LRScheduler): def __init__(self, optimizer, periods, - restart_weights=(1), + restart_weights=(1, ), eta_min=0, last_epoch=-1): self.periods = periods diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 66a98b6..32bb794 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -1,5 +1,4 @@ import importlib -import mmcv import torch from collections import OrderedDict from copy import deepcopy @@ -7,7 +6,7 @@ from basicsr.models.archs import define_network from basicsr.models.base_model import BaseModel -from basicsr.utils import ProgressBar, get_root_logger, tensor2img +from basicsr.utils import ProgressBar, get_root_logger, imwrite, tensor2img loss_module = importlib.import_module('basicsr.models.losses') metric_module = importlib.import_module('basicsr.metrics') @@ -25,10 +24,10 @@ def __init__(self, opt): self.print_network(self.net_g) # load pretrained models - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_g', True)) if self.is_train: self.init_training_settings() @@ -163,7 +162,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img_path = osp.join( self.opt['path']['visualization'], dataset_name, f'{img_name}_{self.opt["name"]}.png') - mmcv.imwrite(sr_img, save_img_path) + imwrite(sr_img, save_img_path) if with_metrics: # calculate metrics diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py index d927773..7d08d7b 100644 --- a/basicsr/models/srgan_model.py +++ b/basicsr/models/srgan_model.py @@ -21,10 +21,10 @@ def init_training_settings(self): self.print_network(self.net_d) # load pretrained models - load_path = self.opt['path'].get('pretrain_model_d', None) + load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_d', True)) self.net_g.train() self.net_d.train() diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py index 7cf7aec..c1ac6cf 100644 --- a/basicsr/models/stylegan2_model.py +++ b/basicsr/models/stylegan2_model.py @@ -1,6 +1,6 @@ +import cv2 import importlib import math -import mmcv import numpy as np import random import torch @@ -11,7 +11,7 @@ from basicsr.models.archs import define_network from basicsr.models.base_model import BaseModel from basicsr.models.losses.losses import g_path_regularize, r1_penalty -from basicsr.utils import tensor2img +from basicsr.utils import imwrite, tensor2img loss_module = importlib.import_module('basicsr.models.losses') @@ -27,11 +27,12 @@ def __init__(self, opt): self.net_g = self.model_to_device(self.net_g) self.print_network(self.net_g) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: param_key = self.opt['path'].get('param_key_g', 'params') self.load_network(self.net_g, load_path, - self.opt['path']['strict_load'], param_key) + self.opt['path'].get('strict_load_g', + True), param_key) # latent dimension: self.num_style_feat self.num_style_feat = opt['network_g']['num_style_feat'] @@ -51,10 +52,10 @@ def init_training_settings(self): self.print_network(self.net_d) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_d', None) + load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_d', True)) # define network net_g with Exponential Moving Average (EMA) # net_g_ema only used for testing on one GPU and saving, do not need to @@ -62,10 +63,11 @@ def init_training_settings(self): self.net_g_ema = define_network(deepcopy(self.opt['network_g'])).to( self.device) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, - self.opt['path']['strict_load'], 'params_ema') + self.opt['path'].get('strict_load_g', + True), 'params_ema') else: self.model_ema(0) # copy net_g weight @@ -311,10 +313,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, else: save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png') - mmcv.imwrite(result, save_img_path) + imwrite(result, save_img_path) # add sample images to tb_logger result = (result / 255.).astype(np.float32) - result = mmcv.bgr2rgb(result) + result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) if tb_logger is not None: tb_logger.add_image( 'samples', result, global_step=current_iter, dataformats='HWC') diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index 6e70eed..cbacd99 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -1,14 +1,13 @@ import importlib -import mmcv import torch from collections import Counter from copy import deepcopy -from mmcv.runner import get_dist_info from os import path as osp from torch import distributed as dist from basicsr.models.sr_model import SRModel -from basicsr.utils import ProgressBar, get_root_logger, tensor2img +from basicsr.utils import ProgressBar, get_root_logger, imwrite, tensor2img +from basicsr.utils.dist_util import get_dist_info metric_module = importlib.import_module('basicsr.metrics') @@ -83,7 +82,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): save_img_path = osp.join( self.opt['path']['visualization'], dataset_name, folder, f'{img_name}_{self.opt["name"]}.png') - mmcv.imwrite(result_img, save_img_path) + imwrite(result_img, save_img_path) if with_metrics: # calculate metrics diff --git a/basicsr/test.py b/basicsr/test.py index 7bdae15..622df4e 100644 --- a/basicsr/test.py +++ b/basicsr/test.py @@ -1,46 +1,23 @@ -import argparse import logging -import random import torch -from mmcv.runner import get_dist_info, get_time_str, init_dist from os import path as osp from basicsr.data import create_dataloader, create_dataset from basicsr.models import create_model -from basicsr.utils import (get_env_info, get_root_logger, make_exp_dirs, - set_random_seed) -from basicsr.utils.options import dict2str, parse +from basicsr.train import parse_options +from basicsr.utils import (get_env_info, get_root_logger, get_time_str, + make_exp_dirs) +from basicsr.utils.options import dict2str def main(): - # options - parser = argparse.ArgumentParser() - parser.add_argument( - '-opt', type=str, required=True, help='Path to option YAML file.') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() - opt = parse(args.opt, is_train=False) + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=False) - # distributed testing settings - if args.launcher == 'none': # non-distributed testing - opt['dist'] = False - print('Disable distributed testing.', flush=True) - else: - opt['dist'] = True - if args.launcher == 'slurm' and 'dist_params' in opt: - init_dist(args.launcher, **opt['dist_params']) - else: - init_dist(args.launcher) - - rank, world_size = get_dist_info() - opt['rank'] = rank - opt['world_size'] = world_size + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + # mkdir and initialize loggers make_exp_dirs(opt) log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") @@ -49,17 +26,6 @@ def main(): logger.info(get_env_info()) logger.info(dict2str(opt)) - # random seed - seed = opt['manual_seed'] - if seed is None: - seed = random.randint(1, 10000) - opt['manual_seed'] = seed - logger.info(f'Random seed: {seed}') - set_random_seed(seed + rank) - - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True - # create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): @@ -70,7 +36,7 @@ def main(): num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, - seed=seed) + seed=opt['manual_seed']) logger.info( f"Number of test images in {dataset_opt['name']}: {len(test_set)}") test_loaders.append(test_loader) diff --git a/basicsr/train.py b/basicsr/train.py index 0d769c8..02a460f 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -5,7 +5,6 @@ import random import time import torch -from mmcv.runner import get_dist_info, get_time_str, init_dist from os import path as osp from basicsr.data import create_dataloader, create_dataset @@ -13,13 +12,14 @@ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher from basicsr.models import create_model from basicsr.utils import (MessageLogger, check_resume, get_env_info, - get_root_logger, init_tb_logger, init_wandb_logger, - make_exp_dirs, mkdir_and_rename, set_random_seed) + get_root_logger, get_time_str, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, + set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist from basicsr.utils.options import dict2str, parse -def main(): - # options +def parse_options(is_train=True): parser = argparse.ArgumentParser() parser.add_argument( '-opt', type=str, required=True, help='Path to option YAML file.') @@ -30,12 +30,12 @@ def main(): help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() - opt = parse(args.opt, is_train=True) + opt = parse(args.opt, is_train=is_train) - # distributed training settings - if args.launcher == 'none': # non-distributed training + # distributed settings + if args.launcher == 'none': opt['dist'] = False - print('Disable distributed training.', flush=True) + print('Disable distributed.', flush=True) else: opt['dist'] = True if args.launcher == 'slurm' and 'dist_params' in opt: @@ -43,68 +43,55 @@ def main(): else: init_dist(args.launcher) - rank, world_size = get_dist_info() - opt['rank'] = rank - opt['world_size'] = world_size + opt['rank'], opt['world_size'] = get_dist_info() - # load resume states if exists - if opt['path'].get('resume_state'): - device_id = torch.cuda.current_device() - resume_state = torch.load( - opt['path']['resume_state'], - map_location=lambda storage, loc: storage.cuda(device_id)) - else: - resume_state = None + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) - # mkdir and loggers - if resume_state is None: - make_exp_dirs(opt) + return opt + + +def init_loggers(opt): log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") logger = get_root_logger( logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) + # initialize tensorboard logger and wandb logger tb_logger = None if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: - log_dir = './tb_logger/' + opt['name'] - if resume_state is None and opt['rank'] == 0: - mkdir_and_rename(log_dir) - tb_logger = init_tb_logger(log_dir=log_dir) + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None) and ('debug' not in opt['name']): assert opt['logger'].get('use_tb_logger') is True, ( 'should turn on tensorboard when using wandb') init_wandb_logger(opt) + return logger, tb_logger - # random seed - seed = opt['manual_seed'] - if seed is None: - seed = random.randint(1, 10000) - opt['manual_seed'] = seed - logger.info(f'Random seed: {seed}') - set_random_seed(seed + rank) - - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True +def create_train_val_dataloader(opt, logger): # create train and val dataloaders train_loader, val_loader = None, None for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) train_set = create_dataset(dataset_opt) - train_sampler = EnlargedSampler(train_set, world_size, rank, - dataset_enlarge_ratio) + train_sampler = EnlargedSampler(train_set, opt['world_size'], + opt['rank'], dataset_enlarge_ratio) train_loader = create_dataloader( train_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=train_sampler, - seed=seed) + seed=opt['manual_seed']) num_iter_per_epoch = math.ceil( len(train_set) * dataset_enlarge_ratio / @@ -119,6 +106,7 @@ def main(): f'\n\tWorld size (gpu number): {opt["world_size"]}' f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader( @@ -127,27 +115,57 @@ def main(): num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, - seed=seed) + seed=opt['manual_seed']) logger.info( f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') else: raise ValueError(f'Dataset phase {phase} is not recognized.') - assert train_loader is not None - # create model - if resume_state: - check_resume(opt, resume_state['iter']) # modify pretrain_model paths - model = create_model(opt) + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ + 'name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result - # resume training - if resume_state: + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = create_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") start_epoch = resume_state['epoch'] current_iter = resume_state['iter'] - model.resume_training(resume_state) # handle optimizers and schedulers else: + model = create_model(opt) start_epoch = 0 current_iter = 0 diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index 95f7a50..e547f67 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -1,12 +1,31 @@ from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img from .logger import (MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger) -from .util import (ProgressBar, check_resume, crop_border, make_exp_dirs, - mkdir_and_rename, set_random_seed, tensor2img) +from .util import (ProgressBar, check_resume, get_time_str, make_exp_dirs, + mkdir_and_rename, scandir, set_random_seed) __all__ = [ - 'FileClient', 'MessageLogger', 'get_root_logger', 'make_exp_dirs', - 'init_tb_logger', 'init_wandb_logger', 'set_random_seed', 'ProgressBar', - 'tensor2img', 'crop_border', 'check_resume', 'mkdir_and_rename', - 'get_env_info' + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # util.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'ProgressBar' ] diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py new file mode 100644 index 0000000..43cf4cd --- /dev/null +++ b/basicsr/utils/dist_util.py @@ -0,0 +1,83 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/basicsr/utils/download.py b/basicsr/utils/download.py index e03516c..3cd696c 100644 --- a/basicsr/utils/download.py +++ b/basicsr/utils/download.py @@ -1,7 +1,7 @@ import math import requests -from basicsr.utils import ProgressBar +from .util import ProgressBar def download_file_from_google_drive(file_id, save_path): diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py index 1d8e5cf..066b22f 100644 --- a/basicsr/utils/file_client.py +++ b/basicsr/utils/file_client.py @@ -1,113 +1,183 @@ -from mmcv.fileio.file_client import (BaseStorageBackend, CephBackend, - HardDiskBackend, MemcachedBackend) - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_paths (str | list[str]): Lmdb database paths. - client_keys (str | list[str]): Lmdb client keys. Default: 'default'. - readonly (bool, optional): Lmdb environment parameter. If True, - disallow any write operations. Default: True. - lock (bool, optional): Lmdb environment parameter. If False, when - concurrent access occurs, do not lock the database. Default: False. - readahead (bool, optional): Lmdb environment parameter. If False, - disable the OS filesystem readahead mechanism, which may improve - random read performance when a database is larger than RAM. - Default: False. - - Attributes: - db_paths (list): Lmdb database path. - _client (list): A list of several lmdb envs. - """ - - def __init__(self, - db_paths, - client_keys='default', - readonly=True, - lock=False, - readahead=False, - **kwargs): - try: - import lmdb - except ImportError: - raise ImportError('Please install lmdb to enable LmdbBackend.') - - if isinstance(client_keys, str): - client_keys = [client_keys] - - if isinstance(db_paths, list): - self.db_paths = [str(v) for v in db_paths] - elif isinstance(db_paths, str): - self.db_paths = [str(db_paths)] - assert len(client_keys) == len(self.db_paths), ( - 'client_keys and db_paths should have the same length, ' - f'but received {len(client_keys)} and {len(self.db_paths)}.') - - self._client = {} - for client, path in zip(client_keys, self.db_paths): - self._client[client] = lmdb.open( - path, - readonly=readonly, - lock=lock, - readahead=readahead, - **kwargs) - - def get(self, filepath, client_key): - """Get values according to the filepath from one lmdb named client_key. - - Args: - filepath (str | obj:`Path`): Here, filepath is the lmdb key. - client_key (str): Used for distinguishing differnet lmdb envs. - """ - filepath = str(filepath) - assert client_key in self._client, (f'client_key {client_key} is not ' - 'in lmdb clients.') - client = self._client[client_key] - with client.begin(write=False) as txn: - value_buf = txn.get(filepath.encode('ascii')) - return value_buf - - def get_text(self, filepath): - raise NotImplementedError - - -class FileClient(object): - """A general file client to access files in different backend. - - The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. - - Attributes: - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached" and "lmdb". - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - 'disk': HardDiskBackend, - 'ceph': CephBackend, - 'memcached': MemcachedBackend, - 'lmdb': LmdbBackend, - } - - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: - raise ValueError( - f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) - - def get(self, filepath, client_key='default'): - # client_key is used only for lmdb, where different fileclients have - # different lmdb environments. - if self.backend == 'lmdb': - return self.client.get(filepath, client_key) - else: - return self.client.get(filepath) - - def get_text(self, filepath): - return self.client.get_text(filepath) +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, + db_paths, + client_keys='default', + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ( + 'client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open( + path, + readonly=readonly, + lock=lock, + readahead=readahead, + **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' + 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py new file mode 100644 index 0000000..2b052cc --- /dev/null +++ b/basicsr/utils/flow_util.py @@ -0,0 +1,180 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(filename, exist_ok=True) + cv2.imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py new file mode 100644 index 0000000..4096cfd --- /dev/null +++ b/basicsr/utils/img_util.py @@ -0,0 +1,162 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or + (isinstance(tensor, list) + and all(torch.is_tensor(t) for t in tensor))): + raise TypeError( + f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid( + _tensor, nrow=int(math.sqrt(_tensor.size(0))), + normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' + f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = { + 'color': cv2.IMREAD_COLOR, + 'grayscale': cv2.IMREAD_GRAYSCALE, + 'unchanged': cv2.IMREAD_UNCHANGED + } + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [ + v[crop_border:-crop_border, crop_border:-crop_border, ...] + for v in imgs + ] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, + ...] diff --git a/basicsr/utils/lmdb.py b/basicsr/utils/lmdb.py index 8e3e99d..c99f50b 100644 --- a/basicsr/utils/lmdb.py +++ b/basicsr/utils/lmdb.py @@ -1,6 +1,5 @@ import cv2 import lmdb -import mmcv import sys from multiprocessing import Pool from os import path as osp @@ -96,8 +95,8 @@ def callback(arg): # create lmdb environment if map_size is None: # obtain data size for one image - img = mmcv.imread( - osp.join(data_path, img_path_list[0]), flag='unchanged') + img = cv2.imread( + osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) _, img_byte = cv2.imencode( '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) data_size_per_img = img_byte.nbytes @@ -148,7 +147,7 @@ def read_img_worker(path, key, compress_level): tuple[int]: Image shape. """ - img = mmcv.imread(path, flag='unchanged') + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) if img.ndim == 2: h, w = img.shape c = 1 diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py index 6aee50b..48671ed 100644 --- a/basicsr/utils/logger.py +++ b/basicsr/utils/logger.py @@ -1,7 +1,8 @@ import datetime import logging import time -from mmcv.runner import get_dist_info, master_only + +from .dist_util import get_dist_info, master_only class MessageLogger(): @@ -153,7 +154,6 @@ def get_env_info(): Currently, only log the software version. """ - import mmcv import torch import torchvision @@ -173,6 +173,5 @@ def get_env_info(): msg += ('\nVersion Information: ' f'\n\tBasicSR: {__version__}' f'\n\tPyTorch: {torch.__version__}' - f'\n\tTorchVision: {torchvision.__version__}' - f'\n\tMMCV: {mmcv.__version__}') + f'\n\tTorchVision: {torchvision.__version__}') return msg diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000..ad487f2 --- /dev/null +++ b/basicsr/utils/matlab_functions.py @@ -0,0 +1,192 @@ +import numpy as np + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index f7717f0..7042b85 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -58,7 +58,7 @@ def parse(opt_path, is_train=True): # paths for key, path in opt['path'].items(): - if path and key != 'strict_load': + if path and 'strict_load' not in key: opt['path'][key] = osp.expanduser(path) opt['path']['root'] = osp.abspath( osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) diff --git a/basicsr/utils/util.py b/basicsr/utils/util.py index 7419e7b..f0440e4 100644 --- a/basicsr/utils/util.py +++ b/basicsr/utils/util.py @@ -1,44 +1,27 @@ -import math -import mmcv import numpy as np import os import random import sys import time import torch -from mmcv.runner import get_time_str, master_only from os import path as osp from shutil import get_terminal_size -from torchvision.utils import make_grid -from basicsr.utils import get_root_logger +from .dist_util import master_only +from .logger import get_root_logger -def check_resume(opt, resume_iter): - """Check resume states and pretrain_model paths. - - Args: - opt (dict): Options. - resume_iter (int): Resume iteration. - """ - logger = get_root_logger() - if opt['path']['resume_state']: - # ignore pretrained model paths - if opt['path'].get('pretrain_model_g') is not None or opt['path'].get( - 'pretrain_model_d') is not None: - logger.warning( - 'pretrain_model path will be ignored during resuming.') +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) - # set pretrained model paths - opt['path']['pretrain_model_g'] = osp.join(opt['path']['models'], - f'net_g_{resume_iter}.pth') - logger.info( - f"Set pretrain_model_g to {opt['path']['pretrain_model_g']}") - opt['path']['pretrain_model_d'] = osp.join(opt['path']['models'], - f'net_d_{resume_iter}.pth') - logger.info( - f"Set pretrain_model_d to {opt['path']['pretrain_model_d']}") +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) def mkdir_and_rename(path): @@ -51,7 +34,7 @@ def mkdir_and_rename(path): new_name = path + '_archived_' + get_time_str() print(f'Path already exists. Rename it to {new_name}', flush=True) os.rename(path, new_name) - mmcv.mkdir_or_exist(path) + os.makedirs(path, exist_ok=True) @master_only @@ -62,101 +45,81 @@ def make_exp_dirs(opt): mkdir_and_rename(path_opt.pop('experiments_root')) else: mkdir_and_rename(path_opt.pop('results_root')) - path_opt.pop('strict_load') for key, path in path_opt.items(): - if 'pretrain_model' not in key and 'resume' not in key: - mmcv.mkdir_or_exist(path) - - -def set_random_seed(seed): - """Set random seeds.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if ('strict_load' not in key) and ('pretrain_network' + not in key) and ('resume' + not in key): + os.makedirs(path, exist_ok=True) -def crop_border(imgs, crop_border): - """Crop borders of images. +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. Args: - imgs (list[ndarray] | ndarray): Images with shape (h, w, c). - crop_border (int): Crop border for each end of height and weight. + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. Returns: - list[ndarray]: Cropped images. + A generator for all the interested files with relative pathes. """ - if crop_border == 0: - return imgs - else: - if isinstance(imgs, list): - return [ - v[crop_border:-crop_border, crop_border:-crop_border, ...] - for v in imgs - ] - else: - return imgs[crop_border:-crop_border, crop_border:-crop_border, - ...] + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') -def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - """Convert torch Tensors into image numpy arrays. + root = dir_path - After clamping to [min, max], values will be normalized to [0, 1]. + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) - Args: - tensor (Tensor or list[Tensor]): Accept shapes: - 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); - 2) 3D Tensor of shape (3/1 x H x W); - 3) 2D Tensor of shape (H x W). - Tensor channel should be in RGB order. - out_type (numpy type): output types. If ``np.uint8``, transform outputs - to uint8 type with range [0, 255]; otherwise, float type with - range [0, 1]. Default: ``np.uint8``. - min_max (tuple[int]): min and max values for clamp. + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue - Returns: - (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of - shape (H x W). The channel order is BGR. + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) - and all(torch.is_tensor(t) for t in tensor))): - raise TypeError( - f'tensor or list of tensors expected, got {type(tensor)}') - - if torch.is_tensor(tensor): - tensor = [tensor] - result = [] - for _tensor in tensor: - _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) - _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) - - n_dim = _tensor.dim() - if n_dim == 4: - img_np = make_grid( - _tensor, nrow=int(math.sqrt(_tensor.size(0))), - normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], - (1, 2, 0)) # HWC, BGR - elif n_dim == 3: - img_np = _tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], - (1, 2, 0)) # HWC, BGR - elif n_dim == 2: - img_np = _tensor.numpy() - else: - raise TypeError('Only support 4D, 3D or 2D tensor. ' - f'But received with dimension: {n_dim}') - if out_type == np.uint8: - # Unlike MATLAB, numpy.unit8() WILL NOT round by default. - img_np = (img_np * 255.0).round() - img_np = img_np.astype(out_type) - result.append(img_np) - if len(result) == 1: - result = result[0] - return result + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning( + 'pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + opt['path'][name] = osp.join(opt['path']['models'], + f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") class ProgressBar(object): diff --git a/docs/Config.md b/docs/Config.md index f2a0775..5a3b04f 100644 --- a/docs/Config.md +++ b/docs/Config.md @@ -127,11 +127,11 @@ network_g: ######################################################### path: # Path for pretrained models, usually end with pth - pretrain_model_g: ~ + pretrain_network_g: ~ # Whether to load pretrained models strictly, that is the corresponding parameter names should be the same - strict_load: true + strict_load_g: true # Path for resume state. Usually in the `experiments/exp_name/training_states` folder - # This argument will over-write the `pretrain_model_g` + # This argument will over-write the `pretrain_network_g` resume_state: ~ @@ -302,9 +302,9 @@ network_g: ################################################# path: ## Path for pretrained models, usually end with pth - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth # Whether to load pretrained models strictly, that is the corresponding parameter names should be the same - strict_load: true + strict_load_g: true ########################################################## # The following are validation settings (Also for testing) diff --git a/docs/Config_CN.md b/docs/Config_CN.md index 6fa159d..6517110 100644 --- a/docs/Config_CN.md +++ b/docs/Config_CN.md @@ -126,11 +126,11 @@ network_g: ###################################### path: # 预训练模型的路径, 需要以pth结尾的模型 - pretrain_model_g: ~ + pretrain_network_g: ~ # 加载预训练模型的时候, 是否需要网络参数的名称严格对应 - strict_load: true + strict_load_g: true # 重启训练的状态路径, 一般在`experiments/exp_name/training_states`目录下 - # 这个设置了, 会覆盖 pretrain_model_g 的设定 + # 这个设置了, 会覆盖 pretrain_network_g 的设定 resume_state: ~ @@ -299,9 +299,9 @@ network_g: ############################# path: # 预训练模型的路径, 需要以pth结尾的模型 - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth # 加载预训练模型的时候, 是否需要网络参数的名称严格对应 - strict_load: true + strict_load_g: true ################################## # 以下为Validation (也是测试)的设置 diff --git a/docs/DatasetPreparation.md b/docs/DatasetPreparation.md index b579f58..31a9616 100644 --- a/docs/DatasetPreparation.md +++ b/docs/DatasetPreparation.md @@ -24,7 +24,7 @@ At present, there are three types of data storage formats supported: 1. Store in `hard disk` directly in the format of images / video frames. 1. Make [LMDB](https://lmdb.readthedocs.io/en/release/), which could accelerate the IO and decompression speed during training. -1. [memcached](https://memcached.org/) or [CEPH](https://ceph.io/) are also supported, if they are installed (usually on clusters). +1. [memcached](https://memcached.org/) is also supported, if they are installed (usually on clusters). #### How to Use diff --git a/docs/DatasetPreparation_CN.md b/docs/DatasetPreparation_CN.md index 7600256..2582422 100644 --- a/docs/DatasetPreparation_CN.md +++ b/docs/DatasetPreparation_CN.md @@ -24,7 +24,7 @@ 1. 直接以图像/视频帧的格式存放在硬盘 2. 制作成 [LMDB](https://lmdb.readthedocs.io/en/release/). 训练数据使用这种形式, 一般会加快读取速度. -3. 若是支持 [Memcached](https://memcached.org/) 或 [Ceph](https://ceph.io/), 则可以使用. 它们一般应用在集群上. +3. 若是支持 [Memcached](https://memcached.org/), 则可以使用. 它们一般应用在集群上. #### 如何使用 diff --git a/docs/DesignConvention.md b/docs/DesignConvention.md index 35ee55f..10d737a 100644 --- a/docs/DesignConvention.md +++ b/docs/DesignConvention.md @@ -34,7 +34,7 @@ Specifically, we implement it through `importlib` and `getattr`. Taking the data # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules diff --git a/docs/DesignConvention_CN.md b/docs/DesignConvention_CN.md index d3c16d3..536d6a6 100644 --- a/docs/DesignConvention_CN.md +++ b/docs/DesignConvention_CN.md @@ -36,7 +36,7 @@ # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules diff --git a/options/test/DUF/test_DUF_official.yml b/options/test/DUF/test_DUF_official.yml index 5d16dd2..0710058 100644 --- a/options/test/DUF/test_DUF_official.yml +++ b/options/test/DUF/test_DUF_official.yml @@ -28,8 +28,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/DUF_x4_52L_official-483d2c78.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/DUF_x4_52L_official-483d2c78.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx2.yml b/options/test/EDSR/test_EDSR_Lx2.yml index 05a1398..1aa4ba8 100644 --- a/options/test/EDSR/test_EDSR_Lx2.yml +++ b/options/test/EDSR/test_EDSR_Lx2.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx3.yml b/options/test/EDSR/test_EDSR_Lx3.yml index c7c951c..0a04329 100644 --- a/options/test/EDSR/test_EDSR_Lx3.yml +++ b/options/test/EDSR/test_EDSR_Lx3.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx4.yml b/options/test/EDSR/test_EDSR_Lx4.yml index e9a55e0..371c08e 100644 --- a/options/test/EDSR/test_EDSR_Lx4.yml +++ b/options/test/EDSR/test_EDSR_Lx4.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx2.yml b/options/test/EDSR/test_EDSR_Mx2.yml index f18dae7..1d07e2c 100644 --- a/options/test/EDSR/test_EDSR_Mx2.yml +++ b/options/test/EDSR/test_EDSR_Mx2.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx3.yml b/options/test/EDSR/test_EDSR_Mx3.yml index 612f213..08319da 100644 --- a/options/test/EDSR/test_EDSR_Mx3.yml +++ b/options/test/EDSR/test_EDSR_Mx3.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx4.yml b/options/test/EDSR/test_EDSR_Mx4.yml index 0d52ef1..744c77e 100644 --- a/options/test/EDSR/test_EDSR_Mx4.yml +++ b/options/test/EDSR/test_EDSR_Mx4.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml index 1576fc1..e008b99 100644 --- a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_deblur_REDS_official-ca46bd8c.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_deblur_REDS_official-ca46bd8c.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml index fbb243d..a233e39 100644 --- a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml index bd75815..b0e7470 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_REDS_official-9f5f5039.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_REDS_official-9f5f5039.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml index 7428355..10181bf 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml @@ -34,8 +34,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml index 21cf0bf..dc652d1 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml index ed4ed55..8a287bb 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml index 95271f8..0e92575 100644 --- a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_M_x4_SR_REDS_official-32075921.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR_M_x4_SR_REDS_official-32075921.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_ESRGAN_x4.yml b/options/test/ESRGAN/test_ESRGAN_x4.yml index 1d23fb8..13327e5 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml index 997381d..5637a19 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml @@ -29,8 +29,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml index 7c39a50..9904e20 100644 --- a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml +++ b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth + strict_load_g: true # validation settings val: diff --git a/options/test/RCAN/test_RCAN.yml b/options/test/RCAN/test_RCAN.yml index 7734d91..78ea747 100644 --- a/options/test/RCAN/test_RCAN.yml +++ b/options/test/RCAN/test_RCAN.yml @@ -49,5 +49,5 @@ save_img: true # path path: - pretrain_model_g: ./experiments/pretrained_models/RCAN_BIX4-official.pth - strict_load: true + pretrain_network_g: ./experiments/pretrained_models/RCAN_BIX4-official.pth + strict_load_g: true diff --git a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml index 517150c..5fef091 100644 --- a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml +++ b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth - strict_load: true + pretrain_network_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml index 29c09f8..d76411d 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml index 91b4e7f..0e8dc78 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml index c5b0e32..ce5e1cf 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml index 8e499cf..cdc8ea7 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml @@ -29,8 +29,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/TOF/test_TOF_official.yml b/options/test/TOF/test_TOF_official.yml index ab916c7..9206634 100644 --- a/options/test/TOF/test_TOF_official.yml +++ b/options/test/TOF/test_TOF_official.yml @@ -26,8 +26,8 @@ save_img: true # path path: - pretrain_model_g: experiments/pretrained_models/tof_official-e81c455f.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/tof_official-e81c455f.pth + strict_load_g: true # validation settings val: diff --git a/options/train/EDSR/train_EDSR_Lx2.yml b/options/train/EDSR/train_EDSR_Lx2.yml index da645b7..bb3167e 100644 --- a/options/train/EDSR/train_EDSR_Lx2.yml +++ b/options/train/EDSR/train_EDSR_Lx2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Lx3.yml b/options/train/EDSR/train_EDSR_Lx3.yml index 7b6ae45..326d95e 100644 --- a/options/train/EDSR/train_EDSR_Lx3.yml +++ b/options/train/EDSR/train_EDSR_Lx3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Lx4.yml b/options/train/EDSR/train_EDSR_Lx4.yml index 6fe945c..ffd3a60 100644 --- a/options/train/EDSR/train_EDSR_Lx4.yml +++ b/options/train/EDSR/train_EDSR_Lx4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx2.yml b/options/train/EDSR/train_EDSR_Mx2.yml index 37410f0..b8c81f9 100644 --- a/options/train/EDSR/train_EDSR_Mx2.yml +++ b/options/train/EDSR/train_EDSR_Mx2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx3.yml b/options/train/EDSR/train_EDSR_Mx3.yml index 7f473a0..bd44e87 100644 --- a/options/train/EDSR/train_EDSR_Mx3.yml +++ b/options/train/EDSR/train_EDSR_Mx3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx4.yml b/options/train/EDSR/train_EDSR_Mx4.yml index aa12b57..0f5e583 100644 --- a/options/train/EDSR/train_EDSR_Mx4.yml +++ b/options/train/EDSR/train_EDSR_Mx4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml index 0623d4c..c59be62 100644 --- a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml +++ b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml @@ -71,8 +71,8 @@ network_d: # path path: - pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: true + pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml index d0bb472..bcc6418 100644 --- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml +++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: false + pretrain_network_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml index becefe2..32645a4 100644 --- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml +++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml index c463310..d79c8d3 100644 --- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml +++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: false + pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml index bb8dba3..75552e9 100644 --- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml +++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml index 23acff2..7d8faba 100644 --- a/options/train/ESRGAN/train_ESRGAN_x4.yml +++ b/options/train/ESRGAN/train_ESRGAN_x4.yml @@ -55,8 +55,8 @@ network_d: # path path: - pretrain_model_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml index a4ede70..c5882c8 100644 --- a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml +++ b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml @@ -51,8 +51,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/RCAN/train_RCAN_x2.yml b/options/train/RCAN/train_RCAN_x2.yml index c525c0d..531b142 100644 --- a/options/train/RCAN/train_RCAN_x2.yml +++ b/options/train/RCAN/train_RCAN_x2.yml @@ -57,8 +57,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml index 978b28e..a0d4ace 100644 --- a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml +++ b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml @@ -60,8 +60,8 @@ network_d: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml index f7e6014..3688a1a 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: false + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml index 9b94d29..5c414ad 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: false + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml index 647b334..1fa782f 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml index 00b77ba..e112d44 100644 --- a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml +++ b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml @@ -42,8 +42,8 @@ network_d: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/requirements.txt b/requirements.txt index 8202611..8f169cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ addict future lmdb matplotlib -mmcv>=0.6 numpy opencv-python +Pillow pyyaml +requests scikit-image scipy tb-nightly diff --git a/scripts/calculate_psnr_ssim.py b/scripts/calculate_psnr_ssim.py index 7e802d1..9aff9c9 100644 --- a/scripts/calculate_psnr_ssim.py +++ b/scripts/calculate_psnr_ssim.py @@ -1,8 +1,10 @@ -import mmcv +import cv2 import numpy as np from os import path as osp from basicsr.metrics import calculate_psnr, calculate_ssim +from basicsr.utils import scandir +from basicsr.utils.matlab_functions import bgr2ycbcr def main(): @@ -27,7 +29,7 @@ def main(): psnr_all = [] ssim_all = [] - img_list = sorted(mmcv.scandir(folder_gt, recursive=True)) + img_list = sorted(scandir(folder_gt, recursive=True, full_path=True)) if test_y_channel: print('Testing Y channel.') @@ -36,16 +38,15 @@ def main(): for i, img_path in enumerate(img_list): basename, ext = osp.splitext(osp.basename(img_path)) - img_gt = mmcv.imread( - osp.join(folder_gt, img_path), flag='unchanged').astype( - np.float32) / 255. - img_restored = mmcv.imread( + img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( + np.float32) / 255. + img_restored = cv2.imread( osp.join(folder_restored, basename + suffix + ext), - flag='unchanged').astype(np.float32) / 255. + cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: - img_gt = mmcv.bgr2ycbcr(img_gt, y_only=True) - img_restored = mmcv.bgr2ycbcr(img_restored, y_only=True) + img_gt = bgr2ycbcr(img_gt, y_only=True) + img_restored = bgr2ycbcr(img_restored, y_only=True) # calculate PSNR and SSIM psnr = calculate_psnr( diff --git a/scripts/create_lmdb.py b/scripts/create_lmdb.py index 4fa359b..6c9787b 100644 --- a/scripts/create_lmdb.py +++ b/scripts/create_lmdb.py @@ -1,6 +1,6 @@ -import mmcv from os import path as osp +from basicsr.utils import scandir from basicsr.utils.lmdb import make_lmdb_from_imgs @@ -53,7 +53,7 @@ def prepare_keys_div2k(folder_path): """ print('Reading image path list ...') img_path_list = sorted( - list(mmcv.scandir(folder_path, suffix='png', recursive=False))) + list(scandir(folder_path, suffix='png', recursive=False))) keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] return img_path_list, keys @@ -96,7 +96,7 @@ def prepare_keys_reds(folder_path): """ print('Reading image path list ...') img_path_list = sorted( - list(mmcv.scandir(folder_path, suffix='png', recursive=True))) + list(scandir(folder_path, suffix='png', recursive=True))) keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000 return img_path_list, keys diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py index cc26218..3e5dd3d 100644 --- a/scripts/download_pretrained_models.py +++ b/scripts/download_pretrained_models.py @@ -1,5 +1,5 @@ import argparse -import mmcv +import os from os import path as osp from basicsr.utils.download import download_file_from_google_drive @@ -7,7 +7,7 @@ def download_pretrained_models(method, file_ids): save_path_root = f'./experiments/pretrained_models/{method}' - mmcv.mkdir_or_exist(save_path_root) + os.makedirs(save_path_root, exist_ok=True) for file_name, file_id in file_ids.items(): save_path = osp.abspath(osp.join(save_path_root, file_name)) diff --git a/scripts/extract_subimages.py b/scripts/extract_subimages.py index 6cf06b1..e2b2af2 100644 --- a/scripts/extract_subimages.py +++ b/scripts/extract_subimages.py @@ -1,12 +1,11 @@ import cv2 -import mmcv import numpy as np import os import sys from multiprocessing import Pool from os import path as osp -from basicsr.utils.util import ProgressBar +from basicsr.utils.util import ProgressBar, scandir def main(): @@ -94,8 +93,7 @@ def extract_subimages(opt): print(f'Folder {save_folder} already exists. Exit.') sys.exit(1) - img_list = list(mmcv.scandir(input_folder)) - img_list = [osp.join(input_folder, v) for v in img_list] + img_list = list(scandir(input_folder, full_path=True)) pbar = ProgressBar(len(img_list)) pool = Pool(opt['n_thread']) diff --git a/scripts/generate_meta_info.py b/scripts/generate_meta_info.py index 22d851e..7bb1aed 100644 --- a/scripts/generate_meta_info.py +++ b/scripts/generate_meta_info.py @@ -1,7 +1,8 @@ -import mmcv from os import path as osp from PIL import Image +from basicsr.utils import scandir + def generate_meta_info_div2k(): """Generate meta info for DIV2K dataset. @@ -10,7 +11,7 @@ def generate_meta_info_div2k(): gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/' meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' - img_list = sorted(list(mmcv.scandir(gt_folder))) + img_list = sorted(list(scandir(gt_folder))) with open(meta_info_txt, 'w') as f: for idx, img_path in enumerate(img_list): diff --git a/setup.cfg b/setup.cfg index dccb00b..78b88e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr -known_third_party = PIL,cv2,lmdb,matplotlib,mmcv,numpy,requests,scipy,skimage,torch,torchvision,yaml +known_third_party = PIL,cv2,lmdb,matplotlib,numpy,requests,scipy,skimage,torch,torchvision,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_face_dfdnet.py b/test_scripts/test_face_dfdnet.py similarity index 97% rename from tests/test_face_dfdnet.py rename to test_scripts/test_face_dfdnet.py index ab2d6db..3e44848 100644 --- a/tests/test_face_dfdnet.py +++ b/test_scripts/test_face_dfdnet.py @@ -1,7 +1,6 @@ import argparse import cv2 import glob -import mmcv import numpy as np import os import torch @@ -10,7 +9,7 @@ from skimage import transform as trans from basicsr.models.archs.dfdnet_arch import DFDNet -from basicsr.utils import tensor2img +from basicsr.utils import imwrite, tensor2img try: import dlib @@ -116,7 +115,8 @@ def warp_crop_faces(self, save_cropped_path=None): if save_cropped_path is not None: path, ext = os.path.splitext(save_cropped_path) save_path = f'{path}_{idx:02d}{ext}' - mmcv.imwrite(mmcv.rgb2bgr(cropped_face), save_path) + imwrite( + cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) # get inverse affine matrix self.similarity_trans.estimate(self.face_template, @@ -129,7 +129,7 @@ def add_restored_face(self, face): def paste_faces_to_input_image(self, save_path): # operate in the BGR order - input_img = mmcv.rgb2bgr(self.input_img) + input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) h, w, _ = input_img.shape h_up, w_up = h * self.upscale_factor, w * self.upscale_factor # simply resize the background @@ -158,7 +158,7 @@ def paste_faces_to_input_image(self, save_path): (blur_size + 1, blur_size + 1), 0) upsample_img = inv_soft_mask * inv_restored_remove_border + ( 1 - inv_soft_mask) * upsample_img - mmcv.imwrite(upsample_img.astype(np.uint8), save_path) + imwrite(upsample_img.astype(np.uint8), save_path) def clean_all(self): self.all_landmarks_5 = [] @@ -339,7 +339,7 @@ def get_part_location(landmarks): path, ext = os.path.splitext( os.path.join(save_restore_root, img_name)) save_path = f'{path}_{idx:02d}{ext}' - mmcv.imwrite(im, save_path) + imwrite(im, save_path) face_helper.add_restored_face(im) print('\tGenerate the final result ...') diff --git a/tests/test_stylegan2.py b/test_scripts/test_stylegan2.py similarity index 97% rename from tests/test_stylegan2.py rename to test_scripts/test_stylegan2.py index c166e64..b93a3ce 100644 --- a/tests/test_stylegan2.py +++ b/test_scripts/test_stylegan2.py @@ -1,6 +1,6 @@ import argparse import math -import mmcv +import os import torch from torchvision import utils @@ -52,7 +52,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): args.latent = 512 args.n_mlp = 8 - mmcv.mkdir_or_exist('samples') + os.makedirs('samples', exist_ok=True) set_random_seed(2020) g_ema = StyleGAN2Generator( diff --git a/tests/test_ffhq_dataset.py b/tests/test_ffhq_dataset.py index 5486385..655e402 100644 --- a/tests/test_ffhq_dataset.py +++ b/tests/test_ffhq_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torch import torchvision.utils @@ -29,7 +29,7 @@ def main(): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_paired_image_dataset.py b/tests/test_paired_image_dataset.py index 3c415a3..a133a36 100644 --- a/tests/test_paired_image_dataset.py +++ b/tests/test_paired_image_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -44,7 +44,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_reds_dataset.py b/tests/test_reds_dataset.py index 7863fe0..cbf23a6 100644 --- a/tests/test_reds_dataset.py +++ b/tests/test_reds_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -45,7 +45,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_vimeo90k_dataset.py b/tests/test_vimeo90k_dataset.py index 8a9661a..80bb45a 100644 --- a/tests/test_vimeo90k_dataset.py +++ b/tests/test_vimeo90k_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -41,7 +41,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( From 2e0068029d901bd1504811a364b460c124098f53 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 4 Oct 2020 00:25:22 +0800 Subject: [PATCH 04/23] add basicsr.png --- assets/basicsr.png | Bin 0 -> 16896 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 assets/basicsr.png diff --git a/assets/basicsr.png b/assets/basicsr.png new file mode 100644 index 0000000000000000000000000000000000000000..cb3577048b0c98d3c48d4dec891e3b7c2e5568f7 GIT binary patch literal 16896 zcmchYlq z*~CveLVppv%j&snIa|AXnYmeEXgS%sn0<7&GNK$=LK6vG6!br0V34@~_rYAYruRl) zj8RwCmStjlGdw(8Qc`k#eXXRew@bowWls3--#;`RKl2Zt3^ z)jIDDzNBSv^9lO;2a=JJy1IK=SXxd^PyhPegHbry+4;Sxxg})zI=`^Q-pNJr86F1X zn*_S&y~hs%gBSM;OZzMW-4f645cC+z=8w(bv+v>#?%r43;*(W=$#3{tFfd+XD9T7_ zdt)ACM~7~()m0rNtG%`F7}RfiTc15OC=8is-hRkE%6}1k{Yp~di`H6|DkJAB{8cTP z8d&X@cm+wOmn6rZyMAsM3hSVMJ}0Qa1&!m&U@>5F6X1h8elfFt!jZwk0@`9TU?v2o z;^2Uc*AtUu(ew~_1Nw?W*E=+o+qJ*{Ke{AzJ{S$T%8aI_wrpCW>5|A-j{hShF<7$u zouOOZec3G5RfC;YJn< zfwT1lvERGvzsSkMYVaK&oI$~uH}}pIj?!C}5EN+4kNC7MiR8VDW59wj@o>Wo)l!p9 zL;01EV+~XU!^T+kAw7=|f6Gn4oW!u9&Vncm5Wx@@XvQxsR6}c{fH*B&2J2&LS8Al( zsq2Zm>0~JkE8lkLTdu+lguves=Af2YynTgF0B&Gm{L%Vu^(oY%{i&O;rUM%6|1jym zWSHRjQP=;M<()L+8(uQ7%tV@>M8c^($+a(PHaB0Nxv_`=h02J<*?tpJ4M1F3EKM*1 za6liASTzGxuuPtS7C!zP(;LPA3+)iZZph>*hiJ=4PXpXu~}OK5fY)(@ zdAA&Re+>-?e2jm*=??C3JRFu8;3k1F9sfnm0vH%eB2=GnaiV<`i8c)!Cx-^K@VuKz z>kU{YO^Ous4(demiX^QPCx;sJ*ynvWRx6`yUP?s@nwe`NVsBCC`rLIUZSil(`Iii4 zm`pget+a)Gp9+HvR((p>g!IN{7@;^1UwB#pD;8qT^ErD9dkAYpPkXO+m0%TIW0$tL z_!aj%z7KJ~hIFyi2oMJYRAjQm4w;1vfUmarVCw~`!qsA&oadl^YT|qvdTBxuP&B(x zG%>-D4%R3hf@Xq`fC)f=Xl;7N-mDP(xvO5vVl*@TY|M7q@t(kdi&357Tkf5{qr&53yDDn;!g{R(*~#SfejP zY(jhbcZRLQudX?M5HU_o@lO_RYB2F6;^g3g{NNd56K~it%_TSZ{@)u)LV6+>GHXw> zf6KzC^^-9XN2BUViSh|u3MAfWiyulQ4Ld%3|A)u3O75tNzl6}HcB4t?;#XJ|M%KEQ zQQLO5>XO-|p@hVsDoS*+Sd5CR?8vn&thUl+-sfQ;7I4|5d>LsgLcw&L#2X9t*T`A< z11*SH9tr1F*|?Z52|I8R%zNZ5nkVUb44txTF$ zGwE)ZOoe*1zx1(4;cySno}ahl_2O-C0-&Re*NBSjz#9@aCLVp*ENk3R%@>X2*M>7Z z05!`DxI0Bc@SBjY(%o&v9?-Xv7j6;N&@mMx!=lprXa6&UW%6np9dK?hf0i|<8+8Z6 zV_WJrxmbtguqMKLruiJ|%sSOxqN>f>z?OoDxpM$J8(Q!1o~fU){N;K#icz6K^JhSw1+~Q#A8@wG*abh<( z(v)QrFQIEDvFjY^$}gJW@vEfzFSa}4)>=pMOZknG(a67bk~CWIa`}i|VC3J3jdJ_6 zH$b1DACyN}nxZBI%0Li3%j_L6wWGOC4~YkdlpFkTOaKO;u>){5)Xr`?C2y zp{riWR0I_~q<4PBQS$&Qepircc3x2_eab~rqhQL15=wh$j8eSghofaOS|_U;*!LazFZykZu~i?`$Nq?XRgX%K^oG78d9&@ zLYwd%^h=9setH{U?Ml(7y^kHgeMVXig;!QeFoAusQGWU^?!XC|Iwx6?6>I*=MAj<+ zHML0GH?i#rj+IF_egarWT6Sa9O)STok}5L;);a?&J0M}w?4uz&^^_N=PT4kXDB!nt z2?iRhuPH=phNm)V<9)>$w{G!+&F?1%A6Who=?#(w11#n4*-4%mb!&SYaHrDMfS7ho z37A$Ox|2WL*FCN3OgeKtdpF8PPOOXWTiPeqni@M$ta?fs-2S3qIZJpqjhz8&^?FSF z{5)LYuew*et|>(SyPVs%ASVMK$TNY|r902I!^KFbg|_wVSG1mt&N7`I4b1hfFXC6G z>TOTsi(i>)6t_C(OgWy&G&ydNy=pyB4_s({FTZ=zU>RazW4ISHOB=sE(F-|GxQ4+N zRg#cV(axDwcvu_nFF50Q z7+Hl&fcNgIJn1nvy?MZLHgMUQmla6~S7gs3<))vE557WTqKr;#&28i5mf4FTp@|(@ zkMnI;>s?!$Wm9yhKBvr_0M;f^3J{SgpcqbosKqm^?%awVo7DZd0ccRA_{2WPdZ!g% z_Hf9f=_}1>FOy3cQS9?slng{<1}KEvvdLx{Oe-vs?3LJ(;V6zKU0@g}(cLFS)K_&< z_^TZ~y?SVCEPwf+$xScn6j986$(VQY(mFT*sv^7L;eaa$0N3^QG=5#FrZf2Bm~-9M zGI^U|Sua7HeijYF1q~?Ax#C_*!m@_s?`p(R#R=+XkfGY7+R7vkCK%gm4}tj};8iDM zGZFZ}JXqar$1II3;-hqdSTxgUJ#Q;u4Qd9=q&Mh-qtuPLn zBo>$`5qP7hh6nqN2%Em zoa2=876pZ;?yji8mR9jLP70)ssA<4k2!d))IVowFqMtugOlOLM3NFPmd^s92#F!i?1NC6_4n74=YhJ~ltW$8bM)mu_g22{+X~A( z?AjVz>{5UF-6D=KjPcwR!vL-M%AF?ck_*3i;*h-A*LlcI|H$I60{kgL;GVJr>=mRv zTGPJB6?QOdENt5@@Zbbl=`>#dX8|+>=f}$fg`P;cmYiB{<%Ad-VX-IPWD?h}X;MtW z)WLMMKp%e?r@$-EOzyvQ{*U8n;=i+tz$^^F>ckCfN10l{;X&~A6Bpd8%`bnxQa?lN3$K2B)n-HY z=RCyPS1t}gNbYlt6qn_R=Vti!HDMUYOokXM`E1ijM4OknWb3te-%k6!1KygpYz@}e zCxyBmPP$n2x|3~48+LKD7<4QrxJ@rdy%FD zM+B!Bg!h9_w~$u!XAyFq`fhjbE*7$4V_Qq*kz8V>Jx?}8&ukYQ7B(iqWyuRG&p|%B zjpqxgQ}Vkc(d^IBa25|5U#XuwrB6v5^WUF6&gn~kE|T;c2~CWBPGmqc2_E>ccf?dt zZunGmk&8+SJenLn)w{}DQ_m@gPm2TEbFKfi?|C|nt%Kl!#SX{*%**S_dtBnvoeVnd za!aq4VFyx~Gz0Qr9a6&%s2&=8#D%Z@QqcI+4uEUoo&>K^PyK$P+{5W#f0f%*?j_yA zI?Hd9_Tq$@!C|{1NgYOfkOjX*HVtd~#MvKp&HORGC*g}vOyr%A*NqcmYh&K1%S2>r zD;TJ=ylGjhQDJcy_%pG7iDaikPxq*o&Q8uJw|pFsg+e)xg;S&SGMyV7s@}I(gxc8P zqqBPtq+6EPq@&0h{6xOUeskwI^yXGs5*vST$@DfGQ%;QvrC06vRxm1HyxO>s&|Tm(~d&_V-Syy5E{UrYTr2TB<6&21t< zg3O1mEozM4u@@6?I&pu}puuWgVa7gR`M7Pweck%Rs6m_=8$$>nxyLc^^4VpoUQ&_3{Zun zMeqQYVkq6;#d;#9-e^GcCe#8r#pr9Cjt5@nv*@l@Nk+LXNj+%1L12-QGt_m9WRVtW zD;+As7V)h`M+)++!zSu|=Q05PwYx55yd@f=yR_G+;@mk6nF)H3%Z7W~?#X}lm2Yaf zoU%5GcvrA0_hFYUZj$?g&xbC70ta06H}Pyi(4J|RZVSp|wT6Ua@lpu+-Y zXwjj9GlRw!I9wp;734|~@cP>HJ)*_hIL6_)uuBFJRpnY|NYTx@hxH81#=W!eu_-s* zW;8lRit{3UaVqf?iKq4i*Z>I0Dg6-+kDy)&d?>C-co=;=%T)VDo&jH%tFoQklYLQo z9?zM7Y|8^d0L`*otQ+qvmyZmQ>?je21aM2>n!$+-6SIaHi6lxPASWrO`S55WI7i(V zPDBJsqh7%h|IS>YXjHUAQUtRs>-m-@W6UcZkh-XDLM}xHeo`6!%JcN%H7zhA!4tiK zoT$&k1gFRzcQ7vL~iy@lEc{Ib8<8x~KZp5b}vr^q_JQA@0_WB6b`J za}i@Qzpr$qQ^1R3S5tP=2Q3pAvh1qXseH)tE`B^rcpy}j`>K-Yx?E=X=;7|!3a&$^ zeKE;f7GAHb$_K#Qm#&t-E}qPp_m5W5NL%C=(r3V-SPV=5vpN+%0&7&CM60+4J9(q0 z#xcWt!8&I9y#oJYr+zIhFkOS$`|P<+6h`;Kayi|l;9^qd$j8q0PA$rQ3WYDR3l5%N z-eeR#oA<+KI~!s-T;iwvMDDK-*!l6)V5~Vh@`0wMC?3bjqp=AayhpV~VBJII+v8i= zkr@#(u3xRYM~z;BxPfg%a37_Fz{cqlZ?mE!b5N9^5>Lc?`G%*oR&GJT^<(8Y=%TLb{s{(C@~$sx|F`UeUBM7(hwmRcS9MHopa# z7<4$c`T|eoD_dyx_Xv9J>_88lr;X{y`xZ?vXX>8#lYkE!_Xl1Yq#!&U`L*3@+*A*O6fxARPD&hnEEkwmw~PMsLtC4m@nU4Zdt2pl5{&8)KUB)P zxKvgnm40~nX~p{>NmoJQCODALrJNU;0Okf_wDROf>rcC+?1!51`<*d-e;Hbww<}-C zBH4~ILd}x2ul>n*IAYwTR^q%!Q}FVA578*VQOJX_<$gVTIC(c9Bn;Ql|5ffQjokun znU1!X_bV7oCJv-=cFgp(N?bwuab-ya{DKb6TA)=S7GahIz3Jm}6nNtR9lMSG3l7tNKqNITI4I*YIR(d8zDJ3?R1LEh?pg*MUF# zUl^X8$1aNZ+n6j)mD*b*x{tZOCMWSI8x(nkAT?e&tyHf*yOS!Yq?$1(F8U5hZ_odY zS1BI)Of&xShVEFh`B01k8+4g{uTi_3`gHMRm(JX-tmnnv(ZnqMu5y^G;A^yF zXg&+t7=RJy<+Bjkefmx3>M_#eb<*iAZseS>Z`~!o){La-gl2vig~!+CV&W%!DOyiJ zU#@aEHQyM1%b+Y3Ba(h~5Oo<=#f3t?P4!j!Ctj74hfMH+xTzF0ZkwpaV$n21f291a zC^uRsXngrNWieMNb|e1%atz$n_;h-a6!mxKnQ~9@n~OI|gW|LJpy!~q^!z+CKg?l< z3#v!!W+rt%f0!h3FcZ&zmYL>HKCT1ym$Mre2{wq$8W3fbPN3Cqd zhypEBPN&O4tmW4+WvZ%1j61!2!@dCMs4M z|Gxmbh^^qA-rwN&*0$e#?X<`@QK&6@SXBu&W{d23?Z**TZ*wmEiWc;N=kmwSfIgg2 zGE95w!_S=`gbM=F4@0D;!~!pRLHXO_b4BR@We3~-F(HgAY`aib;bjIgG1uqdm$N9z7(iAW07?9Q^qN2Pvr~~8 z<9m+D-0t4qySrX_r&LRc6j9f->CN+#Xe!np1*`Pf8bX9$C7^({<2*-r{bKU)~@sf&<__l?KG zTpUFO50!DW=!7gU0%qZLOHASWCEB`1Od!A?tY&fM^Zj7oCV&334PFJZu2Ns#>*xS5 zCq8glTIyVQzkIjwP;Vtfhy?`n&d-74U}rKa$?`QYO^i>LiX1`=ceWtG&og9$S)l-H zxz9OkY<&N=%~v6;vX>>-5)z!`pUW95Y5ED3aN@!Nw{w)B{;8(VnWKbQpxv9|i}kjS z%Wo{4<=-8lPvm_oQI8w-PNOZDk(sQVy??jv>lU{%qnPJZjZ;PBeXitDrYa9n#g6Eu zndihC%`iwQO#KWU+Mjbm*2ygb5o@ER^A0DvGe7+_%nB~8Tbf^%JiYX+ObogfqzBzS zTv#+qPi4QqU+QLa5Nro8t-1c#3_r5de?X8U21y3Re7NN!I8pd7#1sv4A|tR$*t-vm zw3>IuNOkCPNGy{0DPm6EwO{TTkMgNC_1@mAV)%lmalkuOMLA=K9+fzv9ap=2GWqgM zf0hj!>7jXy?;(@)y#Tc`hF#otqUCipZn|K(5&`{4n$)7=6|bebnD{0x=SeG=_j^Yk z%sYv{GF8HfcJH$YHL)c(x?lB@lR@h>h{nOmeoD|!f8=&==F6gxVY&wp!3|05c)Fr5 zDfe)vxOAzPQeR0{&+LimQ>h2$Rm7E4H!EPDr9??Z$ZJ?_LOaFJWm&U1d%}7ok{!od zGd)Yi%1aSe4gB>ljhAsT2`58vZ9@JUd&FE9uzKp3$nW#B`LK(=p#s@&a==79sN3|8 zCW*y-N~dyi+Gs#c2N;~P{9xb^p(LKYyMy(nO4FDC2lEEfZ%m<;V^Mu@&BSEo-j`BA zRMUMhGrkhzo7Lwvg#GicnJaB`yr1)`B9<{=xs!sh?NrE(;jTdL7zEftB|3`qv&l2> zrTZ`%y0Pj=Tw0(rx@#=iq6%}UB?>U}u=7GojMemj@rTAbs7CLOq83``PdjUr>b}zx znwNgs{W1N%RDJ)7=Cwo6Ej0otQ<&vZdB8wE{Kl#qpiiUlO0xQ)-dlBn+R&W!sk%9> zQIB1g|F)>&rClaJ*otpdG#B&<5tL$kN)QJ3$P{=x{1%1O@&1EySk~O@QY%w8e=PjQR`QIFK>G@Hc(9k${0>f`%~aYV*S5U z#b-^bR=3_MSC6Qg!v&+Ip8Js;9eMhqrwH~?Gh-f_OPJ_zjV(U>L{huR4jtNg{3G##A##H@Idy;z@#UaY3#gVIg=M{RGQ>`d($t*NI`oD ztPotyVW(c*`BS%sMQ#s^5B}p{QS%sl){CLzqxd0kS=117pzkx6G-0^rz>(W;_PIbz zlzMeHzeI)93p{>8`O*l`_xV?2Hy&(3a24MjL0_x=eHwve4W;g$5A$vl+^%d!*+gdC zJoCcm8uhbX3HbLm=gL-t${1&G9ylN|Ac+$-oD3`AgQ_?-UW0~X2n~}$-&rh&5+{i` zURD+78BX<8uQgtcARh^5gLmlrC>h!PetX4+sEV0$di(Qph)5LivOsQeODb%mA$Od@8E2t3MV*s=;7- zEGF7sX8w1Xtlm4ntcLjOEOHVN!?*sT-`=@4sh9EfWsoaKVQX$6Z1QA`oXvGZO$g0< zlC48^nPc#QkTc)Ubm~^*5S@9t;=I}@bNX9*P=i%b8_-Z`JXwTo^Q*%(7p5dc!>nsN z)AmII>iE*WfjoX?dcr4)(V7z`7^I#V<*A_!V{`DXlsMtSUAowb-dxqmGYy{RW?e>M zh#N9jm#<{s*evMY5NC^AXbI%sfLo3R3i2>LfOo2!xNm*q8d-uKFK3)<{TZUDlW2tc zX-q9)oD{BJQFdM=ZS~_Q_|f09!6^tNfA|^qs2pvs+`xC% zV1NP1fZqOm67~MdY3+j0q~wtIw_ELN*!Cr@m~C_WJ?xJJ=T9!KY<=H2ro9qgZ6F5I zd}|_koNmG-Y#%vbvksZC|KWEbqw!t?>3X96ju_m*mE2KYE+OXUr^&Oi&`d?~%o1)j z$K5nv!;I9H_6cQ|DFX9ADfWF zFcN4#ag{9PBmboE$8ec!*15tB+ERdjf|Y0$(f!XvlXL}gIaLp|8*re&8 z8%Y17uaNVOD)RTdr&IaSFu`APu90AhKMxpS6dZJ;<82|PS}ZzDYRMi_xF{@cTdS&`lxzIVmov$QH}5${!qFLOOPugQ zt->{lbGVZZtvIcGQSkO~P4SunEU4!7w;%0GmJ63b%_T;H$D+Z=kXt(s4^RtfpYH(y z5M%3=Znho=;(+AOac6FEo?Y;P8`rzdV7OIhPqhe-woqBl*n}X~MXVw}HkaN<7)VxS zKE*hQ+vKv7f!-kb8@*I_o}K0Ul5M0SbIuk)TGl{fZ3Nb`gw+D`M{z>Cb2$9SP>&Kk zZA$9gO>;};+f&9G+mQ)h!TpgT=b5?rd;CfIn?}C&1;6bg$*gIkENnToG-oVDulGlw z(@Ah}_OD#>wf99|>aV<--zpJIRWk3PY8o&u5)YL*R;Pt4rTAG#FFD>kE%97{U(boQ2p9+}Mt^t#(l{;XaO zNB#-JKioT$3jSLETv1Tt>*%_aAXslQs?6rx5^!!)5sYFbT|y%kNoIz>p31eYr>thN zn@{w@hgu9TWiOoJN(XKc=#A0a{HP3qD$$OVDnE|=;O>$NKnc^>-1U}42p#2<$6*3L znH0F$N9E*2gGFcE#vwXG&COhga@%K0Rv%Qv><;&aKiy+J1*5Eg{X)di@yFeRfiGXx zZ0LFG)Kxcpt0_$;oQ2kQ;$nk-x4z3H+0C&nIwQ6~YOKC02^n9clFDNLLVlGO(T|CEar3`SiKzD!dM%CM)8xY$8ek6CdMkjYM^m5E=+33+ z&;r(0(TAqFrLOlp`zaeGwldFa1JCr#dY6qacgyI|g<@`}SM2Xs%%!`f5i7nAyxiyxT9 z7Vpph8Y@Y_W_z~Mm-*j-c_6o9-37P&t9tvyy3>V$QJBE{mp?M2uD@l6?JTM!*IzjX zO#8gij1;ae8yywJMAf`jC;*e}uO@ob!CAq4Zi3{~&xFglDRF5)QJ_!c_ShCoS82XI z_52pV1{{`u?})=x3zzfYdl>PPO&`d%`kWP`Iw*jKQluzB*evYf^Sg3H9 zSfx@}<$lc@vs;K|d2f8gZ?W`+qYDEC4tBX;lJr1{v!r9au@rTcg!rqmq3X%SI&3OpkUwy3tr(A3iQ?Z4L2;!S(d9JDVBD* z4ZcI|E4FEdTH-qRtspInagIW@G~X6`t{MtV(}V2t#}eBk$xVYf-gIHKVBkku^lR0{ z(KUw~`8+)$2LMF{oO6|m>(;P=4%wTN$G*k!g*K8>YNvRpe024HI`{eD_b~(Vo=-gH zlF+H2k1mtyxM9~Pl(?WF*{<$p&`yu>su}u!D7|SJQecse8$C>~g4EdvdLF9_A;WV; zh8X{qQklev>4CckHGHHc0C9|Q*$9HJuh00^o-=a=_ITDtir-2u;*=!~sC`z08sQ1> zYI*KX87tAlW-CY~j0!4kWY`iez0rsN`AkgAiQd!@ht0^(S zr?$r2VySHd6}vsl9kWOzP;YAs7m;K6=c$0YDAKH1wZk1a#@{DL0E(^l30?fe0^TVB zA2Exa^DnHtGFD<}4am1~Ada`zM3>@*m1^Vm5>p`ROsm(=Q8f;Qrh;gUwHnpywjLqh zo^8oS5z{(sXg7>tRU<5C_aJd19SDajE%JI{TYdu23u`BnV+dh%4K;fv$O}s^;F3YA{d#l-PC|jn(GX z9-439U|>;Dz)*~D2w{?q)@Bg?8sPvbuuZwFy#;9sUZDQ^~x8Z;UlX} z3q$JKjNmh$n1A?#+|KHlJr$j-ZVcxtpWE>yWRuK|wW)oX8qHkZeUH;4&6$Wx9>;_l z5NPqNt$|>RgN3G}Kg>|ckER{p8GO*07LLLKmbzt-ru9)^R|Ko-PDv}C-dO9{BfzEj z9ahD=eZy5(lVDhq@^FrIET7gKf=*sHCJ{2d)clYKNb#2eiVL?Q|oHby!gBGh9R%{tfakILuWN*q{Wto9f{u(9;XWs%?lCDYx@0y ztpxh3d(Oz)A8&_+6DCO??`WYZ08X1$lmi-elpe{<|Fh$a@Nu&K71EPEiBTu7LkT*E zrp`>GT#RsF6}D-_lY?(7OMKQ_@xAfi=uP18&{_6D2Jy9(xvD)apxypug8Qm`nmhCg zh*IC4W@j)sM`u*MTA{#3&)XK%LQHj*3bT*8DdXbF<^g~|)9$!Rny73j#Y1=mE9*S3 zsU*fHOvugvNY1p8yb(_cX7ANm%Q*o&vAgO45G~xco}Oh>gZPyDfe|hJ-u1p1I_$ znPpo}zVtv)@b581;y!9~d)-Y<1{#};n(eBrbKH&{K{n?h!Wq3$yLd(IG;{pGli&)2 zMGrZWFZhz;eDRVrY7TAf@jh5NBI)nCp9=$l%bX=iN(r~;^kV&+YWtGs z(rRsFE;#^Y!xZ@LNNwZi&9?_tjYb9krtPrWf>pL=QbR;+QUOzc%rT8?euu+zJrXs# zVPag?leo{GTKb}oRF+uF;hf>nl*+^-bH^%LjcOb-VNbrAeHOi`xqt(fcY4NoTZ`1~ zKM1`GH*>4O{*Xh9&Sy)ek}s8E-Uh7URb%Gx6I7TqQ_$@8&`a99(S~Mv4L;O+&V!;P zd^(R3mVMOH!K7l6D9%ChxJ}>wiv-Dumw8eJ@urpmYNq$4+af089_HDSoi9iDe&>Dk zy?PE^ITgZ5`7sUVJUV4sB3xAQ_OB!j194fq%l`2(+@Ve`cD+=#H7^P4{OL5XdTThy z4B`Ue|V0o%v~RIuBn#lF@Gs2)P4O80^EdZ-tbuCSPTD= z+AwdSUE$+s5pQ&xFo(Z-{XJ?C?}1XW+?^` z@xuN?Zt;u73#NUZZ?r#GMD?BHJu>9FEUFjw2mBQvL_pt*0Z+YOGQMk-0C!~s zD+`6g2g2o|JQ9hE3gs$kN;@AZ)q;;X#kjDov}c0ZRgUmbpY*r-FJKKpT%rC7VLmeM zD-MMb^0-q#+YRgH*=s`C(640YEi%EyW?B#m`HnDWMr?8=vN?R1zjI-y`L~Ns)k%Y} zVpS~@N^FpxSE07{mttaTv<=zjKNrK!t>px!4FZQ=y{{*s855UcLzmJUJl@TP_rJ$L ze$jj#RoGWCB#B-j_ORCTmt}?$s7^vh{eBo2vuK|h8qiE<&ECsSTl?ADo9)IK!CKZ? zMoDd%wWlT)O7ciTxxq;5qBg|Nq>9KCynkk8*fJh_r}jB1iD0PFA_LQdN{IG0aG!Eb*s`hu+jzr)@Op3aoT11SaMc8F)PnI@@rTM~Rb1YL@O>*M&-!=zot`3zL;g$KVkQU@66AO4r ze!+VR*2ep}Yj*`L7CXj;r?J|LdWldLhB@C21jtAhynUWWtCD#(95ZOSJFiPE`JK#w z5T@s!S8O{NBG-bRgiuJq0MS?`HDVw!V#_?o!Lvl`vHEQL&eg(7_JPq?^@lztyCXrx zi81BWw+d8WPW$v>(gju&vnd2us{69)PX`-~gC6)ONtpl!6Gteq?7dyel4v4}PTCJS z>0~=ItCkXgt3g{~ss^45QqW^s*m;erl2L(+9okWqwPeRUGdWov`r;;3*Wh`bogjC) z7gKy@YYNJRO5%&3u}~<0r+DQ?wco7gE<3he3Jcfs%;=OxLf;KuvC55X9?r@3?1s9U zJdMKk*1w0%f5h(`Qk$*USL(YU*4=!y|F|Vgg zdmM&Ha{<;l3ze=$9j0md%1mN8%gr|yCl5|LC>nzUj0|`}&GSooKP3Q?oH75PrBOT8@o zQH`wmw(c~3c5Ss_ea2i1LGtx$)dfvMSTeg)+?J7J~g=~gOU(Hj#rx6%vgE>asxJl!)Q z{e5e{tQfI5WO$)Nq+SK?Dl5-)Ajyw>9AM3geDs-B6>eA8ca0)p626*7YyXDRkf3)H zl6Rff=}y2dVPpa)gH%C=hAQ!o7c&)tU%w$YKbNPm9_SMEs@Nq2HE^5KSq02K1rgqL z``Ww7*=t0Zp2--+z27(ZsP#_)20H(29}wXj{3BKfTGh(qY7f(IeFl53ii%~(i9;l4rLE4= zgfE`)(_A56@SGQ5g6*C9E4WI~WsNU0l6-D9O@f2ZaTvcHvOB5pmaprHBPuj&l}f3j zbYYzA3!4`gy`v_9t={|v#%CER6Wd{FQrGAPkkU*eyRn@Q<5xkm6r>Wcc|uIYNU+K4 zs~<`2W!k3^HjR__SlD;q79IT06&I&vHHXfhza)OzifNJ#bBA?2UNhgt7TZf&`-{O) zM9$yghB*hd#gk%!b=hGSybJwLLD!hms+mG=9@&8za1cKRksZTpIkG(4UuU*8S0LOM zV)!7__V9{LM5j$yPnZa zJ|}86X^8C{eLAc_kL|ogvGwv-A2N+qpH$S~y%(&q@(G=2sm755zSSBrapb*&bp0*!|GoFusg3cQ@3N^?8PZd%a| zK^M6fHd6*v;CWIixB&qVU!@rPs->DaPz{wsFOgE8=?g{sVQuuc;MR>FHLLc|`XmT2x=(du`g+ysQfd%OmOc6p#T z5Uj;K2RorOsi6sPxBnvJn_^y#0a)qqB+(0RnSN%2_dpqR910)!`4}hYz!THx<~*&H z!Vy=TP~s`p-k)v-`R39@xV+-ov6(Smt>=fW%Us?0<5(??QJdM0TD5)qN(MEnIdlze zR*7+m5+oBva8PTn1XF*X=I|l=rA}1vEnUO4s$j8=j9zX z2K^h^kz;SYQ+ZoHuOS0t^XP6pTV{Cp>j2LG^BbRKa3VrulaMAeJ3vd4@j!&3uDG^| zQ-Q-ei4m68?(tR82*CgNP#pv;w#TrZ?7I$~=-anGuIYM|bhN7#F)H{2``20TU|D;! zT~ysO`L<-CTcR}AS)2%!5OP%Rc{Det`>&+@(U+m~4&yrpOg%@T-N8}J1OeD4_yiCM_ z!PLt?03+o>2rsU_k3E1mt?|JOjhKpZ|FzSRfF@~yrdgYt_~Q^x3pOVH42`4#N{+4Q z_d0|93?V`!FY{dadOws~*)?`*2!WLx(IJk>MV3fS&fajueYKn#%8 zE6sx?KSLyfq}l*TKq|F)4&efr$s+-S3-bx+R~c!0KJnhn%U<7=PN*vRFjF~!)A=B7ZhGq|pVKZ+#^#(_G%;;DdTs3b?t27w z!+qJeynvYwU*51_I1I$^G@yB2%Va9e==vITlaZ82_HVq4xX>O(@#GWrus@Jjt;dep zjXoim1(OuyMGA5fgeiVT*Z41}o(#tVNao29qOg&P&ts&OcFb%fRB~7>;Gxtq>UZm} znY-@~3+z!tJBiY7DO4=d1YUiO!KonOK}!9Vid#e?%<>~bXY#!Hq5h!hj1+oVENsQD zV7PGo;H-Rs7EZ6z^euwN(#w{W_pco=`xk`3b~~*(P^;+1`(XO|lgAkLt|_YPT1a zYki2V*CB7fQJ6DL87G6>K_7-me?J?2fQ0d)*W$Fg#O!7JLQNBG4p;5EN3s|R&b+)d0>a6d2pmri5oIc%u;)w}-QBuPrp>y4X(jJYCPDqKmo=C2=S6oQVH z2jpj7roYls*o^ljbuJWphixtvd+JI;xm= z1&&rz4eJFg?dUBpMLDDcCXGOxdWngC@<24{KJfyfO z3cGKo;~#S;A-_Lq)8Ou}f^|JKX+YC>qpca!)9P{NQF@?Rv%A$b3QA8Sbouyv z?|Bk8WKHex`L`DJivCV~mZSmU6=+SMV1%T~u`brD9&A5jZ5u%SzblwgZ~*W6Jg+E+ z1PqS&o}i;S%O*jH7ahVxWfqrbiCz6OR=-(?{a%I%1@HO2^IP_x^CJaH9m@P74Q!*G zr@m0whh>$A&(AUrW1>0xf)*q%2wrSpD?OQpj`q2kX`(efD&{E4jX`V5Q*2B1T8%^O z=f#rRD6$kY-8C%+?Q4b*eFY>C0p72??WJ<{_>q78V50_M9F_in4#ylWG0u-|)??~L zpC=N~@Ny@1uZ1(x7%ddMCOEqU>4-;bk#$klb2X#o24n2L(#FK+#fZ}OhJ&6Qp z4eNdT+bDf_Fd01bd_jgZA4e8zLjj2k?BR5aEdHUz4|-9+Q8*XtCZThSJwDlD3zI zPi51j2Gz{b4W<7L-<*g}$o%w-Xym$#p|>{Us-GRGHuF3PApgZJPQjFlKnN+tyRl6l z#F-U1YA`Bv_a_sIF7Do2a@c&r$yW--L*(wAMc&?<+EGIImOEG-+rGNSb6_Yc|JPgo zp8>bq>Paz6jzU|k>n{dXJevm5XN8YbzZ$+GiDJVP2jjD}GC~P{7540gm?wVtuiri8zmO?OXb0#ZRwhV;iWF1uZrk6XEy~3jP5!^8u1IPp z6i*{|PDN?xmQ6a^MLA0ME3C_4ZM=KQ03OG$iv5D&ZWom^MuMP?!l7dg(L;H+8*Wgp z8>kskiZOP&Gs`AOrs3pEGCn-g{l2j!%gaZ{JRHU(&2-F0a1QTTz;)uG%TvZhkBL}@ zT`>xWK@mv_Ue;~5FZ4%R@}Won*&Q7g#~1`S(i@IvK5f52hK&?=w=6^IH)LYyml16L z_g;XSCklaN|EsX!8F+bleE7tTj)1a=Nlt2_PcwBP8Yf2&9aSBb?jlP>k?~gx zI=)AwGxe56OwU!&G?{lgBqRh;IGq*f=6{nWyCPGD5HoZOKb2fnVwu9~ppqo=#Hw8P s3Cn-aRxE_sCI|W{`u`X6@jpyS$=PcO-@vl}UcjIzt143^Z5sOj00_SZ#sB~S literal 0 HcmV?d00001 From f33864a7538496a6ded8aa1d75d68df6a91f5193 Mon Sep 17 00:00:00 2001 From: Xintao Date: Mon, 5 Oct 2020 00:22:12 +0800 Subject: [PATCH 05/23] add colab --- README.md | 45 +++++++++--- README_CN.md | 45 +++++++++--- basicsr/models/archs/arch_util.py | 11 ++- basicsr/models/archs/stylegan2_arch.py | 9 ++- basicsr/models/sr_model.py | 9 ++- basicsr/models/video_base_model.py | 13 ++-- basicsr/utils/__init__.py | 7 +- basicsr/utils/download.py | 13 ++-- basicsr/utils/lmdb.py | 15 ++-- basicsr/utils/util.py | 61 ---------------- colab/README.md | 7 ++ options/test/DUF/test_DUF_official.yml | 2 +- options/test/EDSR/test_EDSR_Lx2.yml | 2 +- options/test/EDSR/test_EDSR_Lx3.yml | 2 +- options/test/EDSR/test_EDSR_Lx4.yml | 2 +- options/test/EDSR/test_EDSR_Mx2.yml | 2 +- options/test/EDSR/test_EDSR_Mx3.yml | 2 +- options/test/EDSR/test_EDSR_Mx4.yml | 2 +- options/test/EDVR/test_EDVR_L_deblur_REDS.yml | 2 +- .../test/EDVR/test_EDVR_L_deblurcomp_REDS.yml | 2 +- options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml | 2 +- options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml | 2 +- .../test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml | 2 +- .../test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml | 2 +- options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml | 2 +- options/test/ESRGAN/test_ESRGAN_x4.yml | 2 +- options/test/ESRGAN/test_ESRGAN_x4_woGT.yml | 2 +- options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml | 2 +- options/test/RCAN/test_RCAN.yml | 2 +- options/test/TOF/test_TOF_official.yml | 2 +- requirements.txt | 1 + scripts/download_datasets.py | 71 +++++++++++++++++++ scripts/download_pretrained_models.py | 2 +- scripts/extract_subimages.py | 8 ++- scripts/publish_models.py | 3 +- setup.cfg | 2 +- setup.py | 47 ++++++------ test_scripts/test_esrgan.py | 58 +++++++++++++++ test_scripts/test_stylegan2.py | 2 +- 39 files changed, 312 insertions(+), 155 deletions(-) create mode 100644 colab/README.md create mode 100644 scripts/download_datasets.py create mode 100644 test_scripts/test_esrgan.py diff --git a/README.md b/README.md index d17f7bb..027b579 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) +google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
:chart_with_upwards_trend: [Training curves in wandb](https://app.wandb.ai/xintao/basicsr)
@@ -16,13 +17,11 @@ BasicSR is an **open source** image and video super-resolution toolbox based on ## :sparkles: New Feature - Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. Note that it is slightly different from the official testing codes. - > Blind Face Restoration via Deep Multi-scale Component Dictionaries
+ > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
> Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
- > European Conference on Computer Vision (ECCV), 2020 - Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > Analyzing and Improving the Image Quality of StyleGAN
+ > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
> Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
- > Computer Vision and Pattern Recognition (CVPR), 2020
More @@ -46,13 +45,37 @@ These pipelines/commands cannot cover all the cases and more details are in the - [PyTorch >= 1.3](https://pytorch.org/) - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) -Please run the following commands in the **BasicSR root path** to install BasicSR:
-(Make sure that your GCC version: gcc >= 5) +1. Clone repo -```bash -pip install -r requirements.txt -python setup.py develop -``` + ```bash + git clone https://github.com/xinntao/BasicSR.git + ``` + +1. Install dependent packages + + ```bash + cd BasicSR + pip install -r requirements.txt + ``` + +1. Install BasicSR + + Please run the following commands in the **BasicSR root path** to install BasicSR:
+ (Make sure that your GCC version: gcc >= 5)
+ If you do not need the cuda extensions:
+  [*dcn* for EDVR](basicsr/models/ops)
+  [*upfirdn2d* and *fused_act* for StyleGAN2](basicsr/models/ops)
+ please add `--no_cuda_ext` when installing + + ```bash + python setup.py develop --no_cuda_ext + ``` + + If you use the EDVR and StyleGAN2 model, the above cuda extensions are necessary. + + ```bash + python setup.py develop + ``` Note that BasicSR is only tested in Ubuntu, and may be not suitable for Windows. You may try [Windows WSL with CUDA supports](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (It is now only available for insider build with Fast ring). @@ -71,7 +94,7 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects). - **Options/Configs**: Please refer to [Config.md](docs/Config.md). - **Logging**: Please refer to [Logging.md](docs/Logging.md). -## :card_file_box: Model Zoo and Baselines +## :european_castle: Model Zoo and Baselines - The descriptions of currently supported models are in [Models.md](docs/Models.md). - **Pre-trained models and log examples** are available in **[ModelZoo.md](docs/ModelZoo.md)**. diff --git a/README_CN.md b/README_CN.md index f7afe8a..c2b8b55 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,6 +2,7 @@ [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) +google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr)
@@ -16,13 +17,11 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res ## :sparkles: 新的特性 - Sep 8, 2020. 添加 **盲人脸复原推理代码: [DFDNet](https://github.com/csxmli2016/DFDNet)**. 注意和官方代码有些微差异. - > Blind Face Restoration via Deep Multi-scale Component Dictionaries
+ > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
> Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
- > European Conference on Computer Vision (ECCV), 2020 - Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > Analyzing and Improving the Image Quality of StyleGAN
+ > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
> Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
- > Computer Vision and Pattern Recognition (CVPR), 2020
更多 @@ -45,13 +44,37 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res - [PyTorch >= 1.3](https://pytorch.org/) - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) -在BasicSR的**根目录**下运行以下命令:
-(确保 GCC 版本: gcc >= 5) +1. Clone repo -```bash -pip install -r requirements.txt -python setup.py develop -``` + ```bash + git clone https://github.com/xinntao/BasicSR.git + ``` + +1. 安装依赖包 + + ```bash + cd BasicSR + pip install -r requirements.txt + ``` + +1. 安装 BasicSR + + 在BasicSR的**根目录**下运行以下命令:
+ (确保 GCC 版本: gcc >= 5)
+ 如果你不需要以下 cuda 扩展算子:
+  [*dcn* for EDVR](basicsr/models/ops)
+  [*upfirdn2d* and *fused_act* for StyleGAN2](basicsr/models/ops)
+ 在安装命令后添加 `--no_cuda_ext` + + ```bash + python setup.py develop --no_cuda_ext + ``` + + 如果使用 EDVR 和 StyleGAN2 模型, 则需要使用上面的 cuda 扩展算子. + + ```bash + python setup.py develop + ``` 注意: BasicSR 仅在 Ubuntu 下进行测试,或许不支持Windows. 可以在Windows下尝试[支持CUDA的Windows WSL](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (目前只有Fast ring的预览版系统可以安装). @@ -70,7 +93,7 @@ python setup.py develop - **Options/Configs**配置文件的说明, 参见 [Config_CN.md](docs/Config_CN.md). - **Logging**日志系统的说明, 参见 [Logging_CN.md](docs/Logging_CN.md). -## :card_file_box: 模型库和基准 +## :european_castle: 模型库和基准 - 目前支持的模型描述, 参见 [Models_CN.md](docs/Models_CN.md). - **预训练模型和log样例**, 参见 **[ModelZoo_CN.md](docs/ModelZoo_CN.md)**. diff --git a/basicsr/models/archs/arch_util.py b/basicsr/models/archs/arch_util.py index 9aebf99..961fcd4 100644 --- a/basicsr/models/archs/arch_util.py +++ b/basicsr/models/archs/arch_util.py @@ -5,10 +5,17 @@ from torch.nn import init as init from torch.nn.modules.batchnorm import _BatchNorm -from basicsr.models.ops.dcn import (ModulatedDeformConvPack, - modulated_deform_conv) from basicsr.utils import get_root_logger +try: + from basicsr.models.ops.dcn import (ModulatedDeformConvPack, + modulated_deform_conv) +except ImportError: + print('Cannot import dcn. Ignore this warning if dcn is not used. ' + 'Otherwise install BasicSR with compiling dcn.') + ModulatedDeformConvPack = object + modulated_deform_conv = None + @torch.no_grad() def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): diff --git a/basicsr/models/archs/stylegan2_arch.py b/basicsr/models/archs/stylegan2_arch.py index 26c2ea3..2b308b4 100644 --- a/basicsr/models/archs/stylegan2_arch.py +++ b/basicsr/models/archs/stylegan2_arch.py @@ -4,8 +4,13 @@ from torch import nn from torch.nn import functional as F -from basicsr.models.ops.fused_act import FusedLeakyReLU, fused_leaky_relu -from basicsr.models.ops.upfirdn2d import upfirdn2d +try: + from basicsr.models.ops.fused_act import FusedLeakyReLU, fused_leaky_relu + from basicsr.models.ops.upfirdn2d import upfirdn2d +except ImportError: + print('Cannot import fused_act and upfirdn2d. Ignore this warning if ' + 'they are not used. Otherwise install BasicSR with compiling them.') + FusedLeakyReLU, fused_leaky_relu, upfirdn2d = None, None, None class NormStyleCode(nn.Module): diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 32bb794..92bbf4b 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -3,10 +3,11 @@ from collections import OrderedDict from copy import deepcopy from os import path as osp +from tqdm import tqdm from basicsr.models.archs import define_network from basicsr.models.base_model import BaseModel -from basicsr.utils import ProgressBar, get_root_logger, imwrite, tensor2img +from basicsr.utils import get_root_logger, imwrite, tensor2img loss_module = importlib.import_module('basicsr.models.losses') metric_module = importlib.import_module('basicsr.metrics') @@ -130,7 +131,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, metric: 0 for metric in self.opt['val']['metrics'].keys() } - pbar = ProgressBar(len(dataloader)) + pbar = tqdm(total=len(dataloader), unit='image') for idx, val_data in enumerate(dataloader): img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] @@ -171,7 +172,9 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, metric_type = opt_.pop('type') self.metric_results[name] += getattr( metric_module, metric_type)(sr_img, gt_img, **opt_) - pbar.update(f'Test {img_name}') + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() if with_metrics: for metric in self.metric_results.keys(): diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index cbacd99..203fb7f 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -4,9 +4,10 @@ from copy import deepcopy from os import path as osp from torch import distributed as dist +from tqdm import tqdm from basicsr.models.sr_model import SRModel -from basicsr.utils import ProgressBar, get_root_logger, imwrite, tensor2img +from basicsr.utils import get_root_logger, imwrite, tensor2img from basicsr.utils.dist_util import get_dist_info metric_module = importlib.import_module('basicsr.metrics') @@ -39,7 +40,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): tensor.zero_() # record all frames (border and center frames) if rank == 0: - pbar = ProgressBar(len(dataset)) + pbar = tqdm(total=len(dataset), unit='frame') for idx in range(rank, len(dataset), world_size): val_data = dataset[idx] val_data['lq'].unsqueeze_(0) @@ -97,8 +98,12 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): # progress bar if rank == 0: for _ in range(world_size): - pbar.update(f'Test {folder} - ' - f'{int(frame_idx) + world_size}/{max_idx}') + pbar.update(1) + pbar.set_description( + f'Test {folder}:' + f'{int(frame_idx) + world_size}/{max_idx}') + if rank == 0: + pbar.close() if with_metrics: if self.opt['dist']: diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index e547f67..554f433 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -2,8 +2,8 @@ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img from .logger import (MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger) -from .util import (ProgressBar, check_resume, get_time_str, make_exp_dirs, - mkdir_and_rename, scandir, set_random_seed) +from .util import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, + scandir, set_random_seed) __all__ = [ # file_client.py @@ -26,6 +26,5 @@ 'mkdir_and_rename', 'make_exp_dirs', 'scandir', - 'check_resume', - 'ProgressBar' + 'check_resume' ] diff --git a/basicsr/utils/download.py b/basicsr/utils/download.py index 3cd696c..07be86d 100644 --- a/basicsr/utils/download.py +++ b/basicsr/utils/download.py @@ -1,7 +1,6 @@ import math import requests - -from .util import ProgressBar +from tqdm import tqdm def download_file_from_google_drive(file_id, save_path): @@ -49,7 +48,8 @@ def save_response_content(response, file_size=None, chunk_size=32768): if file_size is not None: - pbar = ProgressBar(math.ceil(file_size / chunk_size)) + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + readable_file_size = sizeof_fmt(file_size) else: pbar = None @@ -59,10 +59,13 @@ def save_response_content(response, for chunk in response.iter_content(chunk_size): downloaded_size += chunk_size if pbar is not None: - pbar.update(f'Downloading {sizeof_fmt(downloaded_size)} ' - f'/ {readable_file_size}') + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') if chunk: # filter out keep-alive new chunks f.write(chunk) + if pbar is not None: + pbar.close() def sizeof_fmt(size, suffix='B'): diff --git a/basicsr/utils/lmdb.py b/basicsr/utils/lmdb.py index c99f50b..a81278f 100644 --- a/basicsr/utils/lmdb.py +++ b/basicsr/utils/lmdb.py @@ -3,8 +3,7 @@ import sys from multiprocessing import Pool from os import path as osp - -from .util import ProgressBar +from tqdm import tqdm def make_lmdb_from_imgs(data_path, @@ -75,12 +74,13 @@ def make_lmdb_from_imgs(data_path, dataset = {} # use dict to keep the order for multiprocessing shapes = {} print(f'Read images with multiprocessing, #thread: {n_thread} ...') - pbar = ProgressBar(len(img_path_list)) + pbar = tqdm(total=len(img_path_list), unit='image') def callback(arg): """get the image data and update pbar.""" key, dataset[key], shapes[key] = arg - pbar.update('Reading {}'.format(key)) + pbar.update(1) + pbar.set_description(f'Read {key}') pool = Pool(n_thread) for path, key in zip(img_path_list, keys): @@ -90,6 +90,7 @@ def callback(arg): callback=callback) pool.close() pool.join() + pbar.close() print(f'Finish reading {len(img_path_list)} images.') # create lmdb environment @@ -107,11 +108,12 @@ def callback(arg): env = lmdb.open(lmdb_path, map_size=map_size) # write data to lmdb - pbar = ProgressBar(len(img_path_list)) + pbar = tqdm(total=len(img_path_list), unit='chunk') txn = env.begin(write=True) txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') for idx, (path, key) in enumerate(zip(img_path_list, keys)): - pbar.update(f'Write {key}') + pbar.update(1) + pbar.set_description(f'Write {key}') key_byte = key.encode('ascii') if multiprocessing_read: img_byte = dataset[key] @@ -127,6 +129,7 @@ def callback(arg): if idx % batch == 0: txn.commit() txn = env.begin(write=True) + pbar.close() txn.commit() env.close() txt_file.close() diff --git a/basicsr/utils/util.py b/basicsr/utils/util.py index f0440e4..26f3370 100644 --- a/basicsr/utils/util.py +++ b/basicsr/utils/util.py @@ -1,11 +1,9 @@ import numpy as np import os import random -import sys import time import torch from os import path as osp -from shutil import get_terminal_size from .dist_util import master_only from .logger import get_root_logger @@ -120,62 +118,3 @@ def check_resume(opt, resume_iter): opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') logger.info(f"Set {name} to {opt['path'][name]}") - - -class ProgressBar(object): - """A progress bar that can print the progress. - - Modified from: - https://github.com/hellock/cvbase/blob/master/cvbase/progress.py - """ - - def __init__(self, task_num=0, bar_width=50, start=True): - self.task_num = task_num - max_bar_width = self._get_max_bar_width() - self.bar_width = ( - bar_width if bar_width <= max_bar_width else max_bar_width) - self.completed = 0 - if start: - self.start() - - def _get_max_bar_width(self): - terminal_width, _ = get_terminal_size() - max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) - if max_bar_width < 10: - print(f'terminal width is too small ({terminal_width}), ' - 'please consider widen the terminal for better ' - 'progressbar visualization') - max_bar_width = 10 - return max_bar_width - - def start(self): - if self.task_num > 0: - sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, " - f'elapsed: 0s, ETA:\nStart...\n') - else: - sys.stdout.write('completed: 0, elapsed: 0s') - sys.stdout.flush() - self.start_time = time.time() - - def update(self, msg='In progress...'): - self.completed += 1 - elapsed = time.time() - self.start_time + 1e-8 - fps = self.completed / elapsed - if self.task_num > 0: - percentage = self.completed / float(self.task_num) - eta = int(elapsed * (1 - percentage) / percentage + 0.5) - mark_width = int(self.bar_width * percentage) - bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) - sys.stdout.write('\033[2F') # cursor up 2 lines - sys.stdout.write( - '\033[J' - ) # clean the output (remove extra chars since last display) - sys.stdout.write( - f'[{bar_chars}] {self.completed}/{self.task_num}, ' - f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' - f'ETA: {eta:5}s\n{msg}\n') - else: - sys.stdout.write( - f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' - f' {fps:.1f} tasks/s') - sys.stdout.flush() diff --git a/colab/README.md b/colab/README.md new file mode 100644 index 0000000..729eb3e --- /dev/null +++ b/colab/README.md @@ -0,0 +1,7 @@ +# Colab + +google colab logo + +To maintain a small size of BasicSR repo, we do not include the original colab notebooks in this repo, but provide links to the google colab. + +- [BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing) diff --git a/options/test/DUF/test_DUF_official.yml b/options/test/DUF/test_DUF_official.yml index 0710058..d0bc81c 100644 --- a/options/test/DUF/test_DUF_official.yml +++ b/options/test/DUF/test_DUF_official.yml @@ -28,7 +28,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/DUF_x4_52L_official-483d2c78.pth + pretrain_network_g: experiments/pretrained_models/DUF/DUF_x4_52L_official-483d2c78.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Lx2.yml b/options/test/EDSR/test_EDSR_Lx2.yml index 1aa4ba8..82dcb49 100644 --- a/options/test/EDSR/test_EDSR_Lx2.yml +++ b/options/test/EDSR/test_EDSR_Lx2.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Lx3.yml b/options/test/EDSR/test_EDSR_Lx3.yml index 0a04329..6053ba6 100644 --- a/options/test/EDSR/test_EDSR_Lx3.yml +++ b/options/test/EDSR/test_EDSR_Lx3.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Lx4.yml b/options/test/EDSR/test_EDSR_Lx4.yml index 371c08e..37bb209 100644 --- a/options/test/EDSR/test_EDSR_Lx4.yml +++ b/options/test/EDSR/test_EDSR_Lx4.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Mx2.yml b/options/test/EDSR/test_EDSR_Mx2.yml index 1d07e2c..b6ab304 100644 --- a/options/test/EDSR/test_EDSR_Mx2.yml +++ b/options/test/EDSR/test_EDSR_Mx2.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Mx3.yml b/options/test/EDSR/test_EDSR_Mx3.yml index 08319da..c799603 100644 --- a/options/test/EDSR/test_EDSR_Mx3.yml +++ b/options/test/EDSR/test_EDSR_Mx3.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth strict_load_g: true # validation settings diff --git a/options/test/EDSR/test_EDSR_Mx4.yml b/options/test/EDSR/test_EDSR_Mx4.yml index 744c77e..2686861 100644 --- a/options/test/EDSR/test_EDSR_Mx4.yml +++ b/options/test/EDSR/test_EDSR_Mx4.yml @@ -43,7 +43,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml index e008b99..6982ab8 100644 --- a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_deblur_REDS_official-ca46bd8c.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblur_REDS_official-ca46bd8c.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml index a233e39..4108a2a 100644 --- a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml index b0e7470..768c173 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_REDS_official-9f5f5039.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_REDS_official-9f5f5039.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml index 10181bf..9929067 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml @@ -34,7 +34,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml index dc652d1..dff07d8 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml index 8a287bb..fbe2b1b 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth strict_load_g: true # validation settings diff --git a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml index 0e92575..773286d 100644 --- a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml @@ -35,7 +35,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/EDVR_M_x4_SR_REDS_official-32075921.pth + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_M_x4_SR_REDS_official-32075921.pth strict_load_g: true # validation settings diff --git a/options/test/ESRGAN/test_ESRGAN_x4.yml b/options/test/ESRGAN/test_ESRGAN_x4.yml index 13327e5..845789c 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4.yml @@ -40,7 +40,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth strict_load_g: true # validation settings diff --git a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml index 5637a19..d428740 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml @@ -29,7 +29,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth strict_load_g: true # validation settings diff --git a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml index 9904e20..9636f22 100644 --- a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml +++ b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml @@ -40,7 +40,7 @@ network_g: # path path: - pretrain_network_g: experiments/pretrained_models/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth strict_load_g: true # validation settings diff --git a/options/test/RCAN/test_RCAN.yml b/options/test/RCAN/test_RCAN.yml index 78ea747..3f22dd2 100644 --- a/options/test/RCAN/test_RCAN.yml +++ b/options/test/RCAN/test_RCAN.yml @@ -49,5 +49,5 @@ save_img: true # path path: - pretrain_network_g: ./experiments/pretrained_models/RCAN_BIX4-official.pth + pretrain_network_g: ./experiments/pretrained_models/RCAN/RCAN_BIX4-official.pth strict_load_g: true diff --git a/options/test/TOF/test_TOF_official.yml b/options/test/TOF/test_TOF_official.yml index 9206634..f61dbaf 100644 --- a/options/test/TOF/test_TOF_official.yml +++ b/options/test/TOF/test_TOF_official.yml @@ -26,7 +26,7 @@ save_img: true # path path: - pretrain_network_g: experiments/pretrained_models/tof_official-e81c455f.pth + pretrain_network_g: experiments/pretrained_models/TOF/tof_official-e81c455f.pth strict_load_g: true # validation settings diff --git a/requirements.txt b/requirements.txt index 8f169cd..afd2ca1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ scipy tb-nightly torch>=1.3 torchvision +tqdm yapf diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py new file mode 100644 index 0000000..bd4ebd9 --- /dev/null +++ b/scripts/download_datasets.py @@ -0,0 +1,71 @@ +import argparse +import glob +import os +from os import path as osp + +from basicsr.utils.download import download_file_from_google_drive + + +def download_dataset(dataset, file_ids): + save_path_root = './datasets/' + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input( + f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accpets Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + download_file_from_google_drive(file_id, save_path) + + # unzip + if save_path.endswith('.zip'): + extracted_path = save_path.replace('.zip', '') + print(f'Extract {save_path} to {extracted_path}') + import zipfile + with zipfile.ZipFile(save_path, 'r') as zip_ref: + zip_ref.extractall(extracted_path) + + file_name = file_name.replace('.zip', '') + subfolder = osp.join(extracted_path, file_name) + if osp.isdir(subfolder): + print(f'Move {subfolder} to {extracted_path}') + import shutil + for path in glob.glob(osp.join(subfolder, '*')): + shutil.move(path, extracted_path) + shutil.rmtree(subfolder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + 'dataset', + type=str, + help=("Options: 'Set5', 'Set14'. " + "Set to 'all' if you want to download all the dataset.")) + args = parser.parse_args() + + file_ids = { + 'Set5': { + 'Set5.zip': # file name + '1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id + }, + 'Set14': { + 'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E', + } + } + + if args.dataset == 'all': + for dataset in file_ids.keys(): + download_dataset(dataset, file_ids[dataset]) + else: + download_dataset(args.dataset, file_ids[args.dataset]) diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py index 3e5dd3d..e6eb06f 100644 --- a/scripts/download_pretrained_models.py +++ b/scripts/download_pretrained_models.py @@ -13,7 +13,7 @@ def download_pretrained_models(method, file_ids): save_path = osp.abspath(osp.join(save_path_root, file_name)) if osp.exists(save_path): user_response = input( - f'{file_name} already exist. Do you want to cover it? Y/N') + f'{file_name} already exist. Do you want to cover it? Y/N\n') if user_response.lower() == 'y': print(f'Covering {file_name} to {save_path}') download_file_from_google_drive(file_id, save_path) diff --git a/scripts/extract_subimages.py b/scripts/extract_subimages.py index e2b2af2..4845ca8 100644 --- a/scripts/extract_subimages.py +++ b/scripts/extract_subimages.py @@ -4,8 +4,9 @@ import sys from multiprocessing import Pool from os import path as osp +from tqdm import tqdm -from basicsr.utils.util import ProgressBar, scandir +from basicsr.utils.util import scandir def main(): @@ -95,13 +96,14 @@ def extract_subimages(opt): img_list = list(scandir(input_folder, full_path=True)) - pbar = ProgressBar(len(img_list)) + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') pool = Pool(opt['n_thread']) for path in img_list: pool.apply_async( - worker, args=(path, opt), callback=lambda arg: pbar.update(arg)) + worker, args=(path, opt), callback=lambda arg: pbar.update(1)) pool.close() pool.join() + pbar.close() print('All processes done.') diff --git a/scripts/publish_models.py b/scripts/publish_models.py index ea2b5f4..ea4ae79 100644 --- a/scripts/publish_models.py +++ b/scripts/publish_models.py @@ -53,6 +53,7 @@ def convert_to_backward_compatible_models(paths): if __name__ == '__main__': - paths = glob.glob('experiments/pretrained_models/*.pth') + paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob( + 'experiments/pretrained_models/**/*.pth') convert_to_backward_compatible_models(paths) update_sha(paths) diff --git a/setup.cfg b/setup.cfg index 78b88e2..62caaa6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr -known_third_party = PIL,cv2,lmdb,matplotlib,numpy,requests,scipy,skimage,torch,torchvision,yaml +known_third_party = PIL,cv2,lmdb,matplotlib,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/setup.py b/setup.py index 3050e9b..621007f 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import os import subprocess +import sys import time import torch from torch.utils.cpp_extension import (BuildExtension, CppExtension, @@ -119,6 +120,31 @@ def get_requirements(filename='requirements.txt'): if __name__ == '__main__': + if '--no_cuda_ext' in sys.argv: + ext_modules = [] + sys.argv.remove('--no_cuda_ext') + else: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='basicsr.models.ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=[ + 'src/deform_conv_cuda.cpp', + 'src/deform_conv_cuda_kernel.cu' + ]), + make_cuda_ext( + name='fused_act_ext', + module='basicsr.models.ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='basicsr.models.ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + write_version_py() setup( name='basicsr', @@ -143,25 +169,6 @@ def get_requirements(filename='requirements.txt'): license='Apache License 2.0', setup_requires=['cython', 'numpy'], install_requires=get_requirements(), - ext_modules=[ - make_cuda_ext( - name='deform_conv_ext', - module='basicsr.models.ops.dcn', - sources=['src/deform_conv_ext.cpp'], - sources_cuda=[ - 'src/deform_conv_cuda.cpp', - 'src/deform_conv_cuda_kernel.cu' - ]), - make_cuda_ext( - name='fused_act_ext', - module='basicsr.models.ops.fused_act', - sources=['src/fused_bias_act.cpp'], - sources_cuda=['src/fused_bias_act_kernel.cu']), - make_cuda_ext( - name='upfirdn2d_ext', - module='basicsr.models.ops.upfirdn2d', - sources=['src/upfirdn2d.cpp'], - sources_cuda=['src/upfirdn2d_kernel.cu']), - ], + ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension}, zip_safe=False) diff --git a/test_scripts/test_esrgan.py b/test_scripts/test_esrgan.py new file mode 100644 index 0000000..1c70d82 --- /dev/null +++ b/test_scripts/test_esrgan.py @@ -0,0 +1,58 @@ +import argparse +import cv2 +import glob +import numpy as np +import os +import torch + +from basicsr.models.archs.rrdbnet_arch import RRDBNet + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth' # noqa: E501 + ) + parser.add_argument( + '--folder', + type=str, + default='datasets/Set14/LRbicx4', + help='input test image folder') + parser.add_argument( + '--device', type=str, default='cuda', help='Options: cuda, cpu.') + args = parser.parse_args() + + device = torch.device(args.device) + + # set up model + model = RRDBNet( + num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32) + model.load_state_dict(torch.load(args.model_path)['params'], strict=True) + model.eval() + model = model.to(device) + + os.makedirs('results/ESRGAN', exist_ok=True) + for idx, path in enumerate( + sorted(glob.glob(os.path.join(args.folder, '*')))): + imgname = os.path.splitext(os.path.basename(path))[0] + print('Testing', idx, imgname) + # read image + img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], + (2, 0, 1))).float() + img = img.unsqueeze(0).to(device) + # inference + with torch.no_grad(): + output = model(img) + # save image + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + cv2.imwrite(f'results/ESRGAN/{imgname}_ESRGAN.png', output) + + +if __name__ == '__main__': + main() diff --git a/test_scripts/test_stylegan2.py b/test_scripts/test_stylegan2.py index b93a3ce..38e5ff2 100644 --- a/test_scripts/test_stylegan2.py +++ b/test_scripts/test_stylegan2.py @@ -43,7 +43,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): '--ckpt', type=str, default= # noqa: E251 - 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501 + 'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501 ) parser.add_argument('--channel_multiplier', type=int, default=2) parser.add_argument('--randomize_noise', type=bool, default=True) From e4ebae87008bd77978c7e400b2adae017314948e Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 18 Oct 2020 17:40:37 +0800 Subject: [PATCH 06/23] add fid metric --- basicsr/metrics/fid.py | 102 ++++++ basicsr/models/archs/inception.py | 309 +++++++++++++++++++ scripts/calculate_fid_stats_from_datasets.py | 72 +++++ scripts/calculate_stylegan2_fid.py | 79 +++++ 4 files changed, 562 insertions(+) create mode 100644 basicsr/metrics/fid.py create mode 100644 basicsr/models/archs/inception.py create mode 100644 scripts/calculate_fid_stats_from_datasets.py create mode 100644 scripts/calculate_stylegan2_fid.py diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py new file mode 100644 index 0000000..35fc23d --- /dev/null +++ b/basicsr/metrics/fid.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.models.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', + resize_input=True, + normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], + resize_input=resize_input, + normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, + inception, + len_generator=None, + device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ( + 'Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal ' + 'of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/basicsr/models/archs/inception.py b/basicsr/models/archs/inception.py new file mode 100644 index 0000000..d706bab --- /dev/null +++ b/basicsr/models/archs/inception.py @@ -0,0 +1,309 @@ +# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501 +# For FID metric + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.model_zoo import load_url +from torchvision import models + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling features + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3. + + Args: + output_blocks (list[int]): Indices of blocks to return features of. + Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input (bool): If true, bilinearly resizes input to width and + height 299 before feeding input to model. As the network + without fully connected layers is fully convolutional, it + should be able to handle inputs of arbitrary size, so resizing + might not be strictly needed. Default: True. + normalize_input (bool): If true, scales the input from range (0, 1) + to the range the pretrained Inception network expects, + namely (-1, 1). Default: True. + requires_grad (bool): If true, parameters of the model require + gradients. Possibly useful for finetuning the network. + Default: False. + use_fid_inception (bool): If true, uses the pretrained Inception + model used in Tensorflow's FID implementation. + If false, uses the pretrained Inception model available in + torchvision. The FID Inception model has different weights + and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get + comparable results. Default: True. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, ( + 'Last possible output block index is 3') + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, x): + """Get Inception feature maps. + + Args: + x (Tensor): Input tensor of shape (b, 3, h, w). + Values are expected to be in range (-1, 1). You can also input + (0, 1) with setting normalize_input = True. + + Returns: + list[Tensor]: Corresponding to the selected output block, sorted + ascending by index. + """ + output = [] + + if self.resize_input: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + output.append(x) + + if idx == self.last_needed_block: + break + + return output + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation. + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + if os.path.exists(LOCAL_FID_WEIGHTS): + state_dict = torch.load( + LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage) + else: + state_dict = load_url(FID_WEIGHTS_URL, progress=True) + + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/scripts/calculate_fid_stats_from_datasets.py b/scripts/calculate_fid_stats_from_datasets.py new file mode 100644 index 0000000..8b61f5c --- /dev/null +++ b/scripts/calculate_fid_stats_from_datasets.py @@ -0,0 +1,72 @@ +import argparse +import math +import numpy as np +import torch +from torch.utils.data import DataLoader + +from basicsr.data import create_dataset +from basicsr.metrics.fid import (extract_inception_features, + load_patched_inception_v3) + + +def calculate_stats_from_dataset(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--size', type=int, default=512) + parser.add_argument('--dataroot', type=str, default='datasets/ffhq') + args = parser.parse_args() + + # inception model + inception = load_patched_inception_v3(device) + + # create dataset + opt = {} + opt['name'] = 'FFHQ' + opt['type'] = 'FFHQDataset' + opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb' + opt['io_backend'] = dict(type='lmdb') + opt['use_hflip'] = False + opt['mean'] = [0.5, 0.5, 0.5] + opt['std'] = [0.5, 0.5, 0.5] + dataset = create_dataset(opt) + + # create dataloader + data_loader = DataLoader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=4, + sampler=None, + drop_last=False) + total_batch = math.ceil(args.num_sample / args.batch_size) + + def data_generator(data_loader, total_batch): + for idx, data in enumerate(data_loader): + if idx >= total_batch: + break + else: + yield data['gt'] + + features = extract_inception_features( + data_generator(data_loader, total_batch), inception, total_batch, + device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + mean = np.mean(features, 0) + cov = np.cov(features, rowvar=False) + + save_path = f'inception_{opt["name"]}_{args.size}.pth' + torch.save( + dict(name=opt['name'], size=args.size, mean=mean, cov=cov), + save_path, + _use_new_zipfile_serialization=False) + + +if __name__ == '__main__': + calculate_stats_from_dataset() diff --git a/scripts/calculate_stylegan2_fid.py b/scripts/calculate_stylegan2_fid.py new file mode 100644 index 0000000..bd3acb1 --- /dev/null +++ b/scripts/calculate_stylegan2_fid.py @@ -0,0 +1,79 @@ +import argparse +import math +import numpy as np +import torch +from torch import nn + +from basicsr.metrics.fid import (calculate_fid, extract_inception_features, + load_patched_inception_v3) +from basicsr.models.archs.stylegan2_arch import StyleGAN2Generator + + +def calculate_stylegan2_fid(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument( + 'ckpt', type=str, help='Path to the stylegan2 checkpoint.') + parser.add_argument( + 'fid_stats', type=str, help='Path to the dataset fid statistics.') + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--channel_multiplier', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--truncation', type=float, default=1) + parser.add_argument('--truncation_mean', type=int, default=4096) + args = parser.parse_args() + + # create stylegan2 model + generator = StyleGAN2Generator( + out_size=args.size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=args.channel_multiplier, + resample_kernel=(1, 3, 3, 1)) + generator.load_state_dict(torch.load(args.ckpt)['params_ema']) + generator = nn.DataParallel(generator).eval().to(device) + + if args.truncation < 1: + with torch.no_grad(): + truncation_latent = generator.mean_latent(args.truncation_mean) + else: + truncation_latent = None + + # inception model + inception = load_patched_inception_v3(device) + + total_batch = math.ceil(args.num_sample / args.batch_size) + + def sample_generator(total_batch): + for i in range(total_batch): + with torch.no_grad(): + latent = torch.randn(args.batch_size, 512, device=device) + samples, _ = generator([latent], + truncation=args.truncation, + truncation_latent=truncation_latent) + yield samples + + features = extract_inception_features( + sample_generator(total_batch), inception, total_batch, device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + sample_mean = np.mean(features, 0) + sample_cov = np.cov(features, rowvar=False) + + # load the dataset stats + stats = torch.load(args.fid_stats) + real_mean = stats['mean'] + real_cov = stats['cov'] + + # calculate FID metric + fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) + print('fid:', fid) + + +if __name__ == '__main__': + calculate_stylegan2_fid() From b500cee9103904b946f10f6bde0735feff46667c Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 18 Oct 2020 17:43:17 +0800 Subject: [PATCH 07/23] minor updates --- .gitignore | 1 + LICENSE/README.md | 2 + basicsr/data/paired_image_dataset.py | 7 + basicsr/models/archs/arch_util.py | 2 +- basicsr/models/archs/vgg_arch.py | 14 +- basicsr/utils/__init__.py | 5 +- basicsr/utils/crawler_util.py | 35 ++++ basicsr/utils/download.py | 19 +-- basicsr/utils/options.py | 7 +- basicsr/utils/util.py | 25 ++- scripts/extract_images_from_tfrecords.py | 206 +++++++++++++++-------- test_scripts/test_esrgan.py | 5 +- test_scripts/test_face_dfdnet.py | 42 ++++- test_scripts/test_stylegan2.py | 2 +- 14 files changed, 263 insertions(+), 109 deletions(-) create mode 100644 basicsr/utils/crawler_util.py diff --git a/.gitignore b/.gitignore index 5d63054..4e4055d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ version.py # ignored files with suffix *.html *.png +*.jpeg *.jpg *.gif *.pth diff --git a/LICENSE/README.md b/LICENSE/README.md index 159c492..a5d8b6f 100644 --- a/LICENSE/README.md +++ b/LICENSE/README.md @@ -13,3 +13,5 @@ This BasicSR project is released under the Apache 2.0 license. 1. NIQE metric: the codes are translated from the [official MATLAB codes](http://live.ece.utexas.edu/research/quality/niqe_release.zip) > A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", IEEE Signal Processing Letters, 2012. + +1. FID metric: the codes are modified from [pytorch-fid](https://github.com/mseitzer/pytorch-fid) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 0e2de96..7775d0b 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -1,4 +1,5 @@ from torch.utils import data as data +from torchvision.transforms.functional import normalize from basicsr.data.transforms import augment, paired_random_crop from basicsr.data.util import (paired_paths_from_folder, @@ -44,6 +45,8 @@ def __init__(self, opt): # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] if 'filename_tmpl' in opt: @@ -97,6 +100,10 @@ def __getitem__(self, index): img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) return { 'lq': img_lq, diff --git a/basicsr/models/archs/arch_util.py b/basicsr/models/archs/arch_util.py index 961fcd4..19b4ed8 100644 --- a/basicsr/models/archs/arch_util.py +++ b/basicsr/models/archs/arch_util.py @@ -157,7 +157,7 @@ def flow_warp(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, - align_corners=True) + align_corners=align_corners) # TODO, what if align_corners=False return output diff --git a/basicsr/models/archs/vgg_arch.py b/basicsr/models/archs/vgg_arch.py index 89c8772..251b794 100644 --- a/basicsr/models/archs/vgg_arch.py +++ b/basicsr/models/archs/vgg_arch.py @@ -1,8 +1,10 @@ +import os import torch from collections import OrderedDict from torch import nn as nn from torchvision.models import vgg as vgg +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' NAMES = { 'vgg11': [ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', @@ -97,8 +99,16 @@ def __init__(self, idx = self.names.index(v) if idx > max_idx: max_idx = idx - features = getattr(vgg, - vgg_type)(pretrained=True).features[:max_idx + 1] + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load( + VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index 554f433..fa3401e 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -3,7 +3,7 @@ from .logger import (MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger) from .util import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, - scandir, set_random_seed) + scandir, set_random_seed, sizeof_fmt) __all__ = [ # file_client.py @@ -26,5 +26,6 @@ 'mkdir_and_rename', 'make_exp_dirs', 'scandir', - 'check_resume' + 'check_resume', + 'sizeof_fmt' ] diff --git a/basicsr/utils/crawler_util.py b/basicsr/utils/crawler_util.py new file mode 100644 index 0000000..b65dda3 --- /dev/null +++ b/basicsr/utils/crawler_util.py @@ -0,0 +1,35 @@ +import requests + + +def baidu_decode_url(encrypted_url): + """Decrypt baidu ecrypted url.""" + url = encrypted_url + map1 = {'_z2C$q': ':', '_z&e3B': '.', 'AzdH3F': '/'} + map2 = { + 'w': 'a', 'k': 'b', 'v': 'c', '1': 'd', 'j': 'e', + 'u': 'f', '2': 'g', 'i': 'h', 't': 'i', '3': 'j', + 'h': 'k', 's': 'l', '4': 'm', 'g': 'n', '5': 'o', + 'r': 'p', 'q': 'q', '6': 'r', 'f': 's', 'p': 't', + '7': 'u', 'e': 'v', 'o': 'w', '8': '1', 'd': '2', + 'n': '3', '9': '4', 'c': '5', 'm': '6', '0': '7', + 'b': '8', 'l': '9', 'a': '0' + } # yapf: disable + for (ciphertext, plaintext) in map1.items(): + url = url.replace(ciphertext, plaintext) + char_list = [char for char in url] + for i in range(len(char_list)): + if char_list[i] in map2: + char_list[i] = map2[char_list[i]] + url = ''.join(char_list) + return url + + +def setup_session(): + headers = { + 'User-Agent': ('Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_3)' + ' AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/48.0.2564.116 Safari/537.36') + } + session = requests.Session() + session.headers.update(headers) + return session diff --git a/basicsr/utils/download.py b/basicsr/utils/download.py index 07be86d..d27266e 100644 --- a/basicsr/utils/download.py +++ b/basicsr/utils/download.py @@ -2,6 +2,8 @@ import requests from tqdm import tqdm +from .util import sizeof_fmt + def download_file_from_google_drive(file_id, save_path): """Download files from google drive. @@ -66,20 +68,3 @@ def save_response_content(response, f.write(chunk) if pbar is not None: pbar.close() - - -def sizeof_fmt(size, suffix='B'): - """Get human readable file size. - - Args: - size (int): File size. - suffix (str): Suffix. Default: 'B'. - - Return: - str: Formated file siz. - """ - for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: - if abs(size) < 1024.0: - return f'{size:3.1f} {unit}{suffix}' - size /= 1024.0 - return f'{size:3.1f} Y{suffix}' diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 7042b85..3670d17 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -57,9 +57,10 @@ def parse(opt_path, is_train=True): dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) # paths - for key, path in opt['path'].items(): - if path and 'strict_load' not in key: - opt['path'][key] = osp.expanduser(path) + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key + or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) opt['path']['root'] = osp.abspath( osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) if is_train: diff --git a/basicsr/utils/util.py b/basicsr/utils/util.py index 26f3370..200527c 100644 --- a/basicsr/utils/util.py +++ b/basicsr/utils/util.py @@ -115,6 +115,25 @@ def check_resume(opt, resume_iter): for network in networks: name = f'pretrain_{network}' basename = network.replace('network_', '') - opt['path'][name] = osp.join(opt['path']['models'], - f'net_{basename}_{resume_iter}.pth') - logger.info(f"Set {name} to {opt['path'][name]}") + if opt['path'].get('ignore_resume_networks') is None or ( + basename not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join( + opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/scripts/extract_images_from_tfrecords.py b/scripts/extract_images_from_tfrecords.py index 3ee902a..e1311cc 100644 --- a/scripts/extract_images_from_tfrecords.py +++ b/scripts/extract_images_from_tfrecords.py @@ -1,9 +1,4 @@ -"""Read tfrecords w/o define a graph. - -Ref: -http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ -""" - +import argparse import cv2 import glob import numpy as np @@ -12,25 +7,46 @@ from basicsr.utils.lmdb import LmdbMaker -def celeba_tfrecords(): - # Configurations - file_pattern = '/home/xtwang/datasets/CelebA_tfrecords/celeba-full-tfr/train/train-r08-s-*-of-*.tfrecords' # noqa:E501 - # r08: resolution 2^8 = 256 - resolution = 128 - save_path = f'/home/xtwang/datasets/CelebA_tfrecords/tmptrain_{resolution}' +def convert_celeba_tfrecords(tf_file, + log_resolution, + save_root, + save_type='img', + compress_level=1): + """Convert CelebA tfrecords to images or lmdb files. + + Args: + tf_file (str): Input tfrecords file in glob pattern. + Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501 + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + if 'validation' in tf_file: + phase = 'validation' + else: + phase = 'train' + if save_type == 'lmdb': + save_path = os.path.join(save_root, + f'celeba_{2**log_resolution}_{phase}.lmdb') + lmdb_maker = LmdbMaker(save_path) + elif save_type == 'img': + save_path = os.path.join(save_root, + f'celeba_{2**log_resolution}_{phase}') + else: + raise ValueError('Wrong save type.') - save_all_path = os.path.join(save_path, f'all_{resolution}') - os.makedirs(save_all_path) + os.makedirs(save_path, exist_ok=True) idx = 0 - print(glob.glob(file_pattern)) - for record in glob.glob(file_pattern): + for record in sorted(glob.glob(tf_file)): + print('Processing record: ', record) record_iterator = tf.python_io.tf_record_iterator(record) for string_record in record_iterator: example = tf.train.Example() example.ParseFromString(string_record) - # label = example.features.feature['label'].int64_list.value[0] + # label = example.features.feature['label'].int64_list.value[0] # attr = example.features.feature['attr'].int64_list.value # male = attr[20] # young = attr[39] @@ -40,24 +56,51 @@ def celeba_tfrecords(): img_str = example.features.feature['data'].bytes_list.value[0] img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c)) - # save image img = img[:, :, [2, 1, 0]] - cv2.imwrite(os.path.join(save_all_path, f'{idx:08d}.png'), img) + + if save_type == 'img': + cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) + elif save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) idx += 1 print(idx) - -def ffhq_tfrecords(): - # Configurations - file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords' - resolution = 1024 - save_path = f'/home/xtwang/datasets/ffhq/ffhq_imgs/ffhq_{resolution}' + if save_type == 'lmdb': + lmdb_maker.close() + + +def convert_ffhq_tfrecords(tf_file, + log_resolution, + save_root, + save_type='img', + compress_level=1): + """Convert FFHQ tfrecords to images or lmdb files. + + Args: + tf_file (str): Input tfrecords file. + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + if save_type == 'lmdb': + save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}.lmdb') + lmdb_maker = LmdbMaker(save_path) + elif save_type == 'img': + save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}') + else: + raise ValueError('Wrong save type.') os.makedirs(save_path, exist_ok=True) + idx = 0 - print(glob.glob(file_pattern)) - for record in glob.glob(file_pattern): + for record in sorted(glob.glob(tf_file)): + print('Processing record: ', record) record_iterator = tf.python_io.tf_record_iterator(record) for string_record in record_iterator: example = tf.train.Example() @@ -68,56 +111,85 @@ def ffhq_tfrecords(): img_str = example.features.feature['data'].bytes_list.value[0] img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w)) - # save image img = img.transpose(1, 2, 0) img = img[:, :, [2, 1, 0]] - cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) + if save_type == 'img': + cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) + elif save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) idx += 1 print(idx) - -def ffhq_tfrecords_to_lmdb(): - # Configurations - file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords' - log_resolution = 10 - compress_level = 1 - lmdb_path = f'/home/xtwang/datasets/ffhq/ffhq_{2**log_resolution}.lmdb' - - idx = 0 - print(glob.glob(file_pattern)) - - lmdb_maker = LmdbMaker(lmdb_path) - for record in glob.glob(file_pattern): - record_iterator = tf.python_io.tf_record_iterator(record) - for string_record in record_iterator: - example = tf.train.Example() - example.ParseFromString(string_record) - - shape = example.features.feature['shape'].int64_list.value - c, h, w = shape - img_str = example.features.feature['data'].bytes_list.value[0] - img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w)) - - # write image to lmdb - img = img.transpose(1, 2, 0) - img = img[:, :, [2, 1, 0]] - _, img_byte = cv2.imencode( - '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) - key = f'{idx:08d}/r{log_resolution:02d}' - lmdb_maker.put(img_byte, key, (h, w, c)) - - idx += 1 - print(key) - lmdb_maker.close() + if save_type == 'lmdb': + lmdb_maker.close() if __name__ == '__main__': - # we have test on TensorFlow 1.15 + """Read tfrecords w/o define a graph. + + We have tested it on on TensorFlow 1.15 + + Ref: + http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--dataset', + type=str, + default='ffhq', + help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.") + parser.add_argument( + '--tf_file', + type=str, + default='datasets/ffhq/ffhq-r10.tfrecords', + help=( + 'Input tfrecords file. For celeba, it should be glob pattern. ' + 'Put quotes around the wildcard argument to prevent the shell ' + 'from expanding it.' + "Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501 + )) + parser.add_argument( + '--log_resolution', + type=int, + default=10, + help='Log scale of resolution.') + parser.add_argument( + '--save_root', + type=str, + default='datasets/ffhq/', + help='Save root path.') + parser.add_argument( + '--save_type', + type=str, + default='img', + help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.") + parser.add_argument( + '--compress_level', + type=int, + default=1, + help='Compress level when encoding images. Default: 1.') + args = parser.parse_args() + try: import tensorflow as tf except Exception: raise ImportError('You need to install tensorflow to read tfrecords.') - # celeba_tfrecords() - # ffhq_tfrecords() - ffhq_tfrecords_to_lmdb() + + if args.dataset == 'ffhq': + convert_ffhq_tfrecords( + args.tf_file, + args.log_resolution, + args.save_root, + save_type=args.save_type, + compress_level=args.compress_level) + else: + convert_celeba_tfrecords( + args.tf_file, + args.log_resolution, + args.save_root, + save_type=args.save_type, + compress_level=args.compress_level) diff --git a/test_scripts/test_esrgan.py b/test_scripts/test_esrgan.py index 1c70d82..8c64966 100644 --- a/test_scripts/test_esrgan.py +++ b/test_scripts/test_esrgan.py @@ -21,12 +21,9 @@ def main(): type=str, default='datasets/Set14/LRbicx4', help='input test image folder') - parser.add_argument( - '--device', type=str, default='cuda', help='Options: cuda, cpu.') args = parser.parse_args() - device = torch.device(args.device) - + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set up model model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32) diff --git a/test_scripts/test_face_dfdnet.py b/test_scripts/test_face_dfdnet.py index 3e44848..a3cdb27 100644 --- a/test_scripts/test_face_dfdnet.py +++ b/test_scripts/test_face_dfdnet.py @@ -36,6 +36,7 @@ def __init__(self, upscale_factor, face_template_path, out_size=512): self.inverse_affine_matrices = [] self.cropped_faces = [] self.restored_faces = [] + self.save_png = True def init_dlib(self, detection_path, landmark5_path, landmark68_path): """Initialize the dlib detectors and predictors.""" @@ -97,7 +98,9 @@ def get_face_landmarks_68(self): print('Should only have one face at most.') return num_detected_face - def warp_crop_faces(self, save_cropped_path=None): + def warp_crop_faces(self, + save_cropped_path=None, + save_inverse_affine_path=None): """Get affine matrix, warp and cropped faces. Also get inverse affine matrix for post-processing. @@ -114,7 +117,11 @@ def warp_crop_faces(self, save_cropped_path=None): # save the cropped face if save_cropped_path is not None: path, ext = os.path.splitext(save_cropped_path) - save_path = f'{path}_{idx:02d}{ext}' + if self.save_png: + save_path = f'{path}_{idx:02d}.png' + else: + save_path = f'{path}_{idx:02d}{ext}' + imwrite( cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) @@ -123,6 +130,11 @@ def warp_crop_faces(self, save_cropped_path=None): landmark * self.upscale_factor) inverse_affine = self.similarity_trans.params[0:2, :] self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) def add_restored_face(self, face): self.restored_faces.append(face) @@ -158,6 +170,9 @@ def paste_faces_to_input_image(self, save_path): (blur_size + 1, blur_size + 1), 0) upsample_img = inv_soft_mask * inv_restored_remove_border + ( 1 - inv_soft_mask) * upsample_img + if self.save_png: + save_path = save_path.replace('.jpg', + '.png').replace('.jpeg', '.png') imwrite(upsample_img.astype(np.uint8), save_path) def clean_all(self): @@ -220,7 +235,7 @@ def get_part_location(landmarks): differences: 1) we use dlib for 68 landmark detection; 2) the used image package are different (especially for reading and writing.) """ - device = 'cuda' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() parser.add_argument('--upscale_factor', type=int, default=2) @@ -236,6 +251,7 @@ def get_part_location(landmarks): 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth') parser.add_argument('--test_path', type=str, default='datasets/TestWhole') parser.add_argument('--upsample_num_times', type=int, default=1) + parser.add_argument('--save_inverse_affine', action='store_true') # The official codes use skimage.io to read the cropped images from disk # instead of directly using the intermediate results in the memory (as we # do). Such a different operation brings slight differences due to @@ -280,6 +296,8 @@ def get_part_location(landmarks): net.eval() save_crop_root = os.path.join(result_root, 'cropped_faces') + save_inverse_affine_root = os.path.join(result_root, 'inverse_affine') + os.makedirs(save_inverse_affine_root, exist_ok=True) save_restore_root = os.path.join(result_root, 'restored_faces') save_final_root = os.path.join(result_root, 'final_results') @@ -287,10 +305,16 @@ def get_part_location(landmarks): args.upscale_factor, args.face_template_path, out_size=512) # scan all the jpg and png images - for img_path in glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')): + for img_path in sorted( + glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))): img_name = os.path.basename(img_path) print(f'Processing {img_name} image ...') save_crop_path = os.path.join(save_crop_root, img_name) + if args.save_inverse_affine: + save_inverse_affine_path = os.path.join(save_inverse_affine_root, + img_name) + else: + save_inverse_affine_path = None face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path) @@ -301,11 +325,11 @@ def get_part_location(landmarks): num_landmarks = face_helper.get_face_landmarks_5() print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.') # warp and crop each face - face_helper.warp_crop_faces(save_crop_path) + face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path) if args.official_adaption: path, ext = os.path.splitext(save_crop_path) - pathes = sorted(glob.glob(f'{path}_[0-9]*{ext}')) + pathes = sorted(glob.glob(f'{path}_[0-9]*.png')) cropped_faces = [io.imread(path) for path in pathes] else: cropped_faces = face_helper.cropped_faces @@ -336,9 +360,9 @@ def get_part_location(landmarks): im = tensor2img(output, min_max=(-1, 1)) del output torch.cuda.empty_cache() - path, ext = os.path.splitext( - os.path.join(save_restore_root, img_name)) - save_path = f'{path}_{idx:02d}{ext}' + path = os.path.splitext( + os.path.join(save_restore_root, img_name))[0] + save_path = f'{path}_{idx:02d}.png' imwrite(im, save_path) face_helper.add_restored_face(im) diff --git a/test_scripts/test_stylegan2.py b/test_scripts/test_stylegan2.py index 38e5ff2..47bbe47 100644 --- a/test_scripts/test_stylegan2.py +++ b/test_scripts/test_stylegan2.py @@ -30,7 +30,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): if __name__ == '__main__': - device = 'cuda' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() From 6a58f8894b55d08e95948d8d2225c12976a17803 Mon Sep 17 00:00:00 2001 From: Xintao Date: Mon, 26 Oct 2020 22:41:04 +0800 Subject: [PATCH 08/23] github workflow add pull_request --- .github/workflows/pylint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index dc4f3cf..ee4ebce 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -1,6 +1,6 @@ name: Python Lint -on: [push] +on: [push, pull_request] jobs: build: From c9d1043ab66653709082d0bbbbb8856948752735 Mon Sep 17 00:00:00 2001 From: zenjieli <44634971+zenjieli@users.noreply.github.com> Date: Mon, 26 Oct 2020 19:06:03 +0100 Subject: [PATCH 09/23] Fix path string in regroup_reds_dataset (#313) * github workflow add pull_request * Fix path string in regroup_reds_dataset Co-authored-by: Xintao Co-authored-by: ZLI --- scripts/regroup_reds_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/regroup_reds_dataset.py b/scripts/regroup_reds_dataset.py index 3ce71fa..7d3ddbf 100644 --- a/scripts/regroup_reds_dataset.py +++ b/scripts/regroup_reds_dataset.py @@ -18,8 +18,9 @@ def regroup_reds_dataset(train_path, val_path): # move the validation data to the train folder val_folders = glob.glob(os.path.join(val_path, '*')) for folder in val_folders: - new_folder_idx = int(folder.split(' / ')[-1]) + 240 - os.system(f'cp -r {folder} {os.path.join(train_path, new_folder_idx)}') + new_folder_idx = int(folder.split('/')[-1]) + 240 + os.system( + f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}') if __name__ == '__main__': From df5816f41c9bcc6d80286222e68e67e32974c2b1 Mon Sep 17 00:00:00 2001 From: wenlong Date: Thu, 29 Oct 2020 00:24:19 +0800 Subject: [PATCH 10/23] niqe support nancov (#316) * Update niqe.py Refer to Matlab Help center, remove the row with Nan value. * Update niqe.py * Update niqe.py * Update niqe.py --- basicsr/metrics/niqe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py index 7447bb4..a16c0d6 100644 --- a/basicsr/metrics/niqe.py +++ b/basicsr/metrics/niqe.py @@ -141,7 +141,9 @@ def niqe(img, # fit a MVG (multivariate Gaussian) model to distorted patch features mu_distparam = np.nanmean(distparam, axis=0) - cov_distparam = np.cov(distparam, rowvar=False) # TODO: use nancov + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) # compute niqe quality, Eq. 10 in the paper invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) From 9bcd835862d0b89da0e53f85ad274c2cfb815935 Mon Sep 17 00:00:00 2001 From: Xintao Date: Thu, 29 Oct 2020 03:27:42 +0800 Subject: [PATCH 11/23] add matlab imresize bicubic (#317) --- basicsr/utils/matlab_functions.py | 169 ++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py index ad487f2..cd96a2c 100644 --- a/basicsr/utils/matlab_functions.py +++ b/basicsr/utils/matlab_functions.py @@ -1,4 +1,173 @@ +import math import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace( + 0, p - 1, p).view(1, p).expand(out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( + in_h, out_h, scale, kernel, kernel_width, antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( + in_w, out_w, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 def rgb2ycbcr(img, y_only=False): From 77ac51801123d769b97dd71b7b047b4343e26723 Mon Sep 17 00:00:00 2001 From: Xintao Date: Fri, 30 Oct 2020 17:01:02 +0800 Subject: [PATCH 12/23] README add datasets download links (#318) * add dataset download * update readme * update readme * update readme --- README.md | 3 ++- README_CN.md | 3 ++- docs/DatasetPreparation.md | 2 +- docs/DatasetPreparation_CN.md | 2 +- docs/ModelZoo.md | 11 ++++++----- docs/ModelZoo_CN.md | 10 +++++----- 6 files changed, 17 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 027b579..820af1d 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,9 @@ [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
-:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) +:m: [Model Zoo](docs/ModelZoo.md) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
+:file_folder: [Datasets](docs/DatasetPreparation.md) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/baidutodo) (提取码:wdix)
:chart_with_upwards_trend: [Training curves in wandb](https://app.wandb.ai/xintao/basicsr)
:computer: [Commands for training and testing](docs/TrainTest.md)
:zap: [HOWTOs](#zap-howtos) diff --git a/README_CN.md b/README_CN.md index c2b8b55..feac57c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -3,8 +3,9 @@ [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
-:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) +:m: [模型库](docs/ModelZoo_CN.md) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
+:file_folder: [数据](docs/DatasetPreparation_CN.md) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/baidutodo) (提取码:wdix) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing)
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr)
:computer: [训练和测试的命令](docs/TrainTest_CN.md)
:zap: [HOWTOs](#zap-howtos) diff --git a/docs/DatasetPreparation.md b/docs/DatasetPreparation.md index 31a9616..8434cb1 100644 --- a/docs/DatasetPreparation.md +++ b/docs/DatasetPreparation.md @@ -182,7 +182,7 @@ We provide a list of common image super-resolution datasets. Classical SR Training T91 91 images for training - Google Drive / Baidu Drive + Google Drive / Baidu Drive BSDS200 diff --git a/docs/DatasetPreparation_CN.md b/docs/DatasetPreparation_CN.md index 2582422..a71adf2 100644 --- a/docs/DatasetPreparation_CN.md +++ b/docs/DatasetPreparation_CN.md @@ -182,7 +182,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中. Classical SR Training T91 91 images for training - Google Drive / Baidu Drive + Google Drive / Baidu Drive BSDS200 diff --git a/docs/ModelZoo.md b/docs/ModelZoo.md index af6579c..4dd25aa 100644 --- a/docs/ModelZoo.md +++ b/docs/ModelZoo.md @@ -2,6 +2,11 @@ [English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md) +:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) +:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) + +--- + We provide: 1. Official models converted directly from official released models @@ -9,7 +14,7 @@ We provide: You can put the downloaded models in the `experiments/pretrained_models` folder. -**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g))(https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing)) +**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) You can use the scrip to download pre-trained models from Google Drive. @@ -93,7 +98,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) - **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission. - **M** (Moderate): # of channels = 64, # of back residual blocks = 10. -[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -107,7 +111,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) 1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels. #### Stage 2 models for the NTIRE19 Competition -[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -119,7 +122,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) ## DUF The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50) | Model name | [Test Set] PSNR/SSIM1 | Official Results2 | |:----------:|:----------:|:----------:| @@ -136,7 +138,6 @@ The models are converted from the [officially released models](https://github.co ## TOF The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG) | Model name | [Test Set] PSNR/SSIM | Official Results1 | |:----------:|:----------:|:----------:| diff --git a/docs/ModelZoo_CN.md b/docs/ModelZoo_CN.md index 9290a91..b192ee1 100644 --- a/docs/ModelZoo_CN.md +++ b/docs/ModelZoo_CN.md @@ -2,6 +2,11 @@ [English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md) +:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) +:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) + +--- + 我们提供了: 1. 官方的模型, 它们是从官方release的models直接转化过来的 @@ -92,8 +97,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) - **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission. - **M** (Moderate): # of channels = 64, # of back residual blocks = 10. -[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD) - | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| | EDVR_Vimeo90K_SR_L | [Vid4] (Y1) 27.35/0.8264 [[↓Results]](https://drive.google.com/open?id=14nozpSfe9kC12dVuJ9mspQH5ZqE4mT9K)
(RGB) 25.83/0.8077| @@ -106,7 +109,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) 1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels. #### Stage 2 models for the NTIRE19 Competition -[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -118,7 +120,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) ## DUF The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50) | Model name | [Test Set] PSNR/SSIM1 | Official Results2 | |:----------:|:----------:|:----------:| @@ -135,7 +136,6 @@ The models are converted from the [officially released models](https://github.co ## TOF The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG) | Model name | [Test Set] PSNR/SSIM | Official Results1 | |:----------:|:----------:|:----------:| From 8bb7a9dff8558a13adc1bd57c362b17a8ace8ddf Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 31 Oct 2020 12:21:11 +0800 Subject: [PATCH 13/23] add citation (#319) --- README.md | 17 +++++++++++++++++ README_CN.md | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/README.md b/README.md index 820af1d..7965c6c 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,23 @@ The figure below shows the overall framework. More descriptions for each compone This project is released under the Apache 2.0 license. More details about license and acknowledgement are in [LICENSE](LICENSE/README.md). +## :earth_asia: Citations + +If BasicSR helps your research or work, please consider citing BasicSR.
+The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package. + +``` latex +@misc{wang2020basicsr, + author = {Xintao Wang and Ke Yu and Kelvin C.K. Chan and + Chao Dong and Chen Change Loy}, + title = {BasicSR}, + howpublished = {\url{https://github.com/xinntao/BasicSR}}, + year = {2020} +} +``` + +> Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR. https://github.com/xinntao/BasicSR, 2020. + ## :e-mail: Contact If you have any question, please email `xintao.wang@outlook.com`. diff --git a/README_CN.md b/README_CN.md index feac57c..058cffe 100644 --- a/README_CN.md +++ b/README_CN.md @@ -118,6 +118,23 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res 本项目使用 Apache 2.0 license. 更多细节参见 [LICENSE](LICENSE/README.md). +## :earth_asia: 引用 + +如果 BasicSR 对你有所帮助, 可以考虑引用BasicSR.
+下面是一个 BibTex 引用条目, 它需要 `url` LaTeX package. + +``` latex +@misc{wang2020basicsr, + author = {Xintao Wang and Ke Yu and Kelvin C.K. Chan and + Chao Dong and Chen Change Loy}, + title = {BasicSR}, + howpublished = {\url{https://github.com/xinntao/BasicSR}}, + year = {2020} +} +``` + +> Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR. https://github.com/xinntao/BasicSR, 2020. + ## :e-mail: 联系 若有任何问题, 请电邮 `xintao.wang@outlook.com`. From d37ab0a36e7b9b52072edb721bfb48db4ff264d4 Mon Sep 17 00:00:00 2001 From: Mingyan Zhu Date: Sat, 31 Oct 2020 15:04:39 +0800 Subject: [PATCH 14/23] Fix metrics bug in video_base_model.py (#314) * Update video_base_model.py Fix the bug when metrics is "None" * Update video_base_model.py * Update video_base_model.py Co-authored-by: Xintao --- basicsr/models/video_base_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index 203fb7f..c8e8d26 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -34,10 +34,10 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') - rank, world_size = get_dist_info() - for _, tensor in self.metric_results.items(): - tensor.zero_() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() # record all frames (border and center frames) if rank == 0: pbar = tqdm(total=len(dataset), unit='frame') From 58d75552c1027de843a9d3f4a94987d137fbd522 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 31 Oct 2020 18:05:17 +0800 Subject: [PATCH 15/23] update baidupan link --- README.md | 2 +- README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7965c6c..dc94ee7 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
:m: [Model Zoo](docs/ModelZoo.md) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
-:file_folder: [Datasets](docs/DatasetPreparation.md) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/baidutodo) (提取码:wdix)
+:file_folder: [Datasets](docs/DatasetPreparation.md) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/1AZDcEAFwwc1OC3KCd7EDnQ) (提取码:basr)
:chart_with_upwards_trend: [Training curves in wandb](https://app.wandb.ai/xintao/basicsr)
:computer: [Commands for training and testing](docs/TrainTest.md)
:zap: [HOWTOs](#zap-howtos) diff --git a/README_CN.md b/README_CN.md index 058cffe..3ccd53a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -5,7 +5,7 @@ google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
:m: [模型库](docs/ModelZoo_CN.md) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
-:file_folder: [数据](docs/DatasetPreparation_CN.md) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/baidutodo) (提取码:wdix) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing)
+:file_folder: [数据](docs/DatasetPreparation_CN.md) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/1AZDcEAFwwc1OC3KCd7EDnQ) (提取码:basr) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing)
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr)
:computer: [训练和测试的命令](docs/TrainTest_CN.md)
:zap: [HOWTOs](#zap-howtos) From ee1a026a7ad52bd8e807024e41d70276028c59e2 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 28 Nov 2020 15:07:27 +0800 Subject: [PATCH 16/23] Merge private version (#337) * add paths_from_lmdb * add img_rotate and add return_status to augment * support pretrained networks without paramkey * pytorch < 1.5 does not have init_weights for inception_v3 * add narrow to stylegan2 * tensor2img support gray images * add metrics doc * add correct_mean_var to calculate_psnr_ssim * add make_ffhq_lmdb_from_imgs * add gdrive_download * add only_keep_largest to test_face_dfdnet * add fid lpips calculation --- basicsr/data/__init__.py | 2 +- basicsr/data/single_image_dataset.py | 10 ++- basicsr/data/transforms.py | 30 ++++++++- basicsr/data/util.py | 16 +++++ basicsr/models/archs/inception.py | 20 +++++- basicsr/models/archs/stylegan2_arch.py | 50 ++++++++------ basicsr/models/base_model.py | 7 +- basicsr/utils/img_util.py | 7 +- docs/Metrics.md | 35 ++++++++++ docs/Metrics_CN.md | 36 ++++++++++ scripts/calculate_fid_folder.py | 83 ++++++++++++++++++++++++ scripts/calculate_lpips.py | 56 ++++++++++++++++ scripts/calculate_psnr_ssim.py | 23 +++++++ scripts/extract_images_from_tfrecords.py | 40 ++++++++++++ scripts/gdrive_download.py | 12 ++++ test_scripts/test_face_dfdnet.py | 82 ++++++++++++++++++----- 16 files changed, 458 insertions(+), 51 deletions(-) create mode 100644 docs/Metrics.md create mode 100644 docs/Metrics_CN.md create mode 100644 scripts/calculate_fid_folder.py create mode 100644 scripts/calculate_lpips.py create mode 100644 scripts/gdrive_download.py diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py index 8232845..22b0c8b 100644 --- a/basicsr/data/__init__.py +++ b/basicsr/data/__init__.py @@ -98,7 +98,7 @@ def create_dataloader(dataset, seed=seed) if seed is not None else None elif phase in ['val', 'test']: # validation dataloader_args = dict( - dataset=dataset, batch_size=1, shuffle=False, num_workers=1) + dataset=dataset, batch_size=1, shuffle=False, num_workers=0) else: raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py index cb1bc01..dede4fc 100644 --- a/basicsr/data/single_image_dataset.py +++ b/basicsr/data/single_image_dataset.py @@ -2,6 +2,7 @@ from torch.utils import data as data from torchvision.transforms.functional import normalize +from basicsr.data.util import paths_from_lmdb from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir @@ -30,7 +31,12 @@ def __init__(self, opt): self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.lq_folder = opt['dataroot_lq'] - if 'meta_info_file' in self.opt: + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: with open(self.opt['meta_info_file'], 'r') as fin: self.paths = [ osp.join(self.lq_folder, @@ -46,7 +52,7 @@ def __getitem__(self, index): # load lq image lq_path = self.paths[index] - img_bytes = self.file_client.get(lq_path) + img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # TODO: color space transform diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py index d7d1e7a..b6d04ff 100644 --- a/basicsr/data/transforms.py +++ b/basicsr/data/transforms.py @@ -84,7 +84,7 @@ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): return img_gts, img_lqs -def augment(imgs, hflip=True, rotation=True, flows=None): +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). We use vertical flip and transpose for rotation implementation. @@ -98,6 +98,8 @@ def augment(imgs, hflip=True, rotation=True, flows=None): flows (list[ndarray]: Flows to be augmented. If the input is an ndarray, it will be transformed to a list. Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. Returns: list[ndarray] | ndarray: Augmented images and flows. If returned @@ -143,4 +145,28 @@ def _augment_flow(flow): flows = flows[0] return imgs, flows else: - return imgs + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/basicsr/data/util.py b/basicsr/data/util.py index b4a14e9..975c0c0 100644 --- a/basicsr/data/util.py +++ b/basicsr/data/util.py @@ -262,6 +262,22 @@ def paths_from_folder(folder): return paths +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + def generate_gaussian_kernel(kernel_size=13, sigma=1.6): """Generate Gaussian kernel used in `duf_downsample`. diff --git a/basicsr/models/archs/inception.py b/basicsr/models/archs/inception.py index d706bab..3efdf62 100644 --- a/basicsr/models/archs/inception.py +++ b/basicsr/models/archs/inception.py @@ -79,7 +79,12 @@ def __init__(self, if use_fid_inception: inception = fid_inception_v3() else: - inception = models.inception_v3(pretrained=True) + try: + inception = models.inception_v3( + pretrained=True, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(pretrained=True) # Block 0: input to maxpool1 block0 = [ @@ -163,8 +168,17 @@ def fid_inception_v3(): This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ - inception = models.inception_v3( - num_classes=1008, aux_logits=False, pretrained=False) + try: + inception = models.inception_v3( + num_classes=1008, + aux_logits=False, + pretrained=False, + init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) diff --git a/basicsr/models/archs/stylegan2_arch.py b/basicsr/models/archs/stylegan2_arch.py index 2b308b4..3d53cff 100644 --- a/basicsr/models/archs/stylegan2_arch.py +++ b/basicsr/models/archs/stylegan2_arch.py @@ -454,6 +454,7 @@ class StyleGAN2Generator(nn.Module): magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. Default: (1, 3, 3, 1). lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. """ def __init__(self, @@ -462,7 +463,8 @@ def __init__(self, num_mlp=8, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), - lr_mlp=0.01): + lr_mlp=0.01, + narrow=1): super(StyleGAN2Generator, self).__init__() # Style MLP layers self.num_style_feat = num_style_feat @@ -479,16 +481,17 @@ def __init__(self, self.style_mlp = nn.Sequential(*style_mlp_layers) channels = { - '4': 512, - '8': 512, - '16': 512, - '32': 512, - '64': 256 * channel_multiplier, - '128': 128 * channel_multiplier, - '256': 64 * channel_multiplier, - '512': 32 * channel_multiplier, - '1024': 16 * channel_multiplier + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) } + self.channels = channels self.constant_input = ConstantInput(channels['4'], size=4) self.style_conv1 = StyleConv( @@ -840,25 +843,30 @@ class StyleGAN2Discriminator(nn.Module): resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. Default: (1, 3, 3, 1). + stddev_group (int): For group stddev statistics. Default: 4. + narrow (float): Narrow ratio for channels. Default: 1.0. """ def __init__(self, out_size, channel_multiplier=2, - resample_kernel=(1, 3, 3, 1)): + resample_kernel=(1, 3, 3, 1), + stddev_group=4, + narrow=1): super(StyleGAN2Discriminator, self).__init__() channels = { - '4': 512, - '8': 512, - '16': 512, - '32': 512, - '64': 256 * channel_multiplier, - '128': 128 * channel_multiplier, - '256': 64 * channel_multiplier, - '512': 32 * channel_multiplier, - '1024': 16 * channel_multiplier + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) } + log_size = int(math.log(out_size, 2)) conv_body = [ @@ -891,7 +899,7 @@ def __init__(self, lr_mul=1, activation=None), ) - self.stddev_group = 4 + self.stddev_group = stddev_group self.stddev_feat = 1 def forward(self, x): diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index 3bb89b2..f8987bf 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -242,14 +242,17 @@ def load_network(self, net, load_path, strict=True, param_key='params'): load_path (str): The path of networks to be loaded. net (nn.Module): Network. strict (bool): Whether strictly loaded. - param_key (str): The parameter key of loaded network. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. Default: 'params'. """ net = self.get_bare_model(net) logger.info( f'Loading {net.__class__.__name__} model from {load_path}.') load_net = torch.load( - load_path, map_location=lambda storage, loc: storage)[param_key] + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + load_net = load_net[param_key] # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py index 4096cfd..152be01 100644 --- a/basicsr/utils/img_util.py +++ b/basicsr/utils/img_util.py @@ -78,8 +78,11 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 3: img_np = _tensor.numpy() img_np = img_np.transpose(1, 2, 0) - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif n_dim == 2: img_np = _tensor.numpy() else: diff --git a/docs/Metrics.md b/docs/Metrics.md new file mode 100644 index 0000000..c4f0cb1 --- /dev/null +++ b/docs/Metrics.md @@ -0,0 +1,35 @@ +# Metrics + +[English](Metrics.md) **|** [简体中文](Metrics_CN.md) + +## PSNR and SSIM + +## NIQE + +## FID + +> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. +> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. + +References + +- https://github.com/mseitzer/pytorch-fid +- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) +- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) + +### Pre-calculated FFHQ inception feature statistics + +Usually, we put the downloaded inception feature statistics in `basicsr/metrics`. + +:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing) +:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ)
+ +| File Name | Dataset | Image Shape | Sample Numbers| +| :------------- | :----------:|:----------:|:----------:| +| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 | +| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 | +| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 | +| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 | + +- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size. +- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/docs/Metrics_CN.md b/docs/Metrics_CN.md new file mode 100644 index 0000000..c5f518c --- /dev/null +++ b/docs/Metrics_CN.md @@ -0,0 +1,36 @@ +# 评价指标 + +[English](Metrics.md) **|** [简体中文](Metrics_CN.md) + +## PSNR and SSIM + +## NIQE + +## FID + +> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. +> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. + +参考 + +- https://github.com/mseitzer/pytorch-fid +- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) +- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) + +### Pre-calculated FFHQ inception feature statistics + +通常, 我们把下载的 inception 网络的特征统计数据 (用于计算FID) 放在 `basicsr/metrics`. + + +:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ) +:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing)
+ +| File Name | Dataset | Image Shape | Sample Numbers| +| :------------- | :----------:|:----------:|:----------:| +| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 | +| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 | +| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 | +| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 | + +- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size. +- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/scripts/calculate_fid_folder.py b/scripts/calculate_fid_folder.py new file mode 100644 index 0000000..b903160 --- /dev/null +++ b/scripts/calculate_fid_folder.py @@ -0,0 +1,83 @@ +import argparse +import math +import numpy as np +import torch +from torch.utils.data import DataLoader + +from basicsr.data import create_dataset +from basicsr.metrics.fid import (calculate_fid, extract_inception_features, + load_patched_inception_v3) + + +def calculate_fid_folder(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument('folder', type=str, help='Path to the folder.') + parser.add_argument( + '--fid_stats', type=str, help='Path to the dataset fid statistics.') + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument( + '--backend', + type=str, + default='disk', + help='io backend for dataset. Option: disk, lmdb') + args = parser.parse_args() + + # inception model + inception = load_patched_inception_v3(device) + + # create dataset + opt = {} + opt['name'] = 'SingleImageDataset' + opt['type'] = 'SingleImageDataset' + opt['dataroot_lq'] = args.folder + opt['io_backend'] = dict(type=args.backend) + opt['mean'] = [0.5, 0.5, 0.5] + opt['std'] = [0.5, 0.5, 0.5] + dataset = create_dataset(opt) + + # create dataloader + data_loader = DataLoader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + sampler=None, + drop_last=False) + args.num_sample = min(args.num_sample, len(dataset)) + total_batch = math.ceil(args.num_sample / args.batch_size) + + def data_generator(data_loader, total_batch): + for idx, data in enumerate(data_loader): + if idx >= total_batch: + break + else: + yield data['lq'] + + features = extract_inception_features( + data_generator(data_loader, total_batch), inception, total_batch, + device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + + sample_mean = np.mean(features, 0) + sample_cov = np.cov(features, rowvar=False) + + # load the dataset stats + stats = torch.load(args.fid_stats) + real_mean = stats['mean'] + real_cov = stats['cov'] + + # calculate FID metric + fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) + print('fid:', fid) + + +if __name__ == '__main__': + calculate_fid_folder() diff --git a/scripts/calculate_lpips.py b/scripts/calculate_lpips.py new file mode 100644 index 0000000..d9fbd3c --- /dev/null +++ b/scripts/calculate_lpips.py @@ -0,0 +1,56 @@ +import cv2 +import glob +import numpy as np +import os.path as osp +from torchvision.transforms.functional import normalize + +from basicsr.utils import img2tensor + +try: + import lpips +except ImportError: + print('Please install lpips: pip install lpips') + + +def main(): + # Configurations + # ------------------------------------------------------------------------- + folder_gt = 'datasets/celeba/celeba_512_validation' + folder_restored = 'datasets/celeba/celeba_512_validation_lq' + # crop_border = 4 + suffix = '' + # ------------------------------------------------------------------------- + loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1] + lpips_all = [] + img_list = sorted(glob.glob(osp.join(folder_gt, '*'))) + + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + for i, img_path in enumerate(img_list): + basename, ext = osp.splitext(osp.basename(img_path)) + img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( + np.float32) / 255. + img_restored = cv2.imread( + osp.join(folder_restored, basename + suffix + ext), + cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + + img_gt, img_restored = img2tensor([img_gt, img_restored], + bgr2rgb=True, + float32=True) + # norm to [-1, 1] + normalize(img_gt, mean, std, inplace=True) + normalize(img_restored, mean, std, inplace=True) + + # calculate lpips + lpips_val = loss_fn_vgg( + img_restored.unsqueeze(0).cuda(), + img_gt.unsqueeze(0).cuda()) + + print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.') + lpips_all.append(lpips_val) + + print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}') + + +if __name__ == '__main__': + main() diff --git a/scripts/calculate_psnr_ssim.py b/scripts/calculate_psnr_ssim.py index 9aff9c9..f73ccde 100644 --- a/scripts/calculate_psnr_ssim.py +++ b/scripts/calculate_psnr_ssim.py @@ -25,6 +25,7 @@ def main(): crop_border = 4 suffix = '_expname' test_y_channel = False + correct_mean_var = True # ------------------------------------------------------------------------- psnr_all = [] @@ -44,6 +45,26 @@ def main(): osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + if correct_mean_var: + mean_l = [] + std_l = [] + for j in range(3): + mean_l.append(np.mean(img_gt[:, :, j])) + std_l.append(np.std(img_gt[:, :, j])) + for j in range(3): + # correct twice + mean = np.mean(img_restored[:, :, j]) + img_restored[:, :, + j] = img_restored[:, :, j] - mean + mean_l[j] + std = np.std(img_restored[:, :, j]) + img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] + + mean = np.mean(img_restored[:, :, j]) + img_restored[:, :, + j] = img_restored[:, :, j] - mean + mean_l[j] + std = np.std(img_restored[:, :, j]) + img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] + if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: img_gt = bgr2ycbcr(img_gt, y_only=True) img_restored = bgr2ycbcr(img_restored, y_only=True) @@ -63,6 +84,8 @@ def main(): f'\tSSIM: {ssim:.6f}') psnr_all.append(psnr) ssim_all.append(ssim) + print(folder_gt) + print(folder_restored) print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, ' f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}') diff --git a/scripts/extract_images_from_tfrecords.py b/scripts/extract_images_from_tfrecords.py index e1311cc..8e0706b 100644 --- a/scripts/extract_images_from_tfrecords.py +++ b/scripts/extract_images_from_tfrecords.py @@ -128,6 +128,46 @@ def convert_ffhq_tfrecords(tf_file, lmdb_maker.close() +def make_ffhq_lmdb_from_imgs(folder_path, + log_resolution, + save_root, + save_type='lmdb', + compress_level=1): + """Make FFHQ lmdb from images. + + Args: + folder_path (str): Folder path. + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + if save_type == 'lmdb': + save_path = os.path.join(save_root, + f'ffhq_{2**log_resolution}_crop1.2.lmdb') + lmdb_maker = LmdbMaker(save_path) + else: + raise ValueError('Wrong save type.') + + os.makedirs(save_path, exist_ok=True) + + img_list = sorted(glob.glob(os.path.join(folder_path, '*'))) + for idx, img_path in enumerate(img_list): + print(f'Processing {idx}: ', img_path) + img = cv2.imread(img_path) + h, w, c = img.shape + + if save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) + + if save_type == 'lmdb': + lmdb_maker.close() + + if __name__ == '__main__': """Read tfrecords w/o define a graph. diff --git a/scripts/gdrive_download.py b/scripts/gdrive_download.py new file mode 100644 index 0000000..e67cad9 --- /dev/null +++ b/scripts/gdrive_download.py @@ -0,0 +1,12 @@ +import argparse + +from basicsr.utils.download import download_file_from_google_drive + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--id', type=str, help='File id') + parser.add_argument('--output', type=str, help='Save path') + args = parser.parse_args() + + download_file_from_google_drive(args.id, args.save_path) diff --git a/test_scripts/test_face_dfdnet.py b/test_scripts/test_face_dfdnet.py index a3cdb27..331e731 100644 --- a/test_scripts/test_face_dfdnet.py +++ b/test_scripts/test_face_dfdnet.py @@ -53,7 +53,10 @@ def read_input_image(self, img_path): # self.input_img is Numpy array, (h, w, c) with RGB order self.input_img = dlib.load_rgb_image(img_path) - def detect_faces(self, img_path, upsample_num_times=1): + def detect_faces(self, + img_path, + upsample_num_times=1, + only_keep_largest=False): """ Args: img_path (str): Image path. @@ -64,9 +67,23 @@ def detect_faces(self, img_path, upsample_num_times=1): int: Number of detected faces. """ self.read_input_image(img_path) - self.det_faces = self.face_detector(self.input_img, upsample_num_times) - if len(self.det_faces) == 0: + det_faces = self.face_detector(self.input_img, upsample_num_times) + if len(det_faces) == 0: print('No face detected. Try to increase upsample_num_times.') + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - + det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - + det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces return len(self.det_faces) def get_face_landmarks_5(self): @@ -88,14 +105,28 @@ def get_face_landmarks_68(self): if len(det_face) == 0: print(f'Cannot find faces in cropped image with index {idx}.') self.all_landmarks_68.append(None) - elif len(det_face) == 1: - shape = self.shape_predictor_68(face, det_face[0].rect) + else: + if len(det_face) > 1: + print('Detect several faces in the cropped face. Use the ' + ' largest one. Note that it will also cause overlap ' + 'during paste_faces_to_input_image.') + face_areas = [] + for i in range(len(det_face)): + face_area = (det_face[i].rect.right() - + det_face[i].rect.left()) * ( + det_face[i].rect.bottom() - + det_face[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + face_rect = det_face[largest_idx].rect + else: + face_rect = det_face[0].rect + shape = self.shape_predictor_68(face, face_rect) landmark = np.array([[part.x, part.y] for part in shape.parts()]) self.all_landmarks_68.append(landmark) num_detected_face += 1 - else: - print('Should only have one face at most.') + return num_detected_face def warp_crop_faces(self, @@ -146,6 +177,8 @@ def paste_faces_to_input_image(self, save_path): h_up, w_up = h * self.upscale_factor, w * self.upscale_factor # simply resize the background upsample_img = cv2.resize(input_img, (w_up, h_up)) + assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( + 'length of restored_faces and affine_matrices are different.') for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): inv_restored = cv2.warpAffine(restored_face, inverse_affine, @@ -252,6 +285,7 @@ def get_part_location(landmarks): parser.add_argument('--test_path', type=str, default='datasets/TestWhole') parser.add_argument('--upsample_num_times', type=int, default=1) parser.add_argument('--save_inverse_affine', action='store_true') + parser.add_argument('--only_keep_largest', action='store_true') # The official codes use skimage.io to read the cropped images from disk # instead of directly using the intermediate results in the memory (as we # do). Such a different operation brings slight differences due to @@ -286,7 +320,9 @@ def get_part_location(landmarks): ) args = parser.parse_args() - result_root = f'results/DFDNet/{args.test_path.split("/")[-1]}' + if args.test_path.endswith('/'): # solve when path ends with / + args.test_path = args.test_path[:-1] + result_root = f'results/DFDNet/{os.path.basename(args.test_path)}' # set up the DFDNet net = DFDNet(64, dict_path=args.dict_path).to(device) @@ -320,7 +356,9 @@ def get_part_location(landmarks): args.landmark68_path) # detect faces num_det_faces = face_helper.detect_faces( - img_path, upsample_num_times=args.upsample_num_times) + img_path, + upsample_num_times=args.upsample_num_times, + only_keep_largest=args.only_keep_largest) # get 5 face landmarks for each face num_landmarks = face_helper.get_face_landmarks_5() print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.') @@ -342,10 +380,13 @@ def get_part_location(landmarks): print('\tFace restoration ...') # face restoration for each cropped face + assert len(cropped_faces) == len(face_helper.all_landmarks_68) for idx, (cropped_face, landmarks) in enumerate( zip(cropped_faces, face_helper.all_landmarks_68)): if landmarks is None: print(f'Landmarks is None, skip cropped faces with idx {idx}.') + # just copy the cropped faces to the restored faces + restored_face = cropped_face else: # prepare data part_locations = get_part_location(landmarks) @@ -355,16 +396,21 @@ def get_part_location(landmarks): cropped_face) cropped_face = cropped_face.unsqueeze(0).to(device) - with torch.no_grad(): - output = net(cropped_face, part_locations) - im = tensor2img(output, min_max=(-1, 1)) + try: + with torch.no_grad(): + output = net(cropped_face, part_locations) + restored_face = tensor2img(output, min_max=(-1, 1)) del output - torch.cuda.empty_cache() - path = os.path.splitext( - os.path.join(save_restore_root, img_name))[0] - save_path = f'{path}_{idx:02d}.png' - imwrite(im, save_path) - face_helper.add_restored_face(im) + torch.cuda.empty_cache() + except Exception as e: + print(f'DFDNet inference fail: {e}') + restored_face = tensor2img(cropped_face, min_max=(-1, 1)) + + path = os.path.splitext(os.path.join(save_restore_root, + img_name))[0] + save_path = f'{path}_{idx:02d}.png' + imwrite(restored_face, save_path) + face_helper.add_restored_face(restored_face) print('\tGenerate the final result ...') # paste each restored face to the input image From 192b1a9f1078d998d7e70d37526822cc39ff13da Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 28 Nov 2020 21:25:23 +0800 Subject: [PATCH 17/23] re-organize codes (#338) * re organize * remove requirement: matplotlib * update readme * update readme_cn * update vgg scale_norm to range_norm * dfdnet VGGFaceFeatureExtractor -> VGG Extractor * update psnr ssim * remove crawler_util * update * change name * rename * update readme * update scripts --- ...gitee-repo-mirror.yml => gitee-mirror.yml} | 0 .github/workflows/pylint.yml | 4 +-- README.md | 25 +++++++++---- README_CN.md | 23 ++++++++---- basicsr/data/{util.py => data_util.py} | 0 basicsr/data/paired_image_dataset.py | 6 ++-- basicsr/data/single_image_dataset.py | 2 +- basicsr/data/video_test_dataset.py | 32 ++++++++--------- basicsr/metrics/metric_util.py | 1 - basicsr/metrics/psnr_ssim.py | 4 +++ basicsr/models/archs/dfdnet_arch.py | 33 ++++------------- basicsr/models/archs/vgg_arch.py | 7 +++- .../losses/{loss_utils.py => loss_util.py} | 0 basicsr/models/losses/losses.py | 20 ++++------- basicsr/utils/__init__.py | 4 +-- basicsr/utils/crawler_util.py | 35 ------------------- .../utils/{download.py => download_util.py} | 2 +- basicsr/utils/{lmdb.py => lmdb_util.py} | 0 basicsr/utils/{util.py => misc.py} | 0 docs/DatasetPreparation.md | 18 +++++----- docs/DatasetPreparation_CN.md | 18 +++++----- docs/HOWTOs.md | 10 +++--- docs/HOWTOs_CN.md | 6 ++-- .../inference_dfdnet.py | 0 .../inference_esrgan.py | 0 .../inference_stylegan2.py | 0 make.sh | 7 ---- .../train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml | 2 +- options/train/ESRGAN/train_ESRGAN_x4.yml | 2 +- .../train/SRResNet_SRGAN/train_MSRGAN_x4.yml | 2 +- requirements.txt | 1 - scripts/{ => data_preparation}/create_lmdb.py | 25 ++++++++++--- .../download_datasets.py | 2 +- .../extract_images_from_tfrecords.py | 4 +-- .../extract_subimages.py | 2 +- .../generate_meta_info.py | 0 .../regroup_reds_dataset.py | 0 ...{gdrive_download.py => download_gdrive.py} | 2 +- scripts/download_pretrained_models.py | 2 +- scripts/{ => metrics}/calculate_fid_folder.py | 0 .../calculate_fid_stats_from_datasets.py | 0 scripts/{ => metrics}/calculate_lpips.py | 0 scripts/{ => metrics}/calculate_psnr_ssim.py | 2 +- .../{ => metrics}/calculate_stylegan2_fid.py | 0 .../{ => model_conversion}/convert_dfdnet.py | 0 .../{ => model_conversion}/convert_models.py | 0 .../convert_stylegan.py | 0 setup.cfg | 2 +- tests/test_lr_scheduler.py | 10 ++++-- 49 files changed, 146 insertions(+), 169 deletions(-) rename .github/workflows/{gitee-repo-mirror.yml => gitee-mirror.yml} (100%) rename basicsr/data/{util.py => data_util.py} (100%) rename basicsr/models/losses/{loss_utils.py => loss_util.py} (100%) delete mode 100644 basicsr/utils/crawler_util.py rename basicsr/utils/{download.py => download_util.py} (98%) rename basicsr/utils/{lmdb.py => lmdb_util.py} (100%) rename basicsr/utils/{util.py => misc.py} (100%) rename test_scripts/test_face_dfdnet.py => inference/inference_dfdnet.py (100%) rename test_scripts/test_esrgan.py => inference/inference_esrgan.py (100%) rename test_scripts/test_stylegan2.py => inference/inference_stylegan2.py (100%) delete mode 100644 make.sh rename scripts/{ => data_preparation}/create_lmdb.py (89%) rename scripts/{ => data_preparation}/download_datasets.py (97%) rename scripts/{ => data_preparation}/extract_images_from_tfrecords.py (98%) rename scripts/{ => data_preparation}/extract_subimages.py (99%) rename scripts/{ => data_preparation}/generate_meta_info.py (100%) rename scripts/{ => data_preparation}/regroup_reds_dataset.py (100%) rename scripts/{gdrive_download.py => download_gdrive.py} (80%) rename scripts/{ => metrics}/calculate_fid_folder.py (100%) rename scripts/{ => metrics}/calculate_fid_stats_from_datasets.py (100%) rename scripts/{ => metrics}/calculate_lpips.py (100%) rename scripts/{ => metrics}/calculate_psnr_ssim.py (99%) rename scripts/{ => metrics}/calculate_stylegan2_fid.py (100%) rename scripts/{ => model_conversion}/convert_dfdnet.py (100%) rename scripts/{ => model_conversion}/convert_models.py (100%) rename scripts/{ => model_conversion}/convert_stylegan.py (100%) diff --git a/.github/workflows/gitee-repo-mirror.yml b/.github/workflows/gitee-mirror.yml similarity index 100% rename from .github/workflows/gitee-repo-mirror.yml rename to .github/workflows/gitee-mirror.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index ee4ebce..0b61a71 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -25,5 +25,5 @@ jobs: - name: Lint run: | flake8 . - isort --check-only --diff basicsr/ options/ scripts/ tests/ setup.py - yapf -r -d basicsr/ options/ scripts/ tests/ setup.py + isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py + yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py diff --git a/README.md b/README.md index dc94ee7..684bcb1 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,11 @@ --- -BasicSR is an **open source** image and video super-resolution toolbox based on PyTorch (will extend to more restoration tasks in the future).
+BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and video restoration** toolbox based on PyTorch, such as super-resolution, denoise, deblurring, JPEG artifacts removal, *etc*.
([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN)) +([HandyView](https://github.com/xinntao/HandyView), [HandyFigure](https://github.com/xinntao/HandyFigure), [HandyCrawler](https://github.com/xinntao/HandyCrawler), [HandyWriting](https://github.com/xinntao/HandyWriting)) -## :sparkles: New Feature +## :sparkles: New Features - Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. Note that it is slightly different from the official testing codes. > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
@@ -36,9 +37,10 @@ BasicSR is an **open source** image and video super-resolution toolbox based on We provides simple pipelines to train/test/inference models for quick start. These pipelines/commands cannot cover all the cases and more details are in the following sections. -- [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2) -- [How to test StyleGAN2](docs/HOWTOs.md#How-to-test-StyleGAN2) -- [How to test DFDNet](docs/HOWTOs.md#How-to-test-DFDNet) +| | | | +| :--- | :---: | :---: | +| StyleGAN2 | [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2) | [How to inference StyleGAN2](docs/HOWTOs.md#How-to-inference-StyleGAN2) | +| DFDNet | *TODO* | [How to inference DFDNet](docs/HOWTOs.md#How-to-inference-DFDNet) | ## :wrench: Dependencies and Installation @@ -78,6 +80,15 @@ These pipelines/commands cannot cover all the cases and more details are in the python setup.py develop ``` + You may also want to specify the CUDA paths: + + ```bash + CUDA_HOME=/usr/local/cuda \ + CUDNN_INCLUDE_DIR=/usr/local/cuda \ + CUDNN_LIB_DIR=/usr/local/cuda \ + python setup.py develop + ``` + Note that BasicSR is only tested in Ubuntu, and may be not suitable for Windows. You may try [Windows WSL with CUDA supports](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (It is now only available for insider build with Fast ring). ## :hourglass_flowing_sand: TODO List @@ -116,8 +127,8 @@ The figure below shows the overall framework. More descriptions for each compone ## :scroll: License and Acknowledgement -This project is released under the Apache 2.0 license. -More details about license and acknowledgement are in [LICENSE](LICENSE/README.md). +This project is released under the Apache 2.0 license.
+More details about **license** and **acknowledgement** are in [LICENSE](LICENSE/README.md). ## :earth_asia: Citations diff --git a/README_CN.md b/README_CN.md index 3ccd53a..57ff71b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -12,8 +12,9 @@ --- -BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Resolution) 工具箱 (之后会支持更多的 Restoration 任务).
+BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源图像视频复原工具箱, 比如 超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等.
([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN)) +([HandyView](https://gitee.com/xinntao/HandyView), [HandyFigure](https://gitee.com/xinntao/HandyFigure), [HandyCrawler](https://gitee.com/xinntao/HandyCrawler), [HandyWriting](https://gitee.com/xinntao/HandyWriting)) ## :sparkles: 新的特性 @@ -35,9 +36,10 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res 我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. -- [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) -- [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) -- [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet) +| | | | +| :--- | :---: | :---: | +| StyleGAN2 | [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | +| DFDNet | - | [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet) | ## :wrench: 依赖和安装 @@ -77,6 +79,15 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res python setup.py develop ``` + 你或许需要指定 CUDA 路径: + + ```bash + CUDA_HOME=/usr/local/cuda \ + CUDNN_INCLUDE_DIR=/usr/local/cuda \ + CUDNN_LIB_DIR=/usr/local/cuda \ + python setup.py develop + ``` + 注意: BasicSR 仅在 Ubuntu 下进行测试,或许不支持Windows. 可以在Windows下尝试[支持CUDA的Windows WSL](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (目前只有Fast ring的预览版系统可以安装). ## :hourglass_flowing_sand: TODO 清单 @@ -115,8 +126,8 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res ## :scroll: 许可 -本项目使用 Apache 2.0 license. -更多细节参见 [LICENSE](LICENSE/README.md). +本项目使用 Apache 2.0 license.
+更多关于**许可**和**致谢**, 请参见 [LICENSE](LICENSE/README.md). ## :earth_asia: 引用 diff --git a/basicsr/data/util.py b/basicsr/data/data_util.py similarity index 100% rename from basicsr/data/util.py rename to basicsr/data/data_util.py diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 7775d0b..66c042f 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -1,10 +1,10 @@ from torch.utils import data as data from torchvision.transforms.functional import normalize +from basicsr.data.data_util import (paired_paths_from_folder, + paired_paths_from_lmdb, + paired_paths_from_meta_info_file) from basicsr.data.transforms import augment, paired_random_crop -from basicsr.data.util import (paired_paths_from_folder, - paired_paths_from_lmdb, - paired_paths_from_meta_info_file) from basicsr.utils import FileClient, imfrombytes, img2tensor diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py index dede4fc..b752b00 100644 --- a/basicsr/data/single_image_dataset.py +++ b/basicsr/data/single_image_dataset.py @@ -2,7 +2,7 @@ from torch.utils import data as data from torchvision.transforms.functional import normalize -from basicsr.data.util import paths_from_lmdb +from basicsr.data.data_util import paths_from_lmdb from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py index d3d21f9..01b876a 100644 --- a/basicsr/data/video_test_dataset.py +++ b/basicsr/data/video_test_dataset.py @@ -3,8 +3,8 @@ from os import path as osp from torch.utils import data as data -from basicsr.data import util as util -from basicsr.data.util import duf_downsample +from basicsr.data.data_util import (duf_downsample, generate_frame_indices, + read_img_seq) from basicsr.utils import get_root_logger, scandir @@ -105,10 +105,8 @@ def __init__(self, opt): if self.cache_data: logger.info( f'Cache {subfolder_name} for VideoTestDataset...') - self.imgs_lq[subfolder_name] = util.read_img_seq( - img_paths_lq) - self.imgs_gt[subfolder_name] = util.read_img_seq( - img_paths_gt) + self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) + self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) else: self.imgs_lq[subfolder_name] = img_paths_lq self.imgs_gt[subfolder_name] = img_paths_gt @@ -123,7 +121,7 @@ def __getitem__(self, index): border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] - select_idx = util.generate_frame_indices( + select_idx = generate_frame_indices( idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: @@ -132,8 +130,8 @@ def __getitem__(self, index): img_gt = self.imgs_gt[folder][idx] else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] - imgs_lq = util.read_img_seq(img_paths_lq) - img_gt = util.read_img_seq([self.imgs_gt[folder][idx]]) + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]]) img_gt.squeeze_(0) return { @@ -213,8 +211,8 @@ def __init__(self, opt): def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] gt_path = self.data_info['gt_path'][index] - imgs_lq = util.read_img_seq(lq_path) - img_gt = util.read_img_seq([gt_path]) + imgs_lq = read_img_seq(lq_path) + img_gt = read_img_seq([gt_path]) img_gt.squeeze_(0) return { @@ -250,7 +248,7 @@ def __getitem__(self, index): border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] - select_idx = util.generate_frame_indices( + select_idx = generate_frame_indices( idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: @@ -268,7 +266,7 @@ def __getitem__(self, index): if self.opt['use_duf_downsampling']: img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] # read imgs_gt to generate low-resolution frames - imgs_lq = util.read_img_seq( + imgs_lq = read_img_seq( img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) @@ -276,10 +274,10 @@ def __getitem__(self, index): imgs_lq, kernel_size=13, scale=self.opt['scale']) else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] - imgs_lq = util.read_img_seq(img_paths_lq) - img_gt = util.read_img_seq([self.imgs_gt[folder][idx]], - require_mod_crop=True, - scale=self.opt['scale']) + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]], + require_mod_crop=True, + scale=self.opt['scale']) img_gt.squeeze_(0) return { diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py index 4258781..fb38e1b 100644 --- a/basicsr/metrics/metric_util.py +++ b/basicsr/metrics/metric_util.py @@ -28,7 +28,6 @@ def reorder_image(img, input_order='HWC'): img = img[..., None] if input_order == 'CHW': img = img.transpose(1, 2, 0) - img = img.astype(np.float64) return img diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py index 22dc0fd..faef700 100644 --- a/basicsr/metrics/psnr_ssim.py +++ b/basicsr/metrics/psnr_ssim.py @@ -34,6 +34,8 @@ def calculate_psnr(img1, '"HWC" and "CHW"') img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) if crop_border != 0: img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] @@ -122,6 +124,8 @@ def calculate_ssim(img1, '"HWC" and "CHW"') img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) if crop_border != 0: img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/basicsr/models/archs/dfdnet_arch.py b/basicsr/models/archs/dfdnet_arch.py index e03dfc0..e4d3c9f 100644 --- a/basicsr/models/archs/dfdnet_arch.py +++ b/basicsr/models/archs/dfdnet_arch.py @@ -54,28 +54,6 @@ def forward(self, x, updated_feat): return out -class VGGFaceFeatureExtractor(VGGFeatureExtractor): - - def preprocess(self, x): - # norm to [0, 1] - x = (x + 1) / 2 - if self.use_input_norm: - x = (x - self.mean) / self.std - if x.shape[3] < 224: - x = torch.nn.functional.interpolate( - x, size=(224, 224), mode='bilinear', align_corners=False) - return x - - def forward(self, x): - x = self.preprocess(x) - features = [] - for key, layer in self.vgg_net._modules.items(): - x = layer(x) - if key in self.layer_name_list: - features.append(x) - return features - - class DFDNet(nn.Module): """DFDNet: Deep Face Dictionary Network. @@ -88,14 +66,15 @@ def __init__(self, num_feat, dict_path): # part_sizes: [80, 80, 50, 110] channel_sizes = [128, 256, 512, 512] self.feature_sizes = np.array([256, 128, 64, 32]) + self.vgg_layers = ['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'] self.flag_dict_device = False # dict self.dict = torch.load(dict_path) # vgg face extractor - self.vgg_extractor = VGGFaceFeatureExtractor( - layer_name_list=['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'], + self.vgg_extractor = VGGFeatureExtractor( + layer_name_list=self.vgg_layers, vgg_type='vgg19', use_input_norm=True, requires_grad=False) @@ -175,9 +154,9 @@ def forward(self, x, part_locations): # update vggface features using the dictionary for each part updated_vgg_features = [] batch = 0 # only supports testing with batch size = 0 - for i, f_size in enumerate(self.feature_sizes): + for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes): dict_features = self.dict[f'{f_size}'] - vgg_feat = vgg_features[i] + vgg_feat = vgg_features[vgg_layer] updated_feat = vgg_feat.clone() # swap features from dictionary @@ -190,7 +169,7 @@ def forward(self, x, part_locations): updated_vgg_features.append(updated_feat) - vgg_feat_dilation = self.multi_scale_dilation(vgg_features[3]) + vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4']) # use updated vgg features to modulate the upsampled features with # SFT (Spatial Feature Transform) scaling and shifting manner. upsampled_feat = self.upsample0(vgg_feat_dilation, diff --git a/basicsr/models/archs/vgg_arch.py b/basicsr/models/archs/vgg_arch.py index 251b794..5b1574a 100644 --- a/basicsr/models/archs/vgg_arch.py +++ b/basicsr/models/archs/vgg_arch.py @@ -70,6 +70,8 @@ class VGGFeatureExtractor(nn.Module): vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net @@ -81,6 +83,7 @@ def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True, + range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2): @@ -88,6 +91,7 @@ def __init__(self, self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm + self.range_norm = range_norm self.names = NAMES[vgg_type.replace('_bn', '')] if 'bn' in vgg_type: @@ -153,7 +157,8 @@ def forward(self, x): Returns: Tensor: Forward results. """ - + if self.range_norm: + x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std diff --git a/basicsr/models/losses/loss_utils.py b/basicsr/models/losses/loss_util.py similarity index 100% rename from basicsr/models/losses/loss_utils.py rename to basicsr/models/losses/loss_util.py diff --git a/basicsr/models/losses/losses.py b/basicsr/models/losses/losses.py index 4cbc5e8..2df8d75 100644 --- a/basicsr/models/losses/losses.py +++ b/basicsr/models/losses/losses.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from basicsr.models.archs.vgg_arch import VGGFeatureExtractor -from basicsr.models.losses.loss_utils import weighted_loss +from basicsr.models.losses.loss_util import weighted_loss _reduction_modes = ['none', 'mean', 'sum'] @@ -155,17 +155,14 @@ class PerceptualLoss(nn.Module): Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. - norm_img (bool): If True, the image will be normed to [0, 1]. Note that - this is different from the `use_input_norm` which norm the input in - in forward function of vgg according to the statistics of dataset. - Importantly, the input image must be in range [-1, 1]. - Default: False. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ @@ -173,19 +170,19 @@ def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, + range_norm=False, perceptual_weight=1.0, style_weight=0., - norm_img=False, criterion='l1'): super(PerceptualLoss, self).__init__() - self.norm_img = norm_img self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, - use_input_norm=use_input_norm) + use_input_norm=use_input_norm, + range_norm=range_norm) self.criterion_type = criterion if self.criterion_type == 'l1': @@ -208,11 +205,6 @@ def forward(self, x, gt): Returns: Tensor: Forward results. """ - - if self.norm_img: - x = (x + 1.) * 0.5 - gt = (gt + 1.) * 0.5 - # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index fa3401e..2b91571 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -2,7 +2,7 @@ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img from .logger import (MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger) -from .util import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, +from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt) __all__ = [ @@ -20,7 +20,7 @@ 'init_wandb_logger', 'get_root_logger', 'get_env_info', - # util.py + # misc.py 'set_random_seed', 'get_time_str', 'mkdir_and_rename', diff --git a/basicsr/utils/crawler_util.py b/basicsr/utils/crawler_util.py deleted file mode 100644 index b65dda3..0000000 --- a/basicsr/utils/crawler_util.py +++ /dev/null @@ -1,35 +0,0 @@ -import requests - - -def baidu_decode_url(encrypted_url): - """Decrypt baidu ecrypted url.""" - url = encrypted_url - map1 = {'_z2C$q': ':', '_z&e3B': '.', 'AzdH3F': '/'} - map2 = { - 'w': 'a', 'k': 'b', 'v': 'c', '1': 'd', 'j': 'e', - 'u': 'f', '2': 'g', 'i': 'h', 't': 'i', '3': 'j', - 'h': 'k', 's': 'l', '4': 'm', 'g': 'n', '5': 'o', - 'r': 'p', 'q': 'q', '6': 'r', 'f': 's', 'p': 't', - '7': 'u', 'e': 'v', 'o': 'w', '8': '1', 'd': '2', - 'n': '3', '9': '4', 'c': '5', 'm': '6', '0': '7', - 'b': '8', 'l': '9', 'a': '0' - } # yapf: disable - for (ciphertext, plaintext) in map1.items(): - url = url.replace(ciphertext, plaintext) - char_list = [char for char in url] - for i in range(len(char_list)): - if char_list[i] in map2: - char_list[i] = map2[char_list[i]] - url = ''.join(char_list) - return url - - -def setup_session(): - headers = { - 'User-Agent': ('Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_3)' - ' AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/48.0.2564.116 Safari/537.36') - } - session = requests.Session() - session.headers.update(headers) - return session diff --git a/basicsr/utils/download.py b/basicsr/utils/download_util.py similarity index 98% rename from basicsr/utils/download.py rename to basicsr/utils/download_util.py index d27266e..64a0016 100644 --- a/basicsr/utils/download.py +++ b/basicsr/utils/download_util.py @@ -2,7 +2,7 @@ import requests from tqdm import tqdm -from .util import sizeof_fmt +from .misc import sizeof_fmt def download_file_from_google_drive(file_id, save_path): diff --git a/basicsr/utils/lmdb.py b/basicsr/utils/lmdb_util.py similarity index 100% rename from basicsr/utils/lmdb.py rename to basicsr/utils/lmdb_util.py diff --git a/basicsr/utils/util.py b/basicsr/utils/misc.py similarity index 100% rename from basicsr/utils/util.py rename to basicsr/utils/misc.py diff --git a/docs/DatasetPreparation.md b/docs/DatasetPreparation.md index 8434cb1..207df31 100644 --- a/docs/DatasetPreparation.md +++ b/docs/DatasetPreparation.md @@ -115,7 +115,7 @@ For convenience, the binary content stored in LMDB dataset is encoded image by c **How to Make LMDB** We provide a script to make LMDB. Before running the script, we need to modify the corresponding parameters accordingly. At present, we support DIV2K, REDS and Vimeo90K datasets; other datasets can also be made in a similar way.
- `python scripts/create_lmdb.py` + `python scripts/data_preparation/create_lmdb.py` #### Data Pre-fetcher @@ -155,17 +155,17 @@ It is recommended to symlink the dataset root to `datasets` with the command `ln 1. Download the datasets from the [official DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/).
1. Crop to sub-images: DIV2K has 2K resolution (e.g., 2048 × 1080) images but the training patches are usually small (e.g., 128x128 or 192x192). So there is a waste if reading the whole image but only using a very small part of it. In order to accelerate the IO speed during training, we crop the 2K resolution images to sub-images (here, we crop to 480x480 sub-images).
Note that the size of sub-images is different from the training patch size (`gt_size`) defined in the config file. Specifically, the cropped sub-images with 480x480 are stored. The dataloader will further randomly crop the sub-images to `GT_size x GT_size` patches for training.
- Run the script [extract_subimages.py](../scripts/extract_subimages.py): + Run the script [extract_subimages.py](../scripts/data_preparation/extract_subimages.py): ```python - python scripts/extract_subimages.py + python scripts/data_preparation/extract_subimages.py ``` Remember to modify the paths and configurations if you have different settings. -1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly. +1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_paired_image_dataset.py`. Remember to modify the paths and configurations accordingly. -1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/generate_meta_info.py` to generate the meta_info_file. +1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/data_preparation/generate_meta_info.py` to generate the meta_info_file. ### Common Image SR Datasets @@ -277,8 +277,8 @@ All the left clips are used for training. Note that it it not required to explic **Preparation Steps** 1. Download the datasets from the [official website](https://seungjunnah.github.io/Datasets/reds.html). -1. Regroup the training and validation datasets: `python scripts/regroup_reds_dataset.py` -1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly. +1. Regroup the training and validation datasets: `python scripts/data_preparation/regroup_reds_dataset.py` +1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_reds_dataset.py`. Remember to modify the paths and configurations accordingly. @@ -289,7 +289,7 @@ Remember to modify the paths and configurations accordingly. 1. Download the dataset: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip).This is the Ground-Truth (GT). There is a `sep_trainlist.txt` file listing the training samples in the download zip file. 1. Generate the low-resolution images (TODO) The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images. -1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly. +1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_vimeo90k_dataset.py`. Remember to modify the paths and configurations accordingly. @@ -303,5 +303,5 @@ Training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset). 1. Extract tfrecords to images or LMDBs. (TensorFlow is required to read tfrecords). For each resolution, we will create images folder or LMDB files separately. ```bash - python scripts/extract_images_from_tfrecords.py + python scripts/data_preparation/extract_images_from_tfrecords.py ``` diff --git a/docs/DatasetPreparation_CN.md b/docs/DatasetPreparation_CN.md index a71adf2..b3e90a0 100644 --- a/docs/DatasetPreparation_CN.md +++ b/docs/DatasetPreparation_CN.md @@ -116,7 +116,7 @@ DIV2K_train_HR_sub.lmdb **如何制作** 我们提供了脚本来制作. 在运行脚本前, 需要根据需求修改相应的参数. 目前支持 DIV2K, REDS 和 Vimeo90K 数据集; 其他数据集可仿照进行制作.
- `python scripts/create_lmdb.py` + `python scripts/data_preparation/create_lmdb.py` #### 预读取数据 @@ -155,17 +155,17 @@ DIV2K 数据集被广泛使用在图像复原的任务中. 1. 从[官网](https://data.vision.ee.ethz.ch/cvl/DIV2K)下载数据. 1. Crop to sub-images: 因为 DIV2K 数据集是 2K 分辨率的 (比如: 2048x1080), 而我们在训练的时候往往并不要那么大 (常见的是 128x128 或者 192x192 的训练patch). 因此我们可以先把2K的图片裁剪成有overlap的 480x480 的子图像块. 然后再由 dataloader 从这个 480x480 的子图像块中随机crop出 128x128 或者 192x192 的训练patch.
- 运行脚本 [extract_subimages.py](../scripts/extract_subimages.py): + 运行脚本 [extract_subimages.py](../scripts/data_preparation/extract_subimages.py): ```python - python scripts/extract_subimages.py + python scripts/data_preparation/extract_subimages.py ``` 使用之前可能需要修改文件里面的路径和配置参数. **注意**: sub-image 的尺寸和训练patch的尺寸 (`gt_size`) 是不同的. 我们先把2K分辨率的图像 crop 成 sub-images (往往是 480x480), 然后存储起来. 在训练的时候, dataloader会读取这些sub-images, 然后进一步随机裁剪成 `gt_size` x `gt_size`的大小. -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径. +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径. 1. 测试: `tests/test_paired_image_dataset.py`, 注意修改函数相应的配置和路径. -1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/generate_meta_info.py` 来生成 meta_info_file. +1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/data_preparation/generate_meta_info.py` 来生成 meta_info_file. ### 其他常见图像超分数据集 @@ -277,8 +277,8 @@ DIV2K 数据集被广泛使用在图像复原的任务中. **数据准备步骤** 1. 从[官网](https://seungjunnah.github.io/Datasets/reds.html)下载数据 -1. 整合 training 和 validation 数据: `python scripts/regroup_reds_dataset.py` -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径. +1. 整合 training 和 validation 数据: `python scripts/data_preparation/regroup_reds_dataset.py` +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径. 1. 测试: `python tests/test_reds_dataset.py`, 注意修改函数相应的配置和路径. ### Vimeo90K @@ -290,7 +290,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中. 1. 下载数据: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip). 这些是Ground-Truth. 里面有`sep_trainlist.txt`文件来区分训练数据. 1. 生成低分辨率图片. (TODO) The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images. -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径. +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径. 1. 测试: `python tests/test_vimeo90k_dataset.py`, 注意修改函数相应的配置和路径. ## StyleGAN2 @@ -303,5 +303,5 @@ The low-resolution images in the Vimeo90K test dataset are generated with the MA 1. 从 tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). 我们对每一个分辨率的人脸都单独创建文件夹或者LMDB文件. ```bash - python scripts/extract_images_from_tfrecords.py + python scripts/data_preparation/extract_images_from_tfrecords.py ``` diff --git a/docs/HOWTOs.md b/docs/HOWTOs.md index a2d6433..ddda95f 100644 --- a/docs/HOWTOs.md +++ b/docs/HOWTOs.md @@ -8,23 +8,23 @@ 1. Download FFHQ dataset. Recommend to download the tfrecords files from [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset). 1. Extract tfrecords to images or LMDBs (TensorFlow is required to read tfrecords): - > python scripts/extract_images_from_tfrecords.py + > python scripts/data_preparation/extract_images_from_tfrecords.py 1. Modify the config file in `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml` 1. Train with distributed training. More training commands are in [TrainTest.md](TrainTest.md). > python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ_800k.yml --launcher pytorch -## How to test StyleGAN2 +## How to inference StyleGAN2 1. Download pre-trained models from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models` folder. 1. Test. - > python tests/test_stylegan2.py + > python inference/inference_stylegan2.py 1. The results are in the `samples` folder. -## How to test DFDNet +## How to inference DFDNet 1. Install [dlib](http://dlib.net/), because DFDNet uses dlib to do face recognition and landmark detection. [Installation reference](https://github.com/davisking/dlib). 1. Clone dlib repo: `git clone git@github.com:davisking/dlib.git` @@ -43,6 +43,6 @@ 4. Prepare the testing dataset in the `datasets`, for example, we put images in the `datasets/TestWhole` folder. 5. Test. - > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole + > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole 6. The results are in the `results/DFDNet` folder. diff --git a/docs/HOWTOs_CN.md b/docs/HOWTOs_CN.md index aad7f25..df2ab25 100644 --- a/docs/HOWTOs_CN.md +++ b/docs/HOWTOs_CN.md @@ -8,7 +8,7 @@ 1. 下载 FFHQ 数据集. 推荐从 [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset) 下载 tfrecords 文件. 1. 从tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). - > python scripts/extract_images_from_tfrecords.py + > python scripts/data_preparation/extract_images_from_tfrecords.py 1. 修改配置文件 `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml` 1. 使用分布式训练. 更多训练命令: [TrainTest_CN.md](TrainTest_CN.md) @@ -20,7 +20,7 @@ 1. 从 **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) 下载预训练模型到 `experiments/pretrained_models` 文件夹. 1. 测试. - > python tests/test_stylegan2.py + > python inference/inference_stylegan2.py 1. 结果在 `samples` 文件夹 @@ -43,6 +43,6 @@ 4. 准备测试图片到 `datasets`, 比如说我们把测试图片放在 `datasets/TestWhole` 文件夹. 5. 测试. - > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole + > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole 6. 结果在 `results/DFDNet` 文件夹. diff --git a/test_scripts/test_face_dfdnet.py b/inference/inference_dfdnet.py similarity index 100% rename from test_scripts/test_face_dfdnet.py rename to inference/inference_dfdnet.py diff --git a/test_scripts/test_esrgan.py b/inference/inference_esrgan.py similarity index 100% rename from test_scripts/test_esrgan.py rename to inference/inference_esrgan.py diff --git a/test_scripts/test_stylegan2.py b/inference/inference_stylegan2.py similarity index 100% rename from test_scripts/test_stylegan2.py rename to inference/inference_stylegan2.py diff --git a/make.sh b/make.sh deleted file mode 100644 index 1990c6b..0000000 --- a/make.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -# You may need to modify the following paths before compiling -CUDA_HOME=/usr/local/cuda \ -CUDNN_INCLUDE_DIR=/usr/local/cuda \ -CUDNN_LIB_DIR=/usr/local/cuda \ -python setup.py develop diff --git a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml index c59be62..ec5a78e 100644 --- a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml +++ b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml @@ -107,9 +107,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + range_norm: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml index 7d8faba..057de06 100644 --- a/options/train/ESRGAN/train_ESRGAN_x4.yml +++ b/options/train/ESRGAN/train_ESRGAN_x4.yml @@ -91,9 +91,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + range_norm: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml index a0d4ace..e3681f2 100644 --- a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml +++ b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml @@ -96,9 +96,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + scale: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/requirements.txt b/requirements.txt index afd2ca1..c014cdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ addict future lmdb -matplotlib numpy opencv-python Pillow diff --git a/scripts/create_lmdb.py b/scripts/data_preparation/create_lmdb.py similarity index 89% rename from scripts/create_lmdb.py rename to scripts/data_preparation/create_lmdb.py index 6c9787b..e8eec3b 100644 --- a/scripts/create_lmdb.py +++ b/scripts/data_preparation/create_lmdb.py @@ -1,7 +1,8 @@ +import argparse from os import path as osp from basicsr.utils import scandir -from basicsr.utils.lmdb import make_lmdb_from_imgs +from basicsr.utils.lmdb_util import make_lmdb_from_imgs def create_lmdb_for_div2k(): @@ -160,6 +161,22 @@ def prepare_keys_vimeo90k(folder_path, train_list_path, mode): if __name__ == '__main__': - create_lmdb_for_div2k() - # create_lmdb_for_reds() - # create_lmdb_for_vimeo90k() + parser = argparse.ArgumentParser() + + parser.add_argument( + '--dataset', + type=str, + help=( + "Options: 'DIV2K', 'REDS', 'Vimeo90K' " + 'You may need to modify the corresponding configurations in codes.' + )) + args = parser.parse_args() + dataset = args.dataset.lower() + if dataset == 'div2k': + create_lmdb_for_div2k() + elif dataset == 'reds': + create_lmdb_for_reds() + elif dataset == 'vimeo90k': + create_lmdb_for_vimeo90k() + else: + raise ValueError('Wrong dataset.') diff --git a/scripts/download_datasets.py b/scripts/data_preparation/download_datasets.py similarity index 97% rename from scripts/download_datasets.py rename to scripts/data_preparation/download_datasets.py index bd4ebd9..215e3c8 100644 --- a/scripts/download_datasets.py +++ b/scripts/data_preparation/download_datasets.py @@ -3,7 +3,7 @@ import os from os import path as osp -from basicsr.utils.download import download_file_from_google_drive +from basicsr.utils.download_util import download_file_from_google_drive def download_dataset(dataset, file_ids): diff --git a/scripts/extract_images_from_tfrecords.py b/scripts/data_preparation/extract_images_from_tfrecords.py similarity index 98% rename from scripts/extract_images_from_tfrecords.py rename to scripts/data_preparation/extract_images_from_tfrecords.py index 8e0706b..14a4f67 100644 --- a/scripts/extract_images_from_tfrecords.py +++ b/scripts/data_preparation/extract_images_from_tfrecords.py @@ -4,7 +4,7 @@ import numpy as np import os -from basicsr.utils.lmdb import LmdbMaker +from basicsr.utils.lmdb_util import LmdbMaker def convert_celeba_tfrecords(tf_file, @@ -171,7 +171,7 @@ def make_ffhq_lmdb_from_imgs(folder_path, if __name__ == '__main__': """Read tfrecords w/o define a graph. - We have tested it on on TensorFlow 1.15 + We have tested it on TensorFlow 1.15 Ref: http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ diff --git a/scripts/extract_subimages.py b/scripts/data_preparation/extract_subimages.py similarity index 99% rename from scripts/extract_subimages.py rename to scripts/data_preparation/extract_subimages.py index 4845ca8..6424e8d 100644 --- a/scripts/extract_subimages.py +++ b/scripts/data_preparation/extract_subimages.py @@ -6,7 +6,7 @@ from os import path as osp from tqdm import tqdm -from basicsr.utils.util import scandir +from basicsr.utils import scandir def main(): diff --git a/scripts/generate_meta_info.py b/scripts/data_preparation/generate_meta_info.py similarity index 100% rename from scripts/generate_meta_info.py rename to scripts/data_preparation/generate_meta_info.py diff --git a/scripts/regroup_reds_dataset.py b/scripts/data_preparation/regroup_reds_dataset.py similarity index 100% rename from scripts/regroup_reds_dataset.py rename to scripts/data_preparation/regroup_reds_dataset.py diff --git a/scripts/gdrive_download.py b/scripts/download_gdrive.py similarity index 80% rename from scripts/gdrive_download.py rename to scripts/download_gdrive.py index e67cad9..c3e34c7 100644 --- a/scripts/gdrive_download.py +++ b/scripts/download_gdrive.py @@ -1,6 +1,6 @@ import argparse -from basicsr.utils.download import download_file_from_google_drive +from basicsr.utils.download_util import download_file_from_google_drive if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py index e6eb06f..3514a68 100644 --- a/scripts/download_pretrained_models.py +++ b/scripts/download_pretrained_models.py @@ -2,7 +2,7 @@ import os from os import path as osp -from basicsr.utils.download import download_file_from_google_drive +from basicsr.utils.download_util import download_file_from_google_drive def download_pretrained_models(method, file_ids): diff --git a/scripts/calculate_fid_folder.py b/scripts/metrics/calculate_fid_folder.py similarity index 100% rename from scripts/calculate_fid_folder.py rename to scripts/metrics/calculate_fid_folder.py diff --git a/scripts/calculate_fid_stats_from_datasets.py b/scripts/metrics/calculate_fid_stats_from_datasets.py similarity index 100% rename from scripts/calculate_fid_stats_from_datasets.py rename to scripts/metrics/calculate_fid_stats_from_datasets.py diff --git a/scripts/calculate_lpips.py b/scripts/metrics/calculate_lpips.py similarity index 100% rename from scripts/calculate_lpips.py rename to scripts/metrics/calculate_lpips.py diff --git a/scripts/calculate_psnr_ssim.py b/scripts/metrics/calculate_psnr_ssim.py similarity index 99% rename from scripts/calculate_psnr_ssim.py rename to scripts/metrics/calculate_psnr_ssim.py index f73ccde..1a14af5 100644 --- a/scripts/calculate_psnr_ssim.py +++ b/scripts/metrics/calculate_psnr_ssim.py @@ -25,7 +25,7 @@ def main(): crop_border = 4 suffix = '_expname' test_y_channel = False - correct_mean_var = True + correct_mean_var = False # ------------------------------------------------------------------------- psnr_all = [] diff --git a/scripts/calculate_stylegan2_fid.py b/scripts/metrics/calculate_stylegan2_fid.py similarity index 100% rename from scripts/calculate_stylegan2_fid.py rename to scripts/metrics/calculate_stylegan2_fid.py diff --git a/scripts/convert_dfdnet.py b/scripts/model_conversion/convert_dfdnet.py similarity index 100% rename from scripts/convert_dfdnet.py rename to scripts/model_conversion/convert_dfdnet.py diff --git a/scripts/convert_models.py b/scripts/model_conversion/convert_models.py similarity index 100% rename from scripts/convert_models.py rename to scripts/model_conversion/convert_models.py diff --git a/scripts/convert_stylegan.py b/scripts/model_conversion/convert_stylegan.py similarity index 100% rename from scripts/convert_stylegan.py rename to scripts/model_conversion/convert_stylegan.py diff --git a/setup.cfg b/setup.cfg index 62caaa6..ae5a6eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr -known_third_party = PIL,cv2,lmdb,matplotlib,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml +known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 9562ffd..b9642d1 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,10 +1,14 @@ -import matplotlib as mpl import torch -from matplotlib import pyplot as plt -from matplotlib import ticker as mtick from basicsr.models.lr_scheduler import CosineAnnealingRestartLR +try: + import matplotlib as mpl + from matplotlib import pyplot as plt + from matplotlib import ticker as mtick +except ImportError: + print('Please install matplotlib.') + mpl.use('Agg') From 61efb21eba31bbec7c3966a34421c74f783d9471 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sat, 28 Nov 2020 21:48:15 +0800 Subject: [PATCH 18/23] update readme --- README.md | 19 +++++++++++++++---- README_CN.md | 19 +++++++++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 684bcb1..dc35171 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,21 @@ BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and vide We provides simple pipelines to train/test/inference models for quick start. These pipelines/commands cannot cover all the cases and more details are in the following sections. -| | | | -| :--- | :---: | :---: | -| StyleGAN2 | [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2) | [How to inference StyleGAN2](docs/HOWTOs.md#How-to-inference-StyleGAN2) | -| DFDNet | *TODO* | [How to inference DFDNet](docs/HOWTOs.md#How-to-inference-DFDNet) | +| GAN | | | | | | +| :--- | :---: | :---: | :--- | :---: | :---: | +| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | | +| **Face Restoration** | | | | | | +| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | | +| **Super Resolution** | | | | | | +| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*| +| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*| +| RCAN | *TODO* | *TODO* | | | | +| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | +| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | +| **Deblurring** | | | | | | +| DeblurGANv2 | - | *TODO* | | | | +| **Denoise** | | | | | | +| RIDNet | - | *TODO* | CBDNet | - | *TODO*| ## :wrench: Dependencies and Installation diff --git a/README_CN.md b/README_CN.md index 57ff71b..ad0d2bb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -36,10 +36,21 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. -| | | | -| :--- | :---: | :---: | -| StyleGAN2 | [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | -| DFDNet | - | [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet) | +| GAN | | | | | | +| :--- | :---: | :---: | :--- | :---: | :---: | +| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | | +| **Face Restoration** | | | | | | +| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | | +| **Super Resolution** | | | | | | +| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*| +| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*| +| RCAN | *TODO* | *TODO* | | | | +| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | +| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | +| **Deblurring** | | | | | | +| DeblurGANv2 | - | *TODO* | | | | +| **Denoise** | | | | | | +| RIDNet | - | *TODO* | CBDNet | - | *TODO*| ## :wrench: 依赖和安装 From c0ba07d9acf102048fe3a6007dbe516c76a236ec Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 29 Nov 2020 11:34:45 +0800 Subject: [PATCH 19/23] DFDNet refactor (#339) * fixbug: vgg extractor range_norm * fixbug: vgg extractor after relu * fixbug: vgg extractor after relu * separate FaceRestorationHelper * rm FFHQ_5_landmarks_template_1024 --- basicsr/models/archs/dfdnet_arch.py | 3 +- basicsr/utils/face_util.py | 217 +++++++++++++++++++++++++ inference/inference_dfdnet.py | 219 +------------------------- scripts/download_pretrained_models.py | 4 +- 4 files changed, 223 insertions(+), 220 deletions(-) create mode 100644 basicsr/utils/face_util.py diff --git a/basicsr/models/archs/dfdnet_arch.py b/basicsr/models/archs/dfdnet_arch.py index e4d3c9f..c887d90 100644 --- a/basicsr/models/archs/dfdnet_arch.py +++ b/basicsr/models/archs/dfdnet_arch.py @@ -66,7 +66,7 @@ def __init__(self, num_feat, dict_path): # part_sizes: [80, 80, 50, 110] channel_sizes = [128, 256, 512, 512] self.feature_sizes = np.array([256, 128, 64, 32]) - self.vgg_layers = ['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'] + self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4'] self.flag_dict_device = False # dict @@ -77,6 +77,7 @@ def __init__(self, num_feat, dict_path): layer_name_list=self.vgg_layers, vgg_type='vgg19', use_input_norm=True, + range_norm=True, requires_grad=False) # attention block for fusing dictionary features and input features diff --git a/basicsr/utils/face_util.py b/basicsr/utils/face_util.py new file mode 100644 index 0000000..33fe178 --- /dev/null +++ b/basicsr/utils/face_util.py @@ -0,0 +1,217 @@ +import cv2 +import numpy as np +import os +import torch +from skimage import transform as trans + +from basicsr.utils import imwrite + +try: + import dlib +except ImportError: + print('Please install dlib before testing face restoration.' + 'Reference: https://github.com/davisking/dlib') + + +class FaceRestorationHelper(object): + """Helper for the face restoration pipeline.""" + + def __init__(self, upscale_factor, face_size=512): + self.upscale_factor = upscale_factor + self.face_size = (face_size, face_size) + + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], + [586.77227723, 493.59405941], + [337.91089109, 488.38613861], + [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + # for estimation the 2D similarity transformation + self.similarity_trans = trans.SimilarityTransform() + + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.save_png = True + + def init_dlib(self, detection_path, landmark5_path, landmark68_path): + """Initialize the dlib detectors and predictors.""" + self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) + self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) + self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) + + def free_dlib_gpu_memory(self): + del self.face_detector + del self.shape_predictor_5 + del self.shape_predictor_68 + + def read_input_image(self, img_path): + # self.input_img is Numpy array, (h, w, c) with RGB order + self.input_img = dlib.load_rgb_image(img_path) + + def detect_faces(self, + img_path, + upsample_num_times=1, + only_keep_largest=False): + """ + Args: + img_path (str): Image path. + upsample_num_times (int): Upsamples the image before running the + face detector + + Returns: + int: Number of detected faces. + """ + self.read_input_image(img_path) + det_faces = self.face_detector(self.input_img, upsample_num_times) + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - + det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - + det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + return len(self.det_faces) + + def get_face_landmarks_5(self): + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + return len(self.all_landmarks_5) + + def get_face_landmarks_68(self): + """Get 68 densemarks for cropped images. + + Should only have one face at most in the cropped image. + """ + num_detected_face = 0 + for idx, face in enumerate(self.cropped_faces): + # face detection + det_face = self.face_detector(face, 1) # TODO: can we remove it? + if len(det_face) == 0: + print(f'Cannot find faces in cropped image with index {idx}.') + self.all_landmarks_68.append(None) + else: + if len(det_face) > 1: + print('Detect several faces in the cropped face. Use the ' + ' largest one. Note that it will also cause overlap ' + 'during paste_faces_to_input_image.') + face_areas = [] + for i in range(len(det_face)): + face_area = (det_face[i].rect.right() - + det_face[i].rect.left()) * ( + det_face[i].rect.bottom() - + det_face[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + face_rect = det_face[largest_idx].rect + else: + face_rect = det_face[0].rect + shape = self.shape_predictor_68(face, face_rect) + landmark = np.array([[part.x, part.y] + for part in shape.parts()]) + self.all_landmarks_68.append(landmark) + num_detected_face += 1 + + return num_detected_face + + def warp_crop_faces(self, + save_cropped_path=None, + save_inverse_affine_path=None): + """Get affine matrix, warp and cropped faces. + + Also get inverse affine matrix for post-processing. + """ + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + self.similarity_trans.estimate(landmark, self.face_template) + affine_matrix = self.similarity_trans.params[0:2, :] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + cropped_face = cv2.warpAffine(self.input_img, affine_matrix, + self.face_size) + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path, ext = os.path.splitext(save_cropped_path) + if self.save_png: + save_path = f'{path}_{idx:02d}.png' + else: + save_path = f'{path}_{idx:02d}{ext}' + + imwrite( + cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) + + # get inverse affine matrix + self.similarity_trans.estimate(self.face_template, + landmark * self.upscale_factor) + inverse_affine = self.similarity_trans.params[0:2, :] + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path): + # operate in the BGR order + input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) + h, w, _ = input_img.shape + h_up, w_up = h * self.upscale_factor, w * self.upscale_factor + # simply resize the background + upsample_img = cv2.resize(input_img, (w_up, h_up)) + assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( + 'length of restored_faces and affine_matrices are different.') + for restored_face, inverse_affine in zip(self.restored_faces, + self.inverse_affine_matrices): + inv_restored = cv2.warpAffine(restored_face, inverse_affine, + (w_up, h_up)) + mask = np.ones((*self.face_size, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, + np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), + np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode( + inv_mask_erosion, + np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, + (blur_size + 1, blur_size + 1), 0) + upsample_img = inv_soft_mask * inv_restored_remove_border + ( + 1 - inv_soft_mask) * upsample_img + if self.save_png: + save_path = save_path.replace('.jpg', + '.png').replace('.jpeg', '.png') + imwrite(upsample_img.astype(np.uint8), save_path) + + def clean_all(self): + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] diff --git a/inference/inference_dfdnet.py b/inference/inference_dfdnet.py index 331e731..982c524 100644 --- a/inference/inference_dfdnet.py +++ b/inference/inference_dfdnet.py @@ -1,220 +1,14 @@ import argparse -import cv2 import glob import numpy as np import os import torch import torchvision.transforms as transforms from skimage import io -from skimage import transform as trans from basicsr.models.archs.dfdnet_arch import DFDNet from basicsr.utils import imwrite, tensor2img - -try: - import dlib -except ImportError: - print('Please install dlib before testing face restoration.' - 'Reference: https://github.com/davisking/dlib') - - -class FaceRestorationHelper(object): - """Helper for the face restoration pipeline.""" - - def __init__(self, upscale_factor, face_template_path, out_size=512): - self.upscale_factor = upscale_factor - self.out_size = (out_size, out_size) - - # standard 5 landmarks for FFHQ faces with 1024 x 1024 - self.face_template = np.load(face_template_path) / (1024 // out_size) - # for estimation the 2D similarity transformation - self.similarity_trans = trans.SimilarityTransform() - - self.all_landmarks_5 = [] - self.all_landmarks_68 = [] - self.affine_matrices = [] - self.inverse_affine_matrices = [] - self.cropped_faces = [] - self.restored_faces = [] - self.save_png = True - - def init_dlib(self, detection_path, landmark5_path, landmark68_path): - """Initialize the dlib detectors and predictors.""" - self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) - self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) - self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) - - def free_dlib_gpu_memory(self): - del self.face_detector - del self.shape_predictor_5 - del self.shape_predictor_68 - - def read_input_image(self, img_path): - # self.input_img is Numpy array, (h, w, c) with RGB order - self.input_img = dlib.load_rgb_image(img_path) - - def detect_faces(self, - img_path, - upsample_num_times=1, - only_keep_largest=False): - """ - Args: - img_path (str): Image path. - upsample_num_times (int): Upsamples the image before running the - face detector - - Returns: - int: Number of detected faces. - """ - self.read_input_image(img_path) - det_faces = self.face_detector(self.input_img, upsample_num_times) - if len(det_faces) == 0: - print('No face detected. Try to increase upsample_num_times.') - else: - if only_keep_largest: - print('Detect several faces and only keep the largest.') - face_areas = [] - for i in range(len(det_faces)): - face_area = (det_faces[i].rect.right() - - det_faces[i].rect.left()) * ( - det_faces[i].rect.bottom() - - det_faces[i].rect.top()) - face_areas.append(face_area) - largest_idx = face_areas.index(max(face_areas)) - self.det_faces = [det_faces[largest_idx]] - else: - self.det_faces = det_faces - return len(self.det_faces) - - def get_face_landmarks_5(self): - for face in self.det_faces: - shape = self.shape_predictor_5(self.input_img, face.rect) - landmark = np.array([[part.x, part.y] for part in shape.parts()]) - self.all_landmarks_5.append(landmark) - return len(self.all_landmarks_5) - - def get_face_landmarks_68(self): - """Get 68 densemarks for cropped images. - - Should only have one face at most in the cropped image. - """ - num_detected_face = 0 - for idx, face in enumerate(self.cropped_faces): - # face detection - det_face = self.face_detector(face, 1) # TODO: can we remove it - if len(det_face) == 0: - print(f'Cannot find faces in cropped image with index {idx}.') - self.all_landmarks_68.append(None) - else: - if len(det_face) > 1: - print('Detect several faces in the cropped face. Use the ' - ' largest one. Note that it will also cause overlap ' - 'during paste_faces_to_input_image.') - face_areas = [] - for i in range(len(det_face)): - face_area = (det_face[i].rect.right() - - det_face[i].rect.left()) * ( - det_face[i].rect.bottom() - - det_face[i].rect.top()) - face_areas.append(face_area) - largest_idx = face_areas.index(max(face_areas)) - face_rect = det_face[largest_idx].rect - else: - face_rect = det_face[0].rect - shape = self.shape_predictor_68(face, face_rect) - landmark = np.array([[part.x, part.y] - for part in shape.parts()]) - self.all_landmarks_68.append(landmark) - num_detected_face += 1 - - return num_detected_face - - def warp_crop_faces(self, - save_cropped_path=None, - save_inverse_affine_path=None): - """Get affine matrix, warp and cropped faces. - - Also get inverse affine matrix for post-processing. - """ - for idx, landmark in enumerate(self.all_landmarks_5): - # use 5 landmarks to get affine matrix - self.similarity_trans.estimate(landmark, self.face_template) - affine_matrix = self.similarity_trans.params[0:2, :] - self.affine_matrices.append(affine_matrix) - # warp and crop faces - cropped_face = cv2.warpAffine(self.input_img, affine_matrix, - self.out_size) - self.cropped_faces.append(cropped_face) - # save the cropped face - if save_cropped_path is not None: - path, ext = os.path.splitext(save_cropped_path) - if self.save_png: - save_path = f'{path}_{idx:02d}.png' - else: - save_path = f'{path}_{idx:02d}{ext}' - - imwrite( - cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) - - # get inverse affine matrix - self.similarity_trans.estimate(self.face_template, - landmark * self.upscale_factor) - inverse_affine = self.similarity_trans.params[0:2, :] - self.inverse_affine_matrices.append(inverse_affine) - # save inverse affine matrices - if save_inverse_affine_path is not None: - path, _ = os.path.splitext(save_inverse_affine_path) - save_path = f'{path}_{idx:02d}.pth' - torch.save(inverse_affine, save_path) - - def add_restored_face(self, face): - self.restored_faces.append(face) - - def paste_faces_to_input_image(self, save_path): - # operate in the BGR order - input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) - h, w, _ = input_img.shape - h_up, w_up = h * self.upscale_factor, w * self.upscale_factor - # simply resize the background - upsample_img = cv2.resize(input_img, (w_up, h_up)) - assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( - 'length of restored_faces and affine_matrices are different.') - for restored_face, inverse_affine in zip(self.restored_faces, - self.inverse_affine_matrices): - inv_restored = cv2.warpAffine(restored_face, inverse_affine, - (w_up, h_up)) - mask = np.ones((*self.out_size, 3), dtype=np.float32) - inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) - # remove the black borders - inv_mask_erosion = cv2.erode( - inv_mask, - np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), - np.uint8)) - inv_restored_remove_border = inv_mask_erosion * inv_restored - total_face_area = np.sum(inv_mask_erosion) // 3 - # compute the fusion edge based on the area of face - w_edge = int(total_face_area**0.5) // 20 - erosion_radius = w_edge * 2 - inv_mask_center = cv2.erode( - inv_mask_erosion, - np.ones((erosion_radius, erosion_radius), np.uint8)) - blur_size = w_edge * 2 - inv_soft_mask = cv2.GaussianBlur(inv_mask_center, - (blur_size + 1, blur_size + 1), 0) - upsample_img = inv_soft_mask * inv_restored_remove_border + ( - 1 - inv_soft_mask) * upsample_img - if self.save_png: - save_path = save_path.replace('.jpg', - '.png').replace('.jpeg', '.png') - imwrite(upsample_img.astype(np.uint8), save_path) - - def clean_all(self): - self.all_landmarks_5 = [] - self.all_landmarks_68 = [] - self.restored_faces = [] - self.affine_matrices = [] - self.cropped_faces = [] - self.inverse_affine_matrices = [] +from basicsr.utils.face_util import FaceRestorationHelper def get_part_location(landmarks): @@ -293,13 +87,7 @@ def get_part_location(landmarks): # official_adaption to True. parser.add_argument('--official_adaption', type=bool, default=True) - # The following are the paths for face template and dlib models - parser.add_argument( - '--face_template_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/DFDNet/FFHQ_5_landmarks_template_1024-90a00515.npy' # noqa: E501 - ) + # The following are the paths for dlib models parser.add_argument( '--detection_path', type=str, @@ -337,8 +125,7 @@ def get_part_location(landmarks): save_restore_root = os.path.join(result_root, 'restored_faces') save_final_root = os.path.join(result_root, 'final_results') - face_helper = FaceRestorationHelper( - args.upscale_factor, args.face_template_path, out_size=512) + face_helper = FaceRestorationHelper(args.upscale_factor, face_size=512) # scan all the jpg and png images for img_path in sorted( diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py index 3514a68..3eb6911 100644 --- a/scripts/download_pretrained_models.py +++ b/scripts/download_pretrained_models.py @@ -112,9 +112,7 @@ def download_pretrained_models(method, file_ids): 'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79', 'DFDNet_official-d1fa5650.pth': - '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe', - 'FFHQ_5_landmarks_template_1024-90a00515.npy': - '1IQdQcq9QnpW6YzRwDaNbpV-rJ1Cq7RUq' + '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe' }, 'dlib': { 'mmod_human_face_detector-4cb19393.dat': From 2e54945daee2bfd4cecf781722cf07b84878e0d0 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 29 Nov 2020 13:44:14 +0800 Subject: [PATCH 20/23] update readme --- README.md | 10 ++++------ colab/README.md | 8 +++++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index dc35171..b1d4ade 100644 --- a/README.md +++ b/README.md @@ -17,17 +17,15 @@ BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and vide ([HandyView](https://github.com/xinntao/HandyView), [HandyFigure](https://github.com/xinntao/HandyFigure), [HandyCrawler](https://github.com/xinntao/HandyCrawler), [HandyWriting](https://github.com/xinntao/HandyWriting)) ## :sparkles: New Features - -- Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. Note that it is slightly different from the official testing codes. - > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
- > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
+- Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab). +- Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. - Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
- > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
More
    +
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
diff --git a/colab/README.md b/colab/README.md index 729eb3e..0e83739 100644 --- a/colab/README.md +++ b/colab/README.md @@ -4,4 +4,10 @@ To maintain a small size of BasicSR repo, we do not include the original colab notebooks in this repo, but provide links to the google colab. -- [BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing) +| Face Restoration| | +| :--- | :---: | +|DFDNet | [BasicSR_inference_DFDNet.ipynb](https://colab.research.google.com/drive/1RoNDeipp9yPjI3EbpEbUhn66k5Uzg4n8?usp=sharing)| +| **Super-Resolution**| | +|ESRGAN |[BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing)| +| **Deblurring**| | +| **Denoise**| | From bf4910256addb677134e8958b6a24acafdd8f4ac Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 29 Nov 2020 13:58:07 +0800 Subject: [PATCH 21/23] update readme --- README.md | 9 +++++---- README_CN.md | 13 ++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index b1d4ade..2bfb77f 100644 --- a/README.md +++ b/README.md @@ -17,16 +17,17 @@ BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and vide ([HandyView](https://github.com/xinntao/HandyView), [HandyFigure](https://github.com/xinntao/HandyFigure), [HandyCrawler](https://github.com/xinntao/HandyCrawler), [HandyWriting](https://github.com/xinntao/HandyWriting)) ## :sparkles: New Features + - Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab). -- Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. +- Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet). - Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).
More
    -
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • -
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • -
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
  • +
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • +
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
diff --git a/README_CN.md b/README_CN.md index ad0d2bb..0067f7a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -18,17 +18,16 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 ## :sparkles: 新的特性 -- Sep 8, 2020. 添加 **盲人脸复原推理代码: [DFDNet](https://github.com/csxmli2016/DFDNet)**. 注意和官方代码有些微差异. - > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
- > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
-- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
- > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
+- Nov 29, 2020. 添加 **ESRGAN** and **DFDNet** [colab demo](colab). +- Sep 8, 2020. 添加 **盲人脸复原**测试代码: [DFDNet](https://github.com/csxmli2016/DFDNet). +- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).
更多
    -
  • Aug 19, 2020. 全新的 BasicSR v1.0.0 上线.
  • +
  • Sep 8, 2020. 添加 盲人脸复原 测试代码: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Aug 27, 2020. 添加 StyleGAN2 训练和测试代码.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • +
  • Aug 19, 2020. 全新的 BasicSR v1.0.0 上线.
From 5df13476b85257b7972c00d3cfd80c357ff0583d Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 29 Nov 2020 14:00:46 +0800 Subject: [PATCH 22/23] update readme --- README.md | 4 ++-- README_CN.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2bfb77f..49d5f68 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and vide
More
    -
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • -
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • +
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
diff --git a/README_CN.md b/README_CN.md index 0067f7a..09cd59c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -25,8 +25,8 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源
更多
    -
  • Sep 8, 2020. 添加 盲人脸复原 测试代码: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • -
  • Aug 27, 2020. 添加 StyleGAN2 训练和测试代码.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • +
  • Sep 8, 2020. 添加 盲人脸复原 测试代码: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Aug 27, 2020. 添加 StyleGAN2 训练和测试代码.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • Aug 19, 2020. 全新的 BasicSR v1.0.0 上线.
From 1464d8e5c1ab27c12285c21771c2fc5a71382d32 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 29 Nov 2020 14:04:37 +0800 Subject: [PATCH 23/23] :bookmark: VERSION 1.2.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 524cb55..26aaba0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.1 +1.2.0