Skip to content

Commit

Permalink
Merge branch 'master' of github.com:xinntao/BasicSR
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Sep 18, 2020
2 parents b399ac9 + b25509e commit 010d29a
Show file tree
Hide file tree
Showing 31 changed files with 151 additions and 127 deletions.
36 changes: 20 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@ Note that this version is not compatible with previous versions. If you want to

# :rocket: [BasicSR](https://github.com/xinntao/BasicSR)

[GitHub](https://github.com/xinntao/BasicSR) | [Gitee码云](https://gitee.com/xinntao/BasicSR) <br>
[English](README.md) | [简体中文](README_CN.md)
[English](README.md) **|** [简体中文](README_CN.md) &emsp; [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR)

:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) <br>
:chart_with_upwards_trend: [Training curves in wandb](https://app.wandb.ai/xintao/basicsr) <br>
:computer: [Commands for training and testing](docs/TrainTest.md) <br>
:zap: [HOWTOs](#zap-howtos)

---

BasicSR is an **open source** image and video super-resolution toolbox based on PyTorch (will extend to more restoration tasks in the future).<br>
<sub>([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN))</sub>
Expand Down Expand Up @@ -34,11 +41,11 @@ BasicSR is an **open source** image and video super-resolution toolbox based on
We provides simple pipelines to train/test/inference models for quick start.
These pipelines/commands cannot cover all the cases and more details are in the following sections.

- :zap: [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2)
- :zap: [How to test StyleGAN2](docs/HOWTOs.md#How-to-test-StyleGAN2)
- :zap: [How to test DFDNet](docs/HOWTOs.md#How-to-test-DFDNet)
- [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2)
- [How to test StyleGAN2](docs/HOWTOs.md#How-to-test-StyleGAN2)
- [How to test DFDNet](docs/HOWTOs.md#How-to-test-DFDNet)

## Dependencies and Installation
## :wrench: Dependencies and Installation

- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 1.3](https://pytorch.org/)
Expand All @@ -54,25 +61,22 @@ python setup.py develop

Note that BasicSR is only tested in Ubuntu, and may be not suitable for Windows. You may try [Windows WSL with CUDA supports](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (It is now only available for insider build with Fast ring).

## TODO List
## :hourglass_flowing_sand: TODO List

Please see [project boards](https://github.com/xinntao/BasicSR/projects).

## Dataset Preparation
## :turtle: Dataset Preparation

- Please refer to **[DatasetPreparation.md](docs/DatasetPreparation.md)** for more details.
- The descriptions of currently supported datasets (`torch.utils.data.Dataset` classes) are in [Datasets.md](docs/Datasets.md).

## Train and Test
## :computer: Train and Test

- **Training and testing commands**: Please see **[TrainTest.md](docs/TrainTest.md)** for the basic usage.
- **Options/Configs**: Please refer to [Config.md](docs/Config.md).
- **Logging**: Please refer to [Logging.md](docs/Logging.md).

## Model Zoo and Baselines

**[Download official pre-trained models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing)**<br>
**[Download reproduced models and logs](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)**
## :card_file_box: Model Zoo and Baselines

- The descriptions of currently supported models are in [Models.md](docs/Models.md).
- **Pre-trained models and log examples** are available in **[ModelZoo.md](docs/ModelZoo.md)**.
Expand All @@ -83,19 +87,19 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects).
<img src="./assets/wandb.jpg" height="280">
</a></p>

## Codebase Designs and Conventions
## :memo: Codebase Designs and Conventions

Please see [DesignConvention.md](docs/DesignConvention.md) for the designs and conventions of the BasicSR codebase.<br>
The figure below shows the overall framework. More descriptions for each component: <br>
**[Datasets.md](docs/Datasets.md)**&emsp;|&emsp;**[Models.md](docs/Models.md)**&emsp;|&emsp;**[Config.md](Config.md)**&emsp;|&emsp;**[Logging.md](docs/Logging.md)**

![overall_structure](./assets/overall_structure.png)

## License and Acknowledgement
## :scroll: License and Acknowledgement

This project is released under the Apache 2.0 license.
More details about license and acknowledgement are in [LICENSE](LICENSE/README.md).

## Contact
## :e-mail: Contact

If you have any question, please email `[email protected]`.
38 changes: 21 additions & 17 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@

# :rocket: [BasicSR](https://github.com/xinntao/BasicSR)

[GitHub](https://github.com/xinntao/BasicSR) | [Gitee码云](https://gitee.com/xinntao/BasicSR) <br>
[English](README.md) | [简体中文](README_CN.md)
[English](README.md) **|** [简体中文](README_CN.md) &emsp; [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR)

:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) <br>
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr) <br>
:computer: [训练和测试的命令](docs/TrainTest_CN.md) <br>
:zap: [HOWTOs](#zap-howtos)

---

BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Resolution) 工具箱 (之后会支持更多的 Restoration 任务).<br>
<sub>([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN))</sub>
Expand All @@ -30,15 +37,15 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res
</ul>
</details>

## :zap:HOWTOs
## :zap: HOWTOs

我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分.

- :zap: [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2)
- :zap: [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2)
- :zap: [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet)
- [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2)
- [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2)
- [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet)

## 依赖和安装
## :wrench: 依赖和安装

- Python >= 3.7 (推荐使用 [Anaconda](https://www.anaconda.com/download/#linux)[Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 1.3](https://pytorch.org/)
Expand All @@ -54,25 +61,22 @@ python setup.py develop

注意: BasicSR 仅在 Ubuntu 下进行测试,或许不支持Windows. 可以在Windows下尝试[支持CUDA的Windows WSL](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (目前只有Fast ring的预览版系统可以安装).

## TODO 清单
## :hourglass_flowing_sand: TODO 清单

参见 [project boards](https://github.com/xinntao/BasicSR/projects).

## 数据准备
## :turtle: 数据准备

- 数据准备步骤, 参见 **[DatasetPreparation_CN.md](docs/DatasetPreparation_CN.md)**.
- 目前支持的数据集 (`torch.utils.data.Dataset`类), 参见 [Datasets_CN.md](docs/Datasets_CN.md).

## 训练和测试
## :computer: 训练和测试

- **训练和测试的命令**, 参见 **[TrainTest_CN.md](docs/TrainTest_CN.md)**.
- **Options/Configs**配置文件的说明, 参见 [Config_CN.md](docs/Config_CN.md).
- **Logging**日志系统的说明, 参见 [Logging_CN.md](docs/Logging_CN.md).

## 模型库和基准

**[下载官方提供的预训练模型](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing)** <br>
**[下载复现的模型和log](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)**
## :card_file_box: 模型库和基准

- 目前支持的模型描述, 参见 [Models_CN.md](docs/Models_CN.md).
- **预训练模型和log样例**, 参见 **[ModelZoo_CN.md](docs/ModelZoo_CN.md)**.
Expand All @@ -83,19 +87,19 @@ python setup.py develop
<img src="./assets/wandb.jpg" height="280">
</a></p>

## 代码库的设计和约定
## :memo: 代码库的设计和约定

参见 [DesignConvention_CN.md](docs/DesignConvention_CN.md).<br>
下图概括了整体的框架. 每个模块更多的描述参见: <br>
**[Datasets_CN.md](docs/Datasets_CN.md)**&emsp;|&emsp;**[Models_CN.md](docs/Models_CN.md)**&emsp;|&emsp;**[Config_CN.md](Config_CN.md)**&emsp;|&emsp;**[Logging_CN.md](docs/Logging_CN.md)**

![overall_structure](./assets/overall_structure.png)

## 许可
## :scroll: 许可

本项目使用 Apache 2.0 license.
更多细节参见 [LICENSE](LICENSE/README.md).

#### 联系
## :e-mail: 联系

若有任何问题, 请电邮 `[email protected]`.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0
1.1.1
15 changes: 15 additions & 0 deletions basicsr/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,21 @@ def paired_paths_from_folder(folders, keys, filename_tmpl):
return paths


def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""

paths = list(mmcv.scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths


def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Expand Down
28 changes: 28 additions & 0 deletions basicsr/models/archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,31 @@
importlib.import_module(f'basicsr.models.archs.{file_name}')
for file_name in arch_filenames
]


def dynamic_instantiation(modules, cls_type, opt):
"""Dynamically instantiate class.
Args:
modules (list[importlib modules]): List of modules from importlib
files.
cls_type (str): Class type.
opt (dict): Class initialization kwargs.
Returns:
class: Instantiated class.
"""

for module in modules:
cls_ = getattr(module, cls_type, None)
if cls_ is not None:
break
if cls_ is None:
raise ValueError(f'{cls_type} is not found.')
return cls_(**opt)


def define_network(opt):
network_type = opt.pop('type')
net = dynamic_instantiation(_arch_modules, network_type, opt)
return net
39 changes: 0 additions & 39 deletions basicsr/models/networks.py

This file was deleted.

4 changes: 2 additions & 2 deletions basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from os import path as osp

from basicsr.models import networks as networks
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
from basicsr.utils import ProgressBar, get_root_logger, tensor2img

Expand All @@ -20,7 +20,7 @@ def __init__(self, opt):
super(SRModel, self).__init__(opt)

# define network
self.net_g = networks.define_net_g(deepcopy(opt['network_g']))
self.net_g = define_network(deepcopy(opt['network_g']))
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)

Expand Down
4 changes: 2 additions & 2 deletions basicsr/models/srgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import OrderedDict
from copy import deepcopy

from basicsr.models import networks as networks
from basicsr.models.archs import define_network
from basicsr.models.sr_model import SRModel

loss_module = importlib.import_module('basicsr.models.losses')
Expand All @@ -16,7 +16,7 @@ def init_training_settings(self):
train_opt = self.opt['train']

# define network net_d
self.net_d = networks.define_net_d(deepcopy(self.opt['network_d']))
self.net_d = define_network(deepcopy(self.opt['network_d']))
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)

Expand Down
17 changes: 10 additions & 7 deletions basicsr/models/stylegan2_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import importlib
import math
import mmcv
import numpy as np
import random
import torch
from collections import OrderedDict
from copy import deepcopy
from os import path as osp

from basicsr.models import networks as networks
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
from basicsr.models.losses.losses import g_path_regularize, r1_penalty
from basicsr.utils import tensor2img
Expand All @@ -22,7 +23,7 @@ def __init__(self, opt):
super(StyleGAN2Model, self).__init__(opt)

# define network net_g
self.net_g = networks.define_net_g(deepcopy(opt['network_g']))
self.net_g = define_network(deepcopy(opt['network_g']))
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained model
Expand All @@ -34,8 +35,9 @@ def __init__(self, opt):

# latent dimension: self.num_style_feat
self.num_style_feat = opt['network_g']['num_style_feat']
num_val_samples = self.opt['val'].get('num_val_samples', 16)
self.fixed_sample = torch.randn(
16, self.num_style_feat, device=self.device)
num_val_samples, self.num_style_feat, device=self.device)

if self.is_train:
self.init_training_settings()
Expand All @@ -44,7 +46,7 @@ def init_training_settings(self):
train_opt = self.opt['train']

# define network net_d
self.net_d = networks.define_net_d(deepcopy(self.opt['network_d']))
self.net_d = define_network(deepcopy(self.opt['network_d']))
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)

Expand All @@ -57,8 +59,8 @@ def init_training_settings(self):
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving, do not need to
# wrap with DistributedDataParallel
self.net_g_ema = networks.define_net_g(
deepcopy(self.opt['network_g'])).to(self.device)
self.net_g_ema = define_network(deepcopy(self.opt['network_g'])).to(
self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_model_g', None)
if load_path is not None:
Expand Down Expand Up @@ -311,7 +313,8 @@ def nondist_validation(self, dataloader, current_iter, tb_logger,
f'test_{self.opt["name"]}.png')
mmcv.imwrite(result, save_img_path)
# add sample images to tb_logger
result = mmcv.bgr2rgb(result / 255.)
result = (result / 255.).astype(np.float32)
result = mmcv.bgr2rgb(result)
if tb_logger is not None:
tb_logger.add_image(
'samples', result, global_step=current_iter, dataformats='HWC')
Expand Down
4 changes: 2 additions & 2 deletions basicsr/models/video_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import OrderedDict
from copy import deepcopy

from basicsr.models import networks as networks
from basicsr.models.archs import define_network
from basicsr.models.video_base_model import VideoBaseModel

loss_module = importlib.import_module('basicsr.models.losses')
Expand All @@ -16,7 +16,7 @@ def init_training_settings(self):
train_opt = self.opt['train']

# define network net_d
self.net_d = networks.define_net_d(deepcopy(self.opt['network_d']))
self.net_d = define_network(deepcopy(self.opt['network_d']))
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)

Expand Down
Loading

0 comments on commit 010d29a

Please sign in to comment.