diff --git a/README.md b/README.md
index eb05c84c..15a1f397 100644
--- a/README.md
+++ b/README.md
@@ -1,167 +1,111 @@
-
-
-
+Hallo2: Long-Duration and High-Resolution Audio-driven Portrait Image Animation
-## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)
+## âī¸ Installation
-[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
+- System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 12.1
+- Tested GPUs: A100
+Create conda environment:
- [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) ![Visitors](https://api.infinitescript.com/badgen/count?name=sczhou/CodeFormer<ext=Visitors)
-
-
-[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
-
-S-Lab, Nanyang Technological University
-
-
-
-
-:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
-
-
-### Update
-- **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer)
-- **2023.04.19**: :whale: Training codes and config files are public available now.
-- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
-- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
-- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
-- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
-- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
-- [**More**](docs/history_changelog.md)
-
-### TODO
-- [x] Add training code and config files
-- [x] Add checkpoint and script for face inpainting
-- [x] Add checkpoint and script for face colorization
-- [x] ~~Add background image enhancement~~
-
-#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
-[](https://imgsli.com/MTI3NTE2) [](https://imgsli.com/MTI3NTE1) [](https://imgsli.com/MTI3NTIw)
-
-#### Face Restoration
-
-
-
-
-#### Face Color Enhancement and Restoration
-
-
-
-#### Face Inpainting
-
-
-
-
-
-### Dependencies and Installation
-
-- Pytorch >= 1.7.1
-- CUDA >= 10.1
-- Other required packages in `requirements.txt`
+```bash
+ conda create -n hallo python=3.10
+ conda activate hallo
```
-# git clone this repository
-git clone https://github.com/sczhou/CodeFormer
-cd CodeFormer
-
-# create new anaconda env
-conda create -n codeformer python=3.8 -y
-conda activate codeformer
-
-# install python dependencies
-pip3 install -r requirements.txt
-python basicsr/setup.py develop
-conda install -c conda-forge dlib (only for face detection or cropping with dlib)
-```
-
-### Quick Inference
+Install packages with `pip`
-#### Download Pre-trained Models:
-Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
-```
-python scripts/download_pretrained_models.py facelib
-python scripts/download_pretrained_models.py dlib (only for dlib face detector)
-```
-
-Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
-```
-python scripts/download_pretrained_models.py CodeFormer
+```bash
+ pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
+ pip install -r requirements.txt
```
-#### Prepare Testing Data:
-You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command:
+Besides, ffmpeg is also needed:
+```bash
+ apt-get install ffmpeg
```
-# you may need to install dlib via: conda install -c conda-forge dlib
-python scripts/crop_align_face.py -i [input folder] -o [output folder]
-```
-
-#### Testing:
-[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
+### đĨ Download Pretrained Models
-Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder.
+You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo2).
+Clone the pretrained models into `${PROJECT_ROOT}/pretrained_models` directory by cmd below:
-đ§đģ Face Restoration (cropped and aligned face)
-```
-# For cropped and aligned faces (512x512)
-python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
+```shell
+git lfs install
+git clone https://huggingface.co/fudan-generative-ai/hallo2 pretrained_models
```
-:framed_picture: Whole Image Enhancement
-```
-# For whole image
-# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
-# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
-python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
-```
+Or you can download them separately from their source repo:
+- [hallo2](https://huggingface.co/fudan-generative-ai/hallo2/blob/main/hallo2/net_g.pth): Our checkpoint of video super-resolution.
+- [facelib](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0): pretrained face parse models
+- [realesrgan](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth): background upsample model
+- [CodeFormer](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0): pretrained [Codeformer](https://github.com/sczhou/CodeFormer) model, it's optional to download it, only if you want to train our video super-resolution model from scratch
-:clapper: Video Enhancement
-```
-# For Windows/Mac users, please install ffmpeg first
-conda install -c conda-forge ffmpeg
-```
-```
-# For video clips
-# Video path should end with '.mp4'|'.mov'|'.avi'
-python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
-```
+Finally, these pretrained models should be organized as follows:
-đ Face Colorization (cropped and aligned face)
-```
-# For cropped and aligned faces (512x512)
-# Colorize black and white or faded photo
-python inference_colorization.py --input_path [image folder]|[image path]
+```text
+./pretrained_models/
+|-- CodeFormer/
+| |-- codeformer.pth
+| `-- vqgan_code1024.pth
+|-- facelib
+| |-- detection_mobilenet0.25_Final.pth
+| |-- detection_Resnet50_Final.pth
+| |-- parsing_parsenet.pth
+| |-- yolov5l-face.pth
+| `-- yolov5n-face.pth
+|-- hallo2
+| `-- net_g.pth
+`-- realesrgan
+ `-- RealESRGAN_x2plus.pth
```
-đ¨ Face Inpainting (cropped and aligned face)
-```
-# For cropped and aligned faces (512x512)
-# Inputs could be masked by white brush using an image editing app (e.g., Photoshop)
-# (check out the examples in inputs/masked_faces)
-python inference_inpainting.py --input_path [image folder]|[image path]
-```
-### Training:
-The training commands can be found in the documents: [English](docs/train.md) **|** [įŽäŊä¸æ](docs/train_CN.md).
+### đŽ Run Inference
+#### High-Resolution animation
+Simply to run the `scripts/video_sr.py` and pass `input_path` and `output_path`:
-### Citation
-If our work is useful for your research, please consider citing:
+```bash
+python scripts/video_sr.py --input_path [input_video] --output_path [output_dir] --bg_upsampler realesrgan --face_upsample -w 1 -s 4
+```
- @inproceedings{zhou2022codeformer,
- author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
- title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
- booktitle = {NeurIPS},
- year = {2022}
- }
+Animation results will be saved at `output_dir`.
-### License
+For more options:
-This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
+```shell
+usage: video_sr.py [-h] [-i INPUT_PATH] [-o OUTPUT_PATH] [-w FIDELITY_WEIGHT] [-s UPSCALE] [--has_aligned] [--only_center_face] [--draw_box]
+ [--detection_model DETECTION_MODEL] [--bg_upsampler BG_UPSAMPLER] [--face_upsample] [--bg_tile BG_TILE] [--suffix SUFFIX]
-### Acknowledgement
+options:
+ -h, --help show this help message and exit
+ -i INPUT_PATH, --input_path INPUT_PATH
+ Input video
+ -o OUTPUT_PATH, --output_path OUTPUT_PATH
+ Output folder.
+ -w FIDELITY_WEIGHT, --fidelity_weight FIDELITY_WEIGHT
+ Balance the quality and fidelity. Default: 0.5
+ -s UPSCALE, --upscale UPSCALE
+ The final upsampling scale of the image. Default: 2
+ --has_aligned Input are cropped and aligned faces. Default: False
+ --only_center_face Only restore the center face. Default: False
+ --draw_box Draw the bounding box for the detected faces. Default: False
+ --detection_model DETECTION_MODEL
+ Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. Default: retinaface_resnet50
+ --bg_upsampler BG_UPSAMPLER
+ Background upsampler. Optional: realesrgan
+ --face_upsample Face upsampler after enhancement. Default: False
+ --bg_tile BG_TILE Tile size for background sampler. Default: 400
+ --suffix SUFFIX Suffix of the restored faces. Default: None
+```
-This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
+## Training
+##### prepare data for training
+We use the VFHQ dataset for training, you can download from its [homepage](https://liangbinxie.github.io/projects/vfhq/). Then updata `dataroot_gt` in `./configs/train/video_sr.yaml`.
-### Contact
-If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
+#### training
+Start training with the following command:
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=4652 \
+basicsr/train.py -opt ./configs/train/video_sr.yaml \
+--launcher pytorch
+```
diff --git a/assets/CodeFormer_logo.png b/assets/CodeFormer_logo.png
deleted file mode 100644
index 024cb724..00000000
Binary files a/assets/CodeFormer_logo.png and /dev/null differ
diff --git a/assets/color_enhancement_result1.png b/assets/color_enhancement_result1.png
deleted file mode 100644
index 34433db6..00000000
Binary files a/assets/color_enhancement_result1.png and /dev/null differ
diff --git a/assets/color_enhancement_result2.png b/assets/color_enhancement_result2.png
deleted file mode 100644
index 228690ac..00000000
Binary files a/assets/color_enhancement_result2.png and /dev/null differ
diff --git a/assets/framework.png b/assets/framework.png
new file mode 100755
index 00000000..2ed98dbb
Binary files /dev/null and b/assets/framework.png differ
diff --git a/assets/framework_1.jpg b/assets/framework_1.jpg
new file mode 100755
index 00000000..72de2f86
Binary files /dev/null and b/assets/framework_1.jpg differ
diff --git a/assets/framework_2.jpg b/assets/framework_2.jpg
new file mode 100755
index 00000000..bf9344c2
Binary files /dev/null and b/assets/framework_2.jpg differ
diff --git a/assets/imgsli_1.jpg b/assets/imgsli_1.jpg
deleted file mode 100644
index 313438a6..00000000
Binary files a/assets/imgsli_1.jpg and /dev/null differ
diff --git a/assets/imgsli_2.jpg b/assets/imgsli_2.jpg
deleted file mode 100644
index 42dd7f43..00000000
Binary files a/assets/imgsli_2.jpg and /dev/null differ
diff --git a/assets/imgsli_3.jpg b/assets/imgsli_3.jpg
deleted file mode 100644
index c3f67d9d..00000000
Binary files a/assets/imgsli_3.jpg and /dev/null differ
diff --git a/assets/inpainting_result1.png b/assets/inpainting_result1.png
deleted file mode 100644
index 2c6fa68a..00000000
Binary files a/assets/inpainting_result1.png and /dev/null differ
diff --git a/assets/inpainting_result2.png b/assets/inpainting_result2.png
deleted file mode 100644
index 2945f9f9..00000000
Binary files a/assets/inpainting_result2.png and /dev/null differ
diff --git a/assets/network.jpg b/assets/network.jpg
deleted file mode 100644
index 5aaa6bd1..00000000
Binary files a/assets/network.jpg and /dev/null differ
diff --git a/assets/restoration_result1.png b/assets/restoration_result1.png
deleted file mode 100644
index 8fd3b67e..00000000
Binary files a/assets/restoration_result1.png and /dev/null differ
diff --git a/assets/restoration_result2.png b/assets/restoration_result2.png
deleted file mode 100644
index a2ff2827..00000000
Binary files a/assets/restoration_result2.png and /dev/null differ
diff --git a/assets/restoration_result3.png b/assets/restoration_result3.png
deleted file mode 100644
index 022d7642..00000000
Binary files a/assets/restoration_result3.png and /dev/null differ
diff --git a/assets/restoration_result4.png b/assets/restoration_result4.png
deleted file mode 100644
index 5e965076..00000000
Binary files a/assets/restoration_result4.png and /dev/null differ
diff --git a/assets/wechat.jpeg b/assets/wechat.jpeg
new file mode 100755
index 00000000..f641fd9c
Binary files /dev/null and b/assets/wechat.jpeg differ
diff --git a/basicsr/VERSION b/basicsr/VERSION
old mode 100644
new mode 100755
diff --git a/basicsr/__init__.py b/basicsr/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/arcface_arch.py b/basicsr/archs/arcface_arch.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/codeformer_arch.py b/basicsr/archs/codeformer_arch.py
old mode 100644
new mode 100755
index de66937d..e74dea82
--- a/basicsr/archs/codeformer_arch.py
+++ b/basicsr/archs/codeformer_arch.py
@@ -9,6 +9,10 @@
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
+from einops import rearrange
+
+from tqdm import tqdm
+
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
@@ -110,15 +114,25 @@ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="ge
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
+ self.temp_norm = nn.LayerNorm(embed_dim)
+ self.temp_dropout = nn.Dropout(dropout)
+ self.temp_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+
+ self.temp_ffn_norm = nn.LayerNorm(embed_dim)
+ self.temp_linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.temp_ffn_dropout = nn.Dropout(dropout)
+ self.temp_linear2 = nn.Linear(dim_mlp, embed_dim)
+
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
- def forward(self, tgt,
+ def forward(self, tgt, video_length, batch_size,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
+ query_pos: Optional[Tensor] = None
+ ):
# self attention
tgt2 = self.norm1(tgt)
@@ -131,6 +145,22 @@ def forward(self, tgt,
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
+
+ # tmp attn
+ tgt = rearrange(tgt, "d (b f) c -> f (b d) c", f=video_length)
+ tgt2 = self.temp_norm(tgt)
+ query_pos = rearrange(query_pos, "d (b f) c -> f (b d) c", f=video_length)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.temp_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.temp_dropout(tgt2)
+ tgt = rearrange(tgt, "f (b d) c -> d (b f) c", b=batch_size)
+
+ # ffn
+ tgt2 = self.temp_ffn_norm(tgt)
+ tgt2 = self.temp_linear2(self.temp_ffn_dropout(self.activation(self.temp_linear1(tgt2))))
+ tgt = tgt + self.temp_ffn_dropout(tgt2)
+
return tgt
class Fuse_sft_block(nn.Module):
@@ -151,9 +181,9 @@ def __init__(self, in_ch, out_ch):
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
+ shift = self.shift(enc_feat)
+ out = w * (dec_feat * scale + shift) + dec_feat
+
return out
@@ -166,13 +196,13 @@ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if vqgan_path is not None:
- self.load_state_dict(
+ m, n = self.load_state_dict(
torch.load(vqgan_path, map_location='cpu')['params_ema'])
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
+ # if fix_modules is not None:
+ # for module in fix_modules:
+ # for param in getattr(self, module).parameters():
+ # param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
@@ -221,6 +251,8 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ b, f, _, _, _ = x.shape
+ x = rearrange(x, "b f c h w -> (b f) c h w")
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
@@ -238,7 +270,7 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
+ query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
@@ -248,17 +280,10 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
@@ -269,12 +294,69 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+
+ x = rearrange(x, "(b f) c h w -> b f c h w", f=f)
out = x
# logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
\ No newline at end of file
+ return out, logits, lq_feat
+
+
+ def inference(self, x, w=0, detach_16=True, adain=False):
+ with torch.no_grad():
+ b, f, _, _, _ = x.shape
+ x = rearrange(x, "b f c h w -> (b f) c h w")
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.detach().cpu().clone()
+
+ lq_feat = x.detach()
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].to(x.device), x, w)
+
+ x = rearrange(x, "(b f) c h w -> b f c h w", f=f)
+ # logits doesn't need softmax before cross_entropy loss
+ return x, top_idx
+
\ No newline at end of file
diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py
old mode 100644
new mode 100755
diff --git a/basicsr/archs/vqgan_arch.py b/basicsr/archs/vqgan_arch.py
old mode 100644
new mode 100755
index 5ac69263..3a65de10
--- a/basicsr/archs/vqgan_arch.py
+++ b/basicsr/archs/vqgan_arch.py
@@ -11,6 +11,7 @@
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
+
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py
old mode 100644
new mode 100755
diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/data/ffhq_blind_joint_dataset.py b/basicsr/data/ffhq_blind_joint_dataset.py
deleted file mode 100755
index 0dc845f7..00000000
--- a/basicsr/data/ffhq_blind_joint_dataset.py
+++ /dev/null
@@ -1,324 +0,0 @@
-import cv2
-import math
-import random
-import numpy as np
-import os.path as osp
-from scipy.io import loadmat
-import torch
-import torch.utils.data as data
-from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
- adjust_hue, adjust_saturation, normalize)
-from basicsr.data import gaussian_kernels as gaussian_kernels
-from basicsr.data.transforms import augment
-from basicsr.data.data_util import paths_from_folder
-from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
-from basicsr.utils.registry import DATASET_REGISTRY
-
-@DATASET_REGISTRY.register()
-class FFHQBlindJointDataset(data.Dataset):
-
- def __init__(self, opt):
- super(FFHQBlindJointDataset, self).__init__()
- logger = get_root_logger()
- self.opt = opt
- # file client (io backend)
- self.file_client = None
- self.io_backend_opt = opt['io_backend']
-
- self.gt_folder = opt['dataroot_gt']
- self.gt_size = opt.get('gt_size', 512)
- self.in_size = opt.get('in_size', 512)
- assert self.gt_size >= self.in_size, 'Wrong setting.'
-
- self.mean = opt.get('mean', [0.5, 0.5, 0.5])
- self.std = opt.get('std', [0.5, 0.5, 0.5])
-
- self.component_path = opt.get('component_path', None)
- self.latent_gt_path = opt.get('latent_gt_path', None)
-
- if self.component_path is not None:
- self.crop_components = True
- self.components_dict = torch.load(self.component_path)
- self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
- self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
- self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
- else:
- self.crop_components = False
-
- if self.latent_gt_path is not None:
- self.load_latent_gt = True
- self.latent_gt_dict = torch.load(self.latent_gt_path)
- else:
- self.load_latent_gt = False
-
- if self.io_backend_opt['type'] == 'lmdb':
- self.io_backend_opt['db_paths'] = self.gt_folder
- if not self.gt_folder.endswith('.lmdb'):
- raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
- with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
- self.paths = [line.split('.')[0] for line in fin]
- else:
- self.paths = paths_from_folder(self.gt_folder)
-
- # perform corrupt
- self.use_corrupt = opt.get('use_corrupt', True)
- self.use_motion_kernel = False
- # self.use_motion_kernel = opt.get('use_motion_kernel', True)
-
- if self.use_motion_kernel:
- self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
- motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
- self.motion_kernels = torch.load(motion_kernel_path)
-
- if self.use_corrupt:
- # degradation configurations
- self.blur_kernel_size = self.opt['blur_kernel_size']
- self.kernel_list = self.opt['kernel_list']
- self.kernel_prob = self.opt['kernel_prob']
- # Small degradation
- self.blur_sigma = self.opt['blur_sigma']
- self.downsample_range = self.opt['downsample_range']
- self.noise_range = self.opt['noise_range']
- self.jpeg_range = self.opt['jpeg_range']
- # Large degradation
- self.blur_sigma_large = self.opt['blur_sigma_large']
- self.downsample_range_large = self.opt['downsample_range_large']
- self.noise_range_large = self.opt['noise_range_large']
- self.jpeg_range_large = self.opt['jpeg_range_large']
-
- # print
- logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
- logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
- logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
- logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
-
- # color jitter
- self.color_jitter_prob = opt.get('color_jitter_prob', None)
- self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
- self.color_jitter_shift = opt.get('color_jitter_shift', 20)
- if self.color_jitter_prob is not None:
- logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
-
- # to gray
- self.gray_prob = opt.get('gray_prob', 0.0)
- if self.gray_prob is not None:
- logger.info(f'Use random gray. Prob: {self.gray_prob}')
- self.color_jitter_shift /= 255.
-
- @staticmethod
- def color_jitter(img, shift):
- """jitter color: randomly jitter the RGB values, in numpy formats"""
- jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
- img = img + jitter_val
- img = np.clip(img, 0, 1)
- return img
-
- @staticmethod
- def color_jitter_pt(img, brightness, contrast, saturation, hue):
- """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
- fn_idx = torch.randperm(4)
- for fn_id in fn_idx:
- if fn_id == 0 and brightness is not None:
- brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
- img = adjust_brightness(img, brightness_factor)
-
- if fn_id == 1 and contrast is not None:
- contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
- img = adjust_contrast(img, contrast_factor)
-
- if fn_id == 2 and saturation is not None:
- saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
- img = adjust_saturation(img, saturation_factor)
-
- if fn_id == 3 and hue is not None:
- hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
- img = adjust_hue(img, hue_factor)
- return img
-
-
- def get_component_locations(self, name, status):
- components_bbox = self.components_dict[name]
- if status[0]: # hflip
- # exchange right and left eye
- tmp = components_bbox['left_eye']
- components_bbox['left_eye'] = components_bbox['right_eye']
- components_bbox['right_eye'] = tmp
- # modify the width coordinate
- components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
- components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
- components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
- components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
-
- locations_gt = {}
- locations_in = {}
- for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
- mean = components_bbox[part][0:2]
- half_len = components_bbox[part][2]
- if 'eye' in part:
- half_len *= self.eye_enlarge_ratio
- elif part == 'nose':
- half_len *= self.nose_enlarge_ratio
- elif part == 'mouth':
- half_len *= self.mouth_enlarge_ratio
- loc = np.hstack((mean - half_len + 1, mean + half_len))
- loc = torch.from_numpy(loc).float()
- locations_gt[part] = loc
- loc_in = loc/(self.gt_size//self.in_size)
- locations_in[part] = loc_in
- return locations_gt, locations_in
-
-
- def __getitem__(self, index):
- if self.file_client is None:
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
-
- # load gt image
- gt_path = self.paths[index]
- name = osp.basename(gt_path)[:-4]
- img_bytes = self.file_client.get(gt_path)
- img_gt = imfrombytes(img_bytes, float32=True)
-
- # random horizontal flip
- img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
-
- if self.load_latent_gt:
- if status[0]:
- latent_gt = self.latent_gt_dict['hflip'][name]
- else:
- latent_gt = self.latent_gt_dict['orig'][name]
-
- if self.crop_components:
- locations_gt, locations_in = self.get_component_locations(name, status)
-
- # generate in image
- img_in = img_gt
- if self.use_corrupt:
- # motion blur
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
- m_i = random.randint(0,31)
- k = self.motion_kernels[f'{m_i:02d}']
- img_in = cv2.filter2D(img_in,-1,k)
-
- # gaussian blur
- kernel = gaussian_kernels.random_mixed_kernels(
- self.kernel_list,
- self.kernel_prob,
- self.blur_kernel_size,
- self.blur_sigma,
- self.blur_sigma,
- [-math.pi, math.pi],
- noise_range=None)
- img_in = cv2.filter2D(img_in, -1, kernel)
-
- # downsample
- scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
- img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
-
- # noise
- if self.noise_range is not None:
- noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
- noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
- img_in = img_in + noise
- img_in = np.clip(img_in, 0, 1)
-
- # jpeg
- if self.jpeg_range is not None:
- jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
- _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
- img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
-
- # resize to in_size
- img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
-
-
- # generate in_large with large degradation
- img_in_large = img_gt
-
- if self.use_corrupt:
- # motion blur
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
- m_i = random.randint(0,31)
- k = self.motion_kernels[f'{m_i:02d}']
- img_in_large = cv2.filter2D(img_in_large,-1,k)
-
- # gaussian blur
- kernel = gaussian_kernels.random_mixed_kernels(
- self.kernel_list,
- self.kernel_prob,
- self.blur_kernel_size,
- self.blur_sigma_large,
- self.blur_sigma_large,
- [-math.pi, math.pi],
- noise_range=None)
- img_in_large = cv2.filter2D(img_in_large, -1, kernel)
-
- # downsample
- scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
- img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
-
- # noise
- if self.noise_range_large is not None:
- noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
- noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
- img_in_large = img_in_large + noise
- img_in_large = np.clip(img_in_large, 0, 1)
-
- # jpeg
- if self.jpeg_range_large is not None:
- jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
- _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
- img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
-
- # resize to in_size
- img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
-
-
- # random color jitter (only for lq)
- if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
- img_in = self.color_jitter(img_in, self.color_jitter_shift)
- img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
- # random to gray (only for lq)
- if self.gray_prob and np.random.uniform() < self.gray_prob:
- img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
- img_in = np.tile(img_in[:, :, None], [1, 1, 3])
- img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
- img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
-
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
-
- # random color jitter (pytorch version) (only for lq)
- if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
- brightness = self.opt.get('brightness', (0.5, 1.5))
- contrast = self.opt.get('contrast', (0.5, 1.5))
- saturation = self.opt.get('saturation', (0, 1.5))
- hue = self.opt.get('hue', (-0.1, 0.1))
- img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
- img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
-
- # round and clip
- img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
- img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
-
- # Set vgg range_norm=True if use the normalization here
- # normalize
- normalize(img_in, self.mean, self.std, inplace=True)
- normalize(img_in_large, self.mean, self.std, inplace=True)
- normalize(img_gt, self.mean, self.std, inplace=True)
-
- return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
-
- if self.crop_components:
- return_dict['locations_in'] = locations_in
- return_dict['locations_gt'] = locations_gt
-
- if self.load_latent_gt:
- return_dict['latent_gt'] = latent_gt
-
- return return_dict
-
-
- def __len__(self):
- return len(self.paths)
diff --git a/basicsr/data/gaussian_kernels.py b/basicsr/data/gaussian_kernels.py
index 0ce57f0a..a7c05a33 100755
--- a/basicsr/data/gaussian_kernels.py
+++ b/basicsr/data/gaussian_kernels.py
@@ -632,16 +632,16 @@ def show_one_kernel():
plt.show()
-def show_plateau_kernel():
- import matplotlib.pyplot as plt
- kernel_size = 21
+# def show_plateau_kernel():
+# import matplotlib.pyplot as plt
+# kernel_size = 21
- kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
- kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
- kernel_gau = bivariate_generalized_Gaussian(
- kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
- delta_h, delta_w = mass_center_shift(kernel_size, kernel)
- print(delta_h, delta_w)
+# kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+# kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
+# kernel_gau = bivariate_generalized_Gaussian(
+# kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+# delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+# print(delta_h, delta_w)
# kernel_slice = kernel[10, :]
# kernel_gau_slice = kernel_gau[10, :]
@@ -662,29 +662,29 @@ def show_plateau_kernel():
# ax.plot(t, y2)
# plt.show()
- fig, axs = plt.subplots(nrows=2, ncols=2)
- # axs.set_axis_off()
- ax = axs[0][0]
- im = ax.matshow(kernel, cmap='jet', origin='upper')
- fig.colorbar(im, ax=ax)
-
- # image
- ax = axs[0][1]
- kernel_vis = kernel - np.min(kernel)
- kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
- ax.imshow(kernel_vis, interpolation='nearest')
-
- _, xx, yy = mesh_grid(kernel_size)
- # contour
- ax = axs[1][0]
- CS = ax.contour(xx, yy, kernel, origin='upper')
- ax.clabel(CS, inline=1, fontsize=3)
+ # fig, axs = plt.subplots(nrows=2, ncols=2)
+ # # axs.set_axis_off()
+ # ax = axs[0][0]
+ # im = ax.matshow(kernel, cmap='jet', origin='upper')
+ # fig.colorbar(im, ax=ax)
+
+ # # image
+ # ax = axs[0][1]
+ # kernel_vis = kernel - np.min(kernel)
+ # kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
+ # ax.imshow(kernel_vis, interpolation='nearest')
+
+ # _, xx, yy = mesh_grid(kernel_size)
+ # # contour
+ # ax = axs[1][0]
+ # CS = ax.contour(xx, yy, kernel, origin='upper')
+ # ax.clabel(CS, inline=1, fontsize=3)
+
+ # # contourf
+ # ax = axs[1][1]
+ # kernel = kernel / np.max(kernel)
+ # p = ax.contourf(
+ # xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
+ # fig.colorbar(p)
- # contourf
- ax = axs[1][1]
- kernel = kernel / np.max(kernel)
- p = ax.contourf(
- xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
- fig.colorbar(p)
-
- plt.show()
+ # plt.show()
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
deleted file mode 100644
index c6a6c07b..00000000
--- a/basicsr/data/paired_image_dataset.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from torch.utils import data as data
-from torchvision.transforms.functional import normalize
-
-from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
-from basicsr.data.transforms import augment, paired_random_crop
-from basicsr.utils import FileClient, imfrombytes, img2tensor
-from basicsr.utils.registry import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class PairedImageDataset(data.Dataset):
- """Paired image dataset for image restoration.
-
- Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
- GT image pairs.
-
- There are three modes:
- 1. 'lmdb': Use lmdb files.
- If opt['io_backend'] == lmdb.
- 2. 'meta_info_file': Use meta information file to generate paths.
- If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
- 3. 'folder': Scan folders to generate paths.
- The rest.
-
- Args:
- opt (dict): Config for train datasets. It contains the following keys:
- dataroot_gt (str): Data root path for gt.
- dataroot_lq (str): Data root path for lq.
- meta_info_file (str): Path for meta information file.
- io_backend (dict): IO backend type and other kwarg.
- filename_tmpl (str): Template for each filename. Note that the
- template excludes the file extension. Default: '{}'.
- gt_size (int): Cropped patched size for gt patches.
- use_flip (bool): Use horizontal flips.
- use_rot (bool): Use rotation (use vertical flip and transposing h
- and w for implementation).
-
- scale (bool): Scale, which will be added automatically.
- phase (str): 'train' or 'val'.
- """
-
- def __init__(self, opt):
- super(PairedImageDataset, self).__init__()
- self.opt = opt
- # file client (io backend)
- self.file_client = None
- self.io_backend_opt = opt['io_backend']
- self.mean = opt['mean'] if 'mean' in opt else None
- self.std = opt['std'] if 'std' in opt else None
-
- self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
- if 'filename_tmpl' in opt:
- self.filename_tmpl = opt['filename_tmpl']
- else:
- self.filename_tmpl = '{}'
-
- if self.io_backend_opt['type'] == 'lmdb':
- self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
- self.io_backend_opt['client_keys'] = ['lq', 'gt']
- self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
- elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
- self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
- self.opt['meta_info_file'], self.filename_tmpl)
- else:
- self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
-
- def __getitem__(self, index):
- if self.file_client is None:
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
-
- scale = self.opt['scale']
-
- # Load gt and lq images. Dimension order: HWC; channel order: BGR;
- # image range: [0, 1], float32.
- gt_path = self.paths[index]['gt_path']
- img_bytes = self.file_client.get(gt_path, 'gt')
- img_gt = imfrombytes(img_bytes, float32=True)
- lq_path = self.paths[index]['lq_path']
- img_bytes = self.file_client.get(lq_path, 'lq')
- img_lq = imfrombytes(img_bytes, float32=True)
-
- # augmentation for training
- if self.opt['phase'] == 'train':
- gt_size = self.opt['gt_size']
- # random crop
- img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
- # flip, rotation
- img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
-
- # TODO: color space transform
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
- # normalize
- if self.mean is not None or self.std is not None:
- normalize(img_lq, self.mean, self.std, inplace=True)
- normalize(img_gt, self.mean, self.std, inplace=True)
-
- return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
-
- def __len__(self):
- return len(self.paths)
diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
old mode 100644
new mode 100755
diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py
old mode 100644
new mode 100755
diff --git a/basicsr/data/ffhq_blind_dataset.py b/basicsr/data/vfhq_dataset.py
similarity index 57%
rename from basicsr/data/ffhq_blind_dataset.py
rename to basicsr/data/vfhq_dataset.py
index 9f900606..eadbdeb5 100755
--- a/basicsr/data/ffhq_blind_dataset.py
+++ b/basicsr/data/vfhq_dataset.py
@@ -15,16 +15,20 @@
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
+from pathlib import Path
+import torchvision.transforms as transforms
+
@DATASET_REGISTRY.register()
-class FFHQBlindDataset(data.Dataset):
+class VFHQBlindDataset(data.Dataset):
def __init__(self, opt):
- super(FFHQBlindDataset, self).__init__()
+ super(VFHQBlindDataset, self).__init__()
logger = get_root_logger()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
+ self.video_length = opt['video_length']
self.gt_folder = opt['dataroot_gt']
self.gt_size = opt.get('gt_size', 512)
@@ -59,26 +63,25 @@ def __init__(self, opt):
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
- self.paths = paths_from_folder(self.gt_folder)
+ gt_folder = Path(self.gt_folder)
+ sub_dir = gt_folder.iterdir()
+ self.paths = []
+ for p in sub_dir:
+ if p.is_dir():
+ l = list(p.glob('*.png'))
+ if len(l) > self.video_length:
+ self.paths.append(str(p))
+
# inpainting mask
self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
if self.gen_inpaint_mask:
logger.info(f'generate mask ...')
- # self.mask_max_angle = opt.get('mask_max_angle', 10)
- # self.mask_max_len = opt.get('mask_max_len', 150)
- # self.mask_max_width = opt.get('mask_max_width', 50)
- # self.mask_draw_times = opt.get('mask_draw_times', 10)
- # # print
- # logger.info(f'mask_max_angle: {self.mask_max_angle}')
- # logger.info(f'mask_max_len: {self.mask_max_len}')
- # logger.info(f'mask_max_width: {self.mask_max_width}')
- # logger.info(f'mask_draw_times: {self.mask_draw_times}')
+
# perform corrupt
self.use_corrupt = opt.get('use_corrupt', True)
self.use_motion_kernel = False
- # self.use_motion_kernel = opt.get('use_motion_kernel', True)
if self.use_motion_kernel:
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
@@ -113,6 +116,7 @@ def __init__(self, opt):
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
+
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
@@ -182,118 +186,113 @@ def __getitem__(self, index):
# load gt image
gt_path = self.paths[index]
- name = osp.basename(gt_path)[:-4]
- img_bytes = self.file_client.get(gt_path)
- img_gt = imfrombytes(img_bytes, float32=True)
-
- # random horizontal flip
- img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
-
- if self.load_latent_gt:
- if status[0]:
- latent_gt = self.latent_gt_dict['hflip'][name]
- else:
- latent_gt = self.latent_gt_dict['orig'][name]
-
- if self.crop_components:
- locations_gt, locations_in = self.get_component_locations(name, status)
-
- # generate in image
- img_in = img_gt
- if self.use_corrupt and not self.gen_inpaint_mask:
- # motion blur
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
- m_i = random.randint(0,31)
- k = self.motion_kernels[f'{m_i:02d}']
- img_in = cv2.filter2D(img_in,-1,k)
-
- # gaussian blur
- kernel = gaussian_kernels.random_mixed_kernels(
- self.kernel_list,
- self.kernel_prob,
- self.blur_kernel_size,
- self.blur_sigma,
- self.blur_sigma,
- [-math.pi, math.pi],
- noise_range=None)
- img_in = cv2.filter2D(img_in, -1, kernel)
-
- # downsample
- scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
- img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
-
- # noise
- if self.noise_range is not None:
- noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
- noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
- img_in = img_in + noise
- img_in = np.clip(img_in, 0, 1)
-
- # jpeg
- if self.jpeg_range is not None:
- jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
- _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
- img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
-
- # resize to in_size
- img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
-
- # if self.gen_inpaint_mask:
- # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
- # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
- # max_width = self.mask_max_width, times = self.mask_draw_times)
- # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
- # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
-
- # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
-
- if self.gen_inpaint_mask:
- img_in = (img_in*255).astype('uint8')
- img_in = brush_stroke_mask(Image.fromarray(img_in))
- img_in = np.array(img_in) / 255.
-
- # random color jitter (only for lq)
- if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
- img_in = self.color_jitter(img_in, self.color_jitter_shift)
- # random to gray (only for lq)
- if self.gray_prob and np.random.uniform() < self.gray_prob:
- img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
- img_in = np.tile(img_in[:, :, None], [1, 1, 3])
-
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
- # random color jitter (pytorch version) (only for lq)
- if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
- brightness = self.opt.get('brightness', (0.5, 1.5))
- contrast = self.opt.get('contrast', (0.5, 1.5))
- saturation = self.opt.get('saturation', (0, 1.5))
- hue = self.opt.get('hue', (-0.1, 0.1))
- img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
+ image_list = list(Path(gt_path).glob('*.png'))
+ lenght = len(image_list)
- # round and clip
- img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
+ start_idx = random.randint(0, lenght-self.video_length-1)
+ in_list = []
+ gt_list = []
- # Set vgg range_norm=True if use the normalization here
- # normalize
- normalize(img_in, self.mean, self.std, inplace=True)
- normalize(img_gt, self.mean, self.std, inplace=True)
+ for i in range(start_idx, start_idx+self.video_length):
+ gt_path_idx = image_list[i]
- return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
-
- if self.crop_components:
- return_dict['locations_in'] = locations_in
- return_dict['locations_gt'] = locations_gt
-
- if self.load_latent_gt:
- return_dict['latent_gt'] = latent_gt
-
- # if self.gen_inpaint_mask:
- # return_dict['inpaint_mask'] = inpaint_mask
+ img_bytes = self.file_client.get(gt_path_idx)
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
+
+
+ # generate in image
+ img_in = img_gt
+ if self.use_corrupt and not self.gen_inpaint_mask:
+ # motion blur
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
+ m_i = random.randint(0,31)
+ k = self.motion_kernels[f'{m_i:02d}']
+ img_in = cv2.filter2D(img_in,-1,k)
+
+ # gaussian blur
+ kernel = gaussian_kernels.random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ self.blur_kernel_size,
+ self.blur_sigma,
+ self.blur_sigma,
+ [-math.pi, math.pi],
+ noise_range=None)
+ img_in = cv2.filter2D(img_in, -1, kernel)
+
+ # downsample
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
+
+ # noise
+ if self.noise_range is not None:
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
+ img_in = img_in + noise
+ img_in = np.clip(img_in, 0, 1)
+
+ # jpeg
+ if self.jpeg_range is not None:
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
+
+ # resize to in_size
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
+
+
+ if self.gen_inpaint_mask:
+ img_in = (img_in*255).astype('uint8')
+ img_in = brush_stroke_mask(Image.fromarray(img_in))
+ img_in = np.array(img_in) / 255.
+
+ # random color jitter (only for lq)
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
+ # random to gray (only for lq)
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
+
+ # random color jitter (pytorch version) (only for lq)
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
+ brightness = self.opt.get('brightness', (0.5, 1.5))
+ contrast = self.opt.get('contrast', (0.5, 1.5))
+ saturation = self.opt.get('saturation', (0, 1.5))
+ hue = self.opt.get('hue', (-0.1, 0.1))
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
+
+ # round and clip
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
+
+ # Set vgg range_norm=True if use the normalization here
+ # normalize
+ normalize(img_in, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ img_in = img_in.unsqueeze(0)
+ img_gt = img_gt.unsqueeze(0)
+
+ in_list.append(img_in)
+ gt_list.append(img_gt)
+
+ in_video = torch.cat(in_list, dim=0)
+ gt_video = torch.cat(gt_list, dim=0)
+
+ return_dict = {'in': in_video, 'gt': gt_video}
return return_dict
def __len__(self):
- return len(self.paths)
\ No newline at end of file
+ return len(self.paths)
+
+
\ No newline at end of file
diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/losses/losses.py b/basicsr/losses/losses.py
old mode 100644
new mode 100755
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py
old mode 100644
new mode 100755
diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
old mode 100644
new mode 100755
diff --git a/basicsr/models/codeformer_joint_model.py b/basicsr/models/codeformer_joint_model.py
deleted file mode 100644
index 9e4b9d12..00000000
--- a/basicsr/models/codeformer_joint_model.py
+++ /dev/null
@@ -1,350 +0,0 @@
-import torch
-from collections import OrderedDict
-from os import path as osp
-from tqdm import tqdm
-
-
-from basicsr.archs import build_network
-from basicsr.losses import build_loss
-from basicsr.metrics import calculate_metric
-from basicsr.utils import get_root_logger, imwrite, tensor2img
-from basicsr.utils.registry import MODEL_REGISTRY
-import torch.nn.functional as F
-from .sr_model import SRModel
-
-
-@MODEL_REGISTRY.register()
-class CodeFormerJointModel(SRModel):
- def feed_data(self, data):
- self.gt = data['gt'].to(self.device)
- self.input = data['in'].to(self.device)
- self.input_large_de = data['in_large_de'].to(self.device)
- self.b = self.gt.shape[0]
-
- if 'latent_gt' in data:
- self.idx_gt = data['latent_gt'].to(self.device)
- self.idx_gt = self.idx_gt.view(self.b, -1)
- else:
- self.idx_gt = None
-
- def init_training_settings(self):
- logger = get_root_logger()
- train_opt = self.opt['train']
-
- self.ema_decay = train_opt.get('ema_decay', 0)
- if self.ema_decay > 0:
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
- # define network net_g with Exponential Moving Average (EMA)
- # net_g_ema is used only for testing on one GPU and saving
- # There is no need to wrap with DistributedDataParallel
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
- # load pretrained model
- load_path = self.opt['path'].get('pretrain_network_g', None)
- if load_path is not None:
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
- else:
- self.model_ema(0) # copy net_g weight
- self.net_g_ema.eval()
-
- if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
- self.generate_idx_gt = False
- elif self.opt.get('network_vqgan', None) is not None:
- self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
- self.hq_vqgan_fix.eval()
- self.generate_idx_gt = True
- for param in self.hq_vqgan_fix.parameters():
- param.requires_grad = False
- else:
- raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
-
- logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
-
- self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
- self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
- self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
- self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
- self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
-
- # define network net_d
- self.net_d = build_network(self.opt['network_d'])
- self.net_d = self.model_to_device(self.net_d)
- self.print_network(self.net_d)
-
- # load pretrained models
- load_path = self.opt['path'].get('pretrain_network_d', None)
- if load_path is not None:
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
-
- self.net_g.train()
- self.net_d.train()
-
- # define losses
- if train_opt.get('pixel_opt'):
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
- else:
- self.cri_pix = None
-
- if train_opt.get('perceptual_opt'):
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
- else:
- self.cri_perceptual = None
-
- if train_opt.get('gan_opt'):
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
-
-
- self.fix_generator = train_opt.get('fix_generator', True)
- logger.info(f'fix_generator: {self.fix_generator}')
-
- self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
- self.net_d_iters = train_opt.get('net_d_iters', 1)
- self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
-
- # set up optimizers and schedulers
- self.setup_optimizers()
- self.setup_schedulers()
-
- def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
- recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
-
- d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
- return d_weight
-
- def setup_optimizers(self):
- train_opt = self.opt['train']
- # optimizer g
- optim_params_g = []
- for k, v in self.net_g.named_parameters():
- if v.requires_grad:
- optim_params_g.append(v)
- else:
- logger = get_root_logger()
- logger.warning(f'Params {k} will not be optimized.')
- optim_type = train_opt['optim_g'].pop('type')
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
- self.optimizers.append(self.optimizer_g)
- # optimizer d
- optim_type = train_opt['optim_d'].pop('type')
- self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
- self.optimizers.append(self.optimizer_d)
-
- def gray_resize_for_identity(self, out, size=128):
- out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
- out_gray = out_gray.unsqueeze(1)
- out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
- return out_gray
-
- def optimize_parameters(self, current_iter):
- logger = get_root_logger()
- # optimize net_g
- for p in self.net_d.parameters():
- p.requires_grad = False
-
- self.optimizer_g.zero_grad()
-
- if self.generate_idx_gt:
- x = self.hq_vqgan_fix.encoder(self.gt)
- output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
- min_encoding_indices = quant_stats['min_encoding_indices']
- self.idx_gt = min_encoding_indices.view(self.b, -1)
-
- if current_iter <= 40000: # small degradation
- small_per_n = 1
- w = 1
- elif current_iter <= 80000: # small degradation
- small_per_n = 1
- w = 1.3
- elif current_iter <= 120000: # large degradation
- small_per_n = 120000
- w = 0
- else: # mixed degradation
- small_per_n = 15
- w = 1.3
-
- if current_iter % small_per_n == 0:
- self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
- large_de = False
- else:
- logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
- large_de = True
-
- if self.hq_feat_loss:
- # quant_feats
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
-
- l_g_total = 0
- loss_dict = OrderedDict()
- if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
- # hq_feat_loss
- if not 'transformer' in self.opt['network_g']['fix_modules']:
- if self.hq_feat_loss: # codebook loss
- l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
- l_g_total += l_feat_encoder
- loss_dict['l_feat_encoder'] = l_feat_encoder
-
- # cross_entropy_loss
- if self.cross_entropy_loss:
- # b(hw)n -> bn(hw)
- cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
- l_g_total += cross_entropy_loss
- loss_dict['cross_entropy_loss'] = cross_entropy_loss
-
- # pixel loss
- if not large_de: # when large degradation don't need image-level loss
- if self.cri_pix:
- l_g_pix = self.cri_pix(self.output, self.gt)
- l_g_total += l_g_pix
- loss_dict['l_g_pix'] = l_g_pix
-
- # perceptual loss
- if self.cri_perceptual:
- l_g_percep = self.cri_perceptual(self.output, self.gt)
- l_g_total += l_g_percep
- loss_dict['l_g_percep'] = l_g_percep
-
- # gan loss
- if current_iter > self.net_d_start_iter:
- fake_g_pred = self.net_d(self.output)
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
- recon_loss = l_g_pix + l_g_percep
- if not self.fix_generator:
- last_layer = self.net_g.module.generator.blocks[-1].weight
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
- else:
- largest_fuse_size = self.opt['network_g']['connect_list'][-1]
- last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
-
- d_weight *= self.scale_adaptive_gan_weight # 0.8
- loss_dict['d_weight'] = d_weight
- l_g_total += d_weight * l_g_gan
- loss_dict['l_g_gan'] = d_weight * l_g_gan
-
- l_g_total.backward()
- self.optimizer_g.step()
-
- if self.ema_decay > 0:
- self.model_ema(decay=self.ema_decay)
-
- # optimize net_d
- if not large_de:
- if current_iter > self.net_d_start_iter:
- for p in self.net_d.parameters():
- p.requires_grad = True
-
- self.optimizer_d.zero_grad()
- # real
- real_d_pred = self.net_d(self.gt)
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
- loss_dict['l_d_real'] = l_d_real
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
- l_d_real.backward()
- # fake
- fake_d_pred = self.net_d(self.output.detach())
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
- loss_dict['l_d_fake'] = l_d_fake
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
- l_d_fake.backward()
-
- self.optimizer_d.step()
-
- self.log_dict = self.reduce_loss_dict(loss_dict)
-
-
- def test(self):
- with torch.no_grad():
- if hasattr(self, 'net_g_ema'):
- self.net_g_ema.eval()
- self.output, _, _ = self.net_g_ema(self.input, w=1)
- else:
- logger = get_root_logger()
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
- self.net_g.eval()
- self.output, _, _ = self.net_g(self.input, w=1)
- self.net_g.train()
-
-
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
- if self.opt['rank'] == 0:
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
-
-
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
- dataset_name = dataloader.dataset.opt['name']
- with_metrics = self.opt['val'].get('metrics') is not None
- if with_metrics:
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
- pbar = tqdm(total=len(dataloader), unit='image')
-
- for idx, val_data in enumerate(dataloader):
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
- self.feed_data(val_data)
- self.test()
-
- visuals = self.get_current_visuals()
- sr_img = tensor2img([visuals['result']])
- if 'gt' in visuals:
- gt_img = tensor2img([visuals['gt']])
- del self.gt
-
- # tentative for out of GPU memory
- del self.lq
- del self.output
- torch.cuda.empty_cache()
-
- if save_img:
- if self.opt['is_train']:
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
- f'{img_name}_{current_iter}.png')
- else:
- if self.opt['val']['suffix']:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
- else:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["name"]}.png')
- imwrite(sr_img, save_img_path)
-
- if with_metrics:
- # calculate metrics
- for name, opt_ in self.opt['val']['metrics'].items():
- metric_data = dict(img1=sr_img, img2=gt_img)
- self.metric_results[name] += calculate_metric(metric_data, opt_)
- pbar.update(1)
- pbar.set_description(f'Test {img_name}')
- pbar.close()
-
- if with_metrics:
- for metric in self.metric_results.keys():
- self.metric_results[metric] /= (idx + 1)
-
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
-
-
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
- log_str = f'Validation {dataset_name}\n'
- for metric, value in self.metric_results.items():
- log_str += f'\t # {metric}: {value:.4f}\n'
- logger = get_root_logger()
- logger.info(log_str)
- if tb_logger:
- for metric, value in self.metric_results.items():
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
-
-
- def get_current_visuals(self):
- out_dict = OrderedDict()
- out_dict['gt'] = self.gt.detach().cpu()
- out_dict['result'] = self.output.detach().cpu()
- return out_dict
-
-
- def save(self, epoch, current_iter):
- if self.ema_decay > 0:
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
- else:
- self.save_network(self.net_g, 'net_g', current_iter)
- self.save_network(self.net_d, 'net_d', current_iter)
- self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/codeformer_model.py b/basicsr/models/codeformer_model.py
deleted file mode 100644
index 61829f12..00000000
--- a/basicsr/models/codeformer_model.py
+++ /dev/null
@@ -1,332 +0,0 @@
-import torch
-from collections import OrderedDict
-from os import path as osp
-from tqdm import tqdm
-
-from basicsr.archs import build_network
-from basicsr.losses import build_loss
-from basicsr.metrics import calculate_metric
-from basicsr.utils import get_root_logger, imwrite, tensor2img
-from basicsr.utils.registry import MODEL_REGISTRY
-import torch.nn.functional as F
-from .sr_model import SRModel
-
-
-@MODEL_REGISTRY.register()
-class CodeFormerModel(SRModel):
- def feed_data(self, data):
- self.gt = data['gt'].to(self.device)
- self.input = data['in'].to(self.device)
- self.b = self.gt.shape[0]
-
- if 'latent_gt' in data:
- self.idx_gt = data['latent_gt'].to(self.device)
- self.idx_gt = self.idx_gt.view(self.b, -1)
- else:
- self.idx_gt = None
-
- def init_training_settings(self):
- logger = get_root_logger()
- train_opt = self.opt['train']
-
- self.ema_decay = train_opt.get('ema_decay', 0)
- if self.ema_decay > 0:
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
- # define network net_g with Exponential Moving Average (EMA)
- # net_g_ema is used only for testing on one GPU and saving
- # There is no need to wrap with DistributedDataParallel
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
- # load pretrained model
- load_path = self.opt['path'].get('pretrain_network_g', None)
- if load_path is not None:
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
- else:
- self.model_ema(0) # copy net_g weight
- self.net_g_ema.eval()
-
- if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
- self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
- self.hq_vqgan_fix.eval()
- self.generate_idx_gt = True
- for param in self.hq_vqgan_fix.parameters():
- param.requires_grad = False
- else:
- self.generate_idx_gt = False
-
- self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
- self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
- self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
- self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
- self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
- self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
-
-
- self.net_g.train()
- # define network net_d
- if self.fidelity_weight > 0:
- self.net_d = build_network(self.opt['network_d'])
- self.net_d = self.model_to_device(self.net_d)
- self.print_network(self.net_d)
-
- # load pretrained models
- load_path = self.opt['path'].get('pretrain_network_d', None)
- if load_path is not None:
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
-
- self.net_d.train()
-
- # define losses
- if train_opt.get('pixel_opt'):
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
- else:
- self.cri_pix = None
-
- if train_opt.get('perceptual_opt'):
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
- else:
- self.cri_perceptual = None
-
- if train_opt.get('gan_opt'):
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
-
-
- self.fix_generator = train_opt.get('fix_generator', True)
- logger.info(f'fix_generator: {self.fix_generator}')
-
- self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
- self.net_d_iters = train_opt.get('net_d_iters', 1)
- self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
-
- # set up optimizers and schedulers
- self.setup_optimizers()
- self.setup_schedulers()
-
- def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
- recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
-
- d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
- return d_weight
-
- def setup_optimizers(self):
- train_opt = self.opt['train']
- # optimizer g
- optim_params_g = []
- for k, v in self.net_g.named_parameters():
- if v.requires_grad:
- optim_params_g.append(v)
- else:
- logger = get_root_logger()
- logger.warning(f'Params {k} will not be optimized.')
- optim_type = train_opt['optim_g'].pop('type')
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
- self.optimizers.append(self.optimizer_g)
- # optimizer d
- if self.fidelity_weight > 0:
- optim_type = train_opt['optim_d'].pop('type')
- self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
- self.optimizers.append(self.optimizer_d)
-
- def gray_resize_for_identity(self, out, size=128):
- out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
- out_gray = out_gray.unsqueeze(1)
- out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
- return out_gray
-
- def optimize_parameters(self, current_iter):
- logger = get_root_logger()
- # optimize net_g
- for p in self.net_d.parameters():
- p.requires_grad = False
-
- self.optimizer_g.zero_grad()
-
- if self.generate_idx_gt:
- x = self.hq_vqgan_fix.encoder(self.gt)
- output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
- min_encoding_indices = quant_stats['min_encoding_indices']
- self.idx_gt = min_encoding_indices.view(self.b, -1)
-
- if self.fidelity_weight > 0:
- self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
- else:
- logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
-
- if self.hq_feat_loss:
- # quant_feats
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
-
- l_g_total = 0
- loss_dict = OrderedDict()
- if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
- # hq_feat_loss
- if self.hq_feat_loss: # codebook loss
- l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
- l_g_total += l_feat_encoder
- loss_dict['l_feat_encoder'] = l_feat_encoder
-
- # cross_entropy_loss
- if self.cross_entropy_loss:
- # b(hw)n -> bn(hw)
- cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
- l_g_total += cross_entropy_loss
- loss_dict['cross_entropy_loss'] = cross_entropy_loss
-
- if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
- # pixel loss
- if self.cri_pix:
- l_g_pix = self.cri_pix(self.output, self.gt)
- l_g_total += l_g_pix
- loss_dict['l_g_pix'] = l_g_pix
-
- # perceptual loss
- if self.cri_perceptual:
- l_g_percep = self.cri_perceptual(self.output, self.gt)
- l_g_total += l_g_percep
- loss_dict['l_g_percep'] = l_g_percep
-
- # gan loss
- if current_iter > self.net_d_start_iter:
- fake_g_pred = self.net_d(self.output)
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
- recon_loss = l_g_pix + l_g_percep
- if not self.fix_generator:
- last_layer = self.net_g.module.generator.blocks[-1].weight
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
- else:
- largest_fuse_size = self.opt['network_g']['connect_list'][-1]
- last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
-
- d_weight *= self.scale_adaptive_gan_weight # 0.8
- loss_dict['d_weight'] = d_weight
- l_g_total += d_weight * l_g_gan
- loss_dict['l_g_gan'] = d_weight * l_g_gan
-
- l_g_total.backward()
- self.optimizer_g.step()
-
- if self.ema_decay > 0:
- self.model_ema(decay=self.ema_decay)
-
- # optimize net_d
- if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
- for p in self.net_d.parameters():
- p.requires_grad = True
-
- self.optimizer_d.zero_grad()
- # real
- real_d_pred = self.net_d(self.gt)
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
- loss_dict['l_d_real'] = l_d_real
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
- l_d_real.backward()
- # fake
- fake_d_pred = self.net_d(self.output.detach())
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
- loss_dict['l_d_fake'] = l_d_fake
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
- l_d_fake.backward()
-
- self.optimizer_d.step()
-
- self.log_dict = self.reduce_loss_dict(loss_dict)
-
-
- def test(self):
- with torch.no_grad():
- if hasattr(self, 'net_g_ema'):
- self.net_g_ema.eval()
- self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
- else:
- logger = get_root_logger()
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
- self.net_g.eval()
- self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
- self.net_g.train()
-
-
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
- if self.opt['rank'] == 0:
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
-
-
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
- dataset_name = dataloader.dataset.opt['name']
- with_metrics = self.opt['val'].get('metrics') is not None
- if with_metrics:
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
- pbar = tqdm(total=len(dataloader), unit='image')
-
- for idx, val_data in enumerate(dataloader):
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
- self.feed_data(val_data)
- self.test()
-
- visuals = self.get_current_visuals()
- sr_img = tensor2img([visuals['result']])
- if 'gt' in visuals:
- gt_img = tensor2img([visuals['gt']])
- del self.gt
-
- # tentative for out of GPU memory
- del self.lq
- del self.output
- torch.cuda.empty_cache()
-
- if save_img:
- if self.opt['is_train']:
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
- f'{img_name}_{current_iter}.png')
- else:
- if self.opt['val']['suffix']:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
- else:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["name"]}.png')
- imwrite(sr_img, save_img_path)
-
- if with_metrics:
- # calculate metrics
- for name, opt_ in self.opt['val']['metrics'].items():
- metric_data = dict(img1=sr_img, img2=gt_img)
- self.metric_results[name] += calculate_metric(metric_data, opt_)
- pbar.update(1)
- pbar.set_description(f'Test {img_name}')
- pbar.close()
-
- if with_metrics:
- for metric in self.metric_results.keys():
- self.metric_results[metric] /= (idx + 1)
-
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
-
-
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
- log_str = f'Validation {dataset_name}\n'
- for metric, value in self.metric_results.items():
- log_str += f'\t # {metric}: {value:.4f}\n'
- logger = get_root_logger()
- logger.info(log_str)
- if tb_logger:
- for metric, value in self.metric_results.items():
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
-
-
- def get_current_visuals(self):
- out_dict = OrderedDict()
- out_dict['gt'] = self.gt.detach().cpu()
- out_dict['result'] = self.output.detach().cpu()
- return out_dict
-
-
- def save(self, epoch, current_iter):
- if self.ema_decay > 0:
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
- else:
- self.save_network(self.net_g, 'net_g', current_iter)
- if self.fidelity_weight > 0:
- self.save_network(self.net_d, 'net_d', current_iter)
- self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/codeformer_idx_model.py b/basicsr/models/codeformer_temporal_model.py
old mode 100644
new mode 100755
similarity index 90%
rename from basicsr/models/codeformer_idx_model.py
rename to basicsr/models/codeformer_temporal_model.py
index 005957dd..44ff2d96
--- a/basicsr/models/codeformer_idx_model.py
+++ b/basicsr/models/codeformer_temporal_model.py
@@ -2,6 +2,7 @@
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
+from einops import rearrange
from basicsr.archs import build_network
from basicsr.metrics import calculate_metric
@@ -10,13 +11,16 @@
import torch.nn.functional as F
from .sr_model import SRModel
+# from icecream import ic
@MODEL_REGISTRY.register()
-class CodeFormerIdxModel(SRModel):
+class CodeFormerTempModel(SRModel):
def feed_data(self, data):
self.gt = data['gt'].to(self.device)
self.input = data['in'].to(self.device)
self.b = self.gt.shape[0]
+ self.f = self.gt.shape[1]
+ self.bf = self.b * self.f
if 'latent_gt' in data:
self.idx_gt = data['latent_gt'].to(self.device)
@@ -62,6 +66,13 @@ def init_training_settings(self):
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
self.net_g.train()
+ self.net_g.requires_grad_(False)
+
+ trainable_module = train_opt['trainable_para']
+ for name, module in self.net_g.named_modules():
+ if trainable_module in name :
+ for params in module.parameters():
+ params.requires_grad_(True)
# set up optimizers and schedulers
self.setup_optimizers()
@@ -72,16 +83,18 @@ def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
optim_params_g = []
+ optim_name = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params_g.append(v)
- else:
- logger = get_root_logger()
- logger.warning(f'Params {k} will not be optimized.')
+ optim_name.append(k)
+
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
+ # ic(optim_name)
+
def optimize_parameters(self, current_iter):
logger = get_root_logger()
@@ -89,16 +102,21 @@ def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
if self.generate_idx_gt:
- x = self.hq_vqgan_fix.encoder(self.gt)
+ x = rearrange(self.gt, "b f c h w -> (b f) c h w")
+ x = self.hq_vqgan_fix.encoder(x)
_, _, quant_stats = self.hq_vqgan_fix.quantize(x)
min_encoding_indices = quant_stats['min_encoding_indices']
- self.idx_gt = min_encoding_indices.view(self.b, -1)
+ # ic(min_encoding_indices.shape)
+ self.idx_gt = min_encoding_indices.view(self.bf, -1)
+ # ic(self.idx_gt.shape)
if self.hq_feat_loss:
# quant_feats
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.bf,16,16,256])
logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
+ # ic(logits.shape)
+ # ic(lq_feat.shape)
l_g_total = 0
loss_dict = OrderedDict()
diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py
old mode 100644
new mode 100755
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
old mode 100644
new mode 100755
diff --git a/basicsr/models/vqgan_model.py b/basicsr/models/vqgan_model.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
old mode 100644
new mode 100755
diff --git a/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
old mode 100644
new mode 100755
diff --git a/basicsr/ops/dcn/src/deform_conv_ext.cpp b/basicsr/ops/dcn/src/deform_conv_ext.cpp
old mode 100644
new mode 100755
diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/fused_act/src/fused_bias_act.cpp b/basicsr/ops/fused_act/src/fused_bias_act.cpp
old mode 100644
new mode 100755
diff --git a/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
old mode 100644
new mode 100755
diff --git a/basicsr/ops/upfirdn2d/__init__.py b/basicsr/ops/upfirdn2d/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
old mode 100644
new mode 100755
diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
old mode 100644
new mode 100755
diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py
old mode 100644
new mode 100755
diff --git a/basicsr/setup.py b/basicsr/setup.py
old mode 100644
new mode 100755
diff --git a/basicsr/train.py b/basicsr/train.py
old mode 100644
new mode 100755
index a01c0dfc..84fe0a28
--- a/basicsr/train.py
+++ b/basicsr/train.py
@@ -25,7 +25,7 @@ def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument('--local-rank','--local_rank', type=int, default=0)
args = parser.parse_args()
opt = parse(args.opt, root_path, is_train=is_train)
diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/realesrgan_utils.py b/basicsr/utils/realesrgan_utils.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py
old mode 100644
new mode 100755
diff --git a/basicsr/utils/video_util.py b/basicsr/utils/video_util.py
old mode 100644
new mode 100755
diff --git a/options/CodeFormer_stage2.yml b/configs/train/video_sr.yaml
old mode 100755
new mode 100644
similarity index 83%
rename from options/CodeFormer_stage2.yml
rename to configs/train/video_sr.yaml
index 4dfe9c9d..28ba65db
--- a/options/CodeFormer_stage2.yml
+++ b/configs/train/video_sr.yaml
@@ -1,15 +1,15 @@
# general settings
-name: CodeFormer_stage2
-model_type: CodeFormerIdxModel
+name: CodeFormer_temp
+model_type: CodeFormerTempModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
- name: FFHQ
- type: FFHQBlindDataset
- dataroot_gt: datasets/ffhq/ffhq_512
+ name: VFHQ
+ type: VFHQBlindDataset
+ dataroot_gt: ./VFHQ/image
filename_tmpl: '{}'
io_backend:
type: disk
@@ -20,6 +20,7 @@ datasets:
std: [0.5, 0.5, 0.5]
use_hflip: true
use_corrupt: true
+ video_length: 16
# large degradation in stageII
blur_kernel_size: 41
@@ -33,12 +34,11 @@ datasets:
jpeg_range: [30, 80]
latent_gt_path: ~ # without pre-calculated latent code
- # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
- num_worker_per_gpu: 2
+ num_worker_per_gpu: 8
batch_size_per_gpu: 4
- dataset_enlarge_ratio: 100
+ dataset_enlarge_ratio: 1
prefetch_mode: ~
# val:
@@ -61,7 +61,7 @@ network_g:
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
fix_modules: ['quantize','generator']
- vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
+ vqgan_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN
network_vqgan: # this config is needed if no pre-calculated latent
type: VQAutoEncoder
@@ -70,10 +70,11 @@ network_vqgan: # this config is needed if no pre-calculated latent
ch_mult: [1, 2, 2, 4, 4, 8]
quantizer: 'nearest'
codebook_size: 1024
+ model_path: './pretrained_models/CodeFormer/vqgan_code1024.pth'
# path
path:
- pretrain_network_g: ~
+ pretrain_network_g: './pretrained_models/CodeFormer/codeformer.pth'
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: ~
@@ -88,6 +89,8 @@ train:
entropy_loss_weight: 0.5
fidelity_weight: 0
+ trainable_para: temp
+
optim_g:
type: Adam
lr: !!float 1e-4
@@ -119,7 +122,7 @@ train:
# validation settings
val:
- val_freq: !!float 5e10 # no validation
+ val_freq: 1000
save_img: true
metrics:
@@ -130,8 +133,8 @@ val:
# logging settings
logger:
- print_freq: 100
- save_checkpoint_freq: !!float 1e4
+ print_freq: 1
+ save_checkpoint_freq: 1000
use_tb_logger: true
wandb:
project: ~
diff --git a/docs/history_changelog.md b/docs/history_changelog.md
deleted file mode 100644
index 6c35e34e..00000000
--- a/docs/history_changelog.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# History of Changelog
-
-- **2023.04.19**: :whale: Training codes and config files are public available now.
-- **2023.04.09**: Add features of inpainting and colorization for cropped face images.
-- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
-- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
-- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
-- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
-- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
-- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
-- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
-- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
-- **2022.07.17**: Add Colab demo of CodeFormer.
-- **2022.07.16**: Release inference code for face restoration. :blush:
-- **2022.06.21**: This repo is created.
\ No newline at end of file
diff --git a/docs/train.md b/docs/train.md
deleted file mode 100644
index 873ee7cc..00000000
--- a/docs/train.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# :milky_way: Training Procedures
-[English](train.md) **|** [įŽäŊä¸æ](train_CN.md)
-## Preparing Dataset
-
-- Download training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
-
----
-
-## Training
-```
-For PyTorch versions >= 1.10, please replace `python -m torch.distributed.launch` in the commands below with `torchrun`.
-```
-
-### đž Stage I - VQGAN
-- Training VQGAN:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch
-
-- After VQGAN training, you can pre-calculate code sequence for the training dataset to speed up the later training stages:
- > python scripts/generate_latent_gt.py
-
-- If you don't require training your own VQGAN, you can find pre-trained VQGAN (`vqgan_code1024.pth`) and the corresponding code sequence (`latent_gt_code1024.pth`) in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
-### đ Stage II - CodeFormer (w=0)
-- Training Code Sequence Prediction Module:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch
-
-- Pre-trained CodeFormer of stage II (`codeformer_stage2.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
-### đ¸ Stage III - CodeFormer (w=1)
-- Training Controllable Module:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch
-
-- Pre-trained CodeFormer (`codeformer.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
----
-
-:whale: The project was built using the framework [BasicSR](https://github.com/XPixelGroup/BasicSR). For detailed information on training, resuming, and other related topics, please refer to the documentation: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest.md
diff --git a/docs/train_CN.md b/docs/train_CN.md
deleted file mode 100644
index c1ac2cc4..00000000
--- a/docs/train_CN.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# :milky_way: čŽįģææĄŖ
-[English](train.md) **|** [įŽäŊä¸æ](train_CN.md)
-
-## åå¤æ°æŽé
-- ä¸čŊŊčŽįģæ°æŽé: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
-
----
-
-## čŽįģ
-```
-寚äēPyTorchįæŦ >= 1.10, č¯ˇå°ä¸éĸåŊäģ¤ä¸į`python -m torch.distributed.launch`æŋæĸä¸ē`torchrun`.
-```
-
-### đž éļæŽĩ I - VQGAN
-- čŽįģVQGAN:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch
-
-- čŽįģåŽVQGANåīŧå¯äģĨéčŋä¸éĸäģŖį éĸå
čˇåžčŽįģæ°æŽéįå¯į æŦåēåīŧäģčå éåéĸéļæŽĩįčŽįģčŋį¨:
- > python scripts/generate_latent_gt.py
-
-- åĻæäŊ ä¸éčĻčŽįģčĒåˇąįVQGANīŧå¯äģĨå¨Release v0.1.0ææĄŖä¸æžå°éĸčŽįģįVQGAN (`vqgan_code1024.pth`)å寚åēįå¯į æŦåēå (`latent_gt_code1024.pth`): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
-### đ éļæŽĩ II - CodeFormer (w=0)
-- čŽįģå¯į æŦčŽįģéĸæĩæ¨Ąå:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch
-
-- éĸčŽįģCodeFormerįŦŦäēéļæŽĩæ¨Ąå (`codeformer_stage2.pth`)å¯äģĨå¨Releases v0.1.0ææĄŖéä¸čŊŊ: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
-### đ¸ éļæŽĩ III - CodeFormer (w=1)
-- čŽįģå¯č°æ¨Ąå:
- > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch
-
-- éĸčŽįģCodeFormeræ¨Ąå (`codeformer.pth`)å¯äģĨå¨Releases v0.1.0ææĄŖéä¸čŊŊ: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0
-
----
-
-:whale: č¯Ĩ饚įŽæ¯åēäē[BasicSR](https://github.com/XPixelGroup/BasicSR)æĄæļæåģēīŧæå
ŗčŽįģãResumeįč¯Ļįģäģįģå¯äģĨæĨįææĄŖ: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest_CN.md
\ No newline at end of file
diff --git a/inference_colorization.py b/inference_colorization.py
deleted file mode 100644
index 0f1b763c..00000000
--- a/inference_colorization.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import os
-import cv2
-import argparse
-import glob
-import torch
-from torchvision.transforms.functional import normalize
-from basicsr.utils import imwrite, img2tensor, tensor2img
-from basicsr.utils.download_util import load_file_from_url
-from basicsr.utils.misc import get_device
-from basicsr.utils.registry import ARCH_REGISTRY
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_colorization.pth'
-
-if __name__ == '__main__':
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- device = get_device()
- parser = argparse.ArgumentParser()
-
- parser.add_argument('-i', '--input_path', type=str, default='./inputs/gray_faces',
- help='Input image or folder. Default: inputs/gray_faces')
- parser.add_argument('-o', '--output_path', type=str, default=None,
- help='Output folder. Default: results/')
- parser.add_argument('--suffix', type=str, default=None,
- help='Suffix of the restored faces. Default: None')
- args = parser.parse_args()
-
- # ------------------------ input & output ------------------------
- print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.')
- if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
- input_img_list = [args.input_path]
- result_root = f'results/test_colorization_img'
- else: # input img folder
- if args.input_path.endswith('/'): # solve when path ends with /
- args.input_path = args.input_path[:-1]
- # scan all the jpg and png images
- input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
- result_root = f'results/{os.path.basename(args.input_path)}'
-
- if not args.output_path is None: # set output path
- result_root = args.output_path
-
- test_img_num = len(input_img_list)
-
- # ------------------ set up CodeFormer restorer -------------------
- net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
- connect_list=['32', '64', '128']).to(device)
-
- # ckpt_path = 'weights/CodeFormer/codeformer.pth'
- ckpt_path = load_file_from_url(url=pretrain_model_url,
- model_dir='weights/CodeFormer', progress=True, file_name=None)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- # -------------------- start to processing ---------------------
- for i, img_path in enumerate(input_img_list):
- img_name = os.path.basename(img_path)
- basename, ext = os.path.splitext(img_name)
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
- input_face = cv2.imread(img_path)
- assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for colorization.'
- # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR)
- input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True)
- normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- input_face = input_face.unsqueeze(0).to(device)
- try:
- with torch.no_grad():
- # w is fixed to 0 since we didn't train the Stage III for colorization
- output_face = net(input_face, w=0, adain=True)[0]
- save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1))
- del output_face
- torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}')
- save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1))
-
- save_face = save_face.astype('uint8')
-
- # save face
- if args.suffix is not None:
- basename = f'{basename}_{args.suffix}'
- save_restore_path = os.path.join(result_root, f'{basename}.png')
- imwrite(save_face, save_restore_path)
-
- print(f'\nAll results are saved in {result_root}')
-
diff --git a/inference_inpainting.py b/inference_inpainting.py
deleted file mode 100644
index 9cbfb69a..00000000
--- a/inference_inpainting.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import os
-import cv2
-import argparse
-import glob
-import torch
-from torchvision.transforms.functional import normalize
-from basicsr.utils import imwrite, img2tensor, tensor2img
-from basicsr.utils.download_util import load_file_from_url
-from basicsr.utils.misc import get_device
-from basicsr.utils.registry import ARCH_REGISTRY
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_inpainting.pth'
-
-if __name__ == '__main__':
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- device = get_device()
- parser = argparse.ArgumentParser()
-
- parser.add_argument('-i', '--input_path', type=str, default='./inputs/masked_faces',
- help='Input image or folder. Default: inputs/masked_faces')
- parser.add_argument('-o', '--output_path', type=str, default=None,
- help='Output folder. Default: results/')
- parser.add_argument('--suffix', type=str, default=None,
- help='Suffix of the restored faces. Default: None')
- args = parser.parse_args()
-
- # ------------------------ input & output ------------------------
- print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.')
- if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
- input_img_list = [args.input_path]
- result_root = f'results/test_inpainting_img'
- else: # input img folder
- if args.input_path.endswith('/'): # solve when path ends with /
- args.input_path = args.input_path[:-1]
- # scan all the jpg and png images
- input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
- result_root = f'results/{os.path.basename(args.input_path)}'
-
- if not args.output_path is None: # set output path
- result_root = args.output_path
-
- test_img_num = len(input_img_list)
-
- # ------------------ set up CodeFormer restorer -------------------
- net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=512, n_head=8, n_layers=9,
- connect_list=['32', '64', '128']).to(device)
-
- # ckpt_path = 'weights/CodeFormer/codeformer.pth'
- ckpt_path = load_file_from_url(url=pretrain_model_url,
- model_dir='weights/CodeFormer', progress=True, file_name=None)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- # -------------------- start to processing ---------------------
- for i, img_path in enumerate(input_img_list):
- img_name = os.path.basename(img_path)
- basename, ext = os.path.splitext(img_name)
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
- input_face = cv2.imread(img_path)
- assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for inpainting.'
- # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR)
- input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True)
- normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- input_face = input_face.unsqueeze(0).to(device)
- try:
- with torch.no_grad():
- mask = torch.zeros(512, 512)
- m_ind = torch.sum(input_face[0], dim=0)
- mask[m_ind==3] = 1.0
- mask = mask.view(1, 1, 512, 512).to(device)
- # w is fixed to 1, adain=False for inpainting
- output_face = net(input_face, w=1, adain=False)[0]
- output_face = (1-mask)*input_face + mask*output_face
- save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1))
- del output_face
- torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}')
- save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1))
-
- save_face = save_face.astype('uint8')
-
- # save face
- if args.suffix is not None:
- basename = f'{basename}_{args.suffix}'
- save_restore_path = os.path.join(result_root, f'{basename}.png')
- imwrite(save_face, save_restore_path)
-
- print(f'\nAll results are saved in {result_root}')
-
diff --git a/inputs/cropped_faces/0143.png b/inputs/cropped_faces/0143.png
deleted file mode 100644
index 065b3f00..00000000
Binary files a/inputs/cropped_faces/0143.png and /dev/null differ
diff --git a/inputs/cropped_faces/0240.png b/inputs/cropped_faces/0240.png
deleted file mode 100644
index 7a117017..00000000
Binary files a/inputs/cropped_faces/0240.png and /dev/null differ
diff --git a/inputs/cropped_faces/0342.png b/inputs/cropped_faces/0342.png
deleted file mode 100644
index 8f5aeeae..00000000
Binary files a/inputs/cropped_faces/0342.png and /dev/null differ
diff --git a/inputs/cropped_faces/0345.png b/inputs/cropped_faces/0345.png
deleted file mode 100644
index b8f71b6d..00000000
Binary files a/inputs/cropped_faces/0345.png and /dev/null differ
diff --git a/inputs/cropped_faces/0368.png b/inputs/cropped_faces/0368.png
deleted file mode 100644
index 262778a9..00000000
Binary files a/inputs/cropped_faces/0368.png and /dev/null differ
diff --git a/inputs/cropped_faces/0412.png b/inputs/cropped_faces/0412.png
deleted file mode 100644
index c4a63b13..00000000
Binary files a/inputs/cropped_faces/0412.png and /dev/null differ
diff --git a/inputs/cropped_faces/0444.png b/inputs/cropped_faces/0444.png
deleted file mode 100644
index 9028dd0e..00000000
Binary files a/inputs/cropped_faces/0444.png and /dev/null differ
diff --git a/inputs/cropped_faces/0478.png b/inputs/cropped_faces/0478.png
deleted file mode 100644
index f0924061..00000000
Binary files a/inputs/cropped_faces/0478.png and /dev/null differ
diff --git a/inputs/cropped_faces/0500.png b/inputs/cropped_faces/0500.png
deleted file mode 100644
index 7a7146b4..00000000
Binary files a/inputs/cropped_faces/0500.png and /dev/null differ
diff --git a/inputs/cropped_faces/0599.png b/inputs/cropped_faces/0599.png
deleted file mode 100644
index ff26ccda..00000000
Binary files a/inputs/cropped_faces/0599.png and /dev/null differ
diff --git a/inputs/cropped_faces/0717.png b/inputs/cropped_faces/0717.png
deleted file mode 100644
index 9342b5e5..00000000
Binary files a/inputs/cropped_faces/0717.png and /dev/null differ
diff --git a/inputs/cropped_faces/0720.png b/inputs/cropped_faces/0720.png
deleted file mode 100644
index af384dce..00000000
Binary files a/inputs/cropped_faces/0720.png and /dev/null differ
diff --git a/inputs/cropped_faces/0729.png b/inputs/cropped_faces/0729.png
deleted file mode 100644
index 4f70f46e..00000000
Binary files a/inputs/cropped_faces/0729.png and /dev/null differ
diff --git a/inputs/cropped_faces/0763.png b/inputs/cropped_faces/0763.png
deleted file mode 100644
index 1263df7b..00000000
Binary files a/inputs/cropped_faces/0763.png and /dev/null differ
diff --git a/inputs/cropped_faces/0770.png b/inputs/cropped_faces/0770.png
deleted file mode 100644
index 40a64e83..00000000
Binary files a/inputs/cropped_faces/0770.png and /dev/null differ
diff --git a/inputs/cropped_faces/0777.png b/inputs/cropped_faces/0777.png
deleted file mode 100644
index c72cb26f..00000000
Binary files a/inputs/cropped_faces/0777.png and /dev/null differ
diff --git a/inputs/cropped_faces/0885.png b/inputs/cropped_faces/0885.png
deleted file mode 100644
index f3ea2632..00000000
Binary files a/inputs/cropped_faces/0885.png and /dev/null differ
diff --git a/inputs/cropped_faces/0934.png b/inputs/cropped_faces/0934.png
deleted file mode 100644
index bf82c2d3..00000000
Binary files a/inputs/cropped_faces/0934.png and /dev/null differ
diff --git a/inputs/cropped_faces/Solvay_conference_1927_0018.png b/inputs/cropped_faces/Solvay_conference_1927_0018.png
deleted file mode 100644
index 0f79547a..00000000
Binary files a/inputs/cropped_faces/Solvay_conference_1927_0018.png and /dev/null differ
diff --git a/inputs/cropped_faces/Solvay_conference_1927_2_16.png b/inputs/cropped_faces/Solvay_conference_1927_2_16.png
deleted file mode 100644
index f75b9602..00000000
Binary files a/inputs/cropped_faces/Solvay_conference_1927_2_16.png and /dev/null differ
diff --git a/inputs/gray_faces/067_David_Beckham_00.png b/inputs/gray_faces/067_David_Beckham_00.png
deleted file mode 100644
index 69dc41a6..00000000
Binary files a/inputs/gray_faces/067_David_Beckham_00.png and /dev/null differ
diff --git a/inputs/gray_faces/089_Miley_Cyrus_00.png b/inputs/gray_faces/089_Miley_Cyrus_00.png
deleted file mode 100644
index 29f56936..00000000
Binary files a/inputs/gray_faces/089_Miley_Cyrus_00.png and /dev/null differ
diff --git a/inputs/gray_faces/099_Victoria_Beckham_00.png b/inputs/gray_faces/099_Victoria_Beckham_00.png
deleted file mode 100644
index ccf375f3..00000000
Binary files a/inputs/gray_faces/099_Victoria_Beckham_00.png and /dev/null differ
diff --git a/inputs/gray_faces/111_Alexa_Chung_00.png b/inputs/gray_faces/111_Alexa_Chung_00.png
deleted file mode 100644
index 759c8ebf..00000000
Binary files a/inputs/gray_faces/111_Alexa_Chung_00.png and /dev/null differ
diff --git a/inputs/gray_faces/132_Robert_Downey_Jr_00.png b/inputs/gray_faces/132_Robert_Downey_Jr_00.png
deleted file mode 100644
index d4ec0cbd..00000000
Binary files a/inputs/gray_faces/132_Robert_Downey_Jr_00.png and /dev/null differ
diff --git a/inputs/gray_faces/158_Jimmy_Fallon_00.png b/inputs/gray_faces/158_Jimmy_Fallon_00.png
deleted file mode 100644
index aabb515d..00000000
Binary files a/inputs/gray_faces/158_Jimmy_Fallon_00.png and /dev/null differ
diff --git a/inputs/gray_faces/161_Zac_Efron_00.png b/inputs/gray_faces/161_Zac_Efron_00.png
deleted file mode 100644
index 264d2f48..00000000
Binary files a/inputs/gray_faces/161_Zac_Efron_00.png and /dev/null differ
diff --git a/inputs/gray_faces/169_John_Lennon_00.png b/inputs/gray_faces/169_John_Lennon_00.png
deleted file mode 100644
index 45aef40e..00000000
Binary files a/inputs/gray_faces/169_John_Lennon_00.png and /dev/null differ
diff --git a/inputs/gray_faces/170_Marilyn_Monroe_00.png b/inputs/gray_faces/170_Marilyn_Monroe_00.png
deleted file mode 100644
index 6e1f2a11..00000000
Binary files a/inputs/gray_faces/170_Marilyn_Monroe_00.png and /dev/null differ
diff --git a/inputs/gray_faces/Einstein01.png b/inputs/gray_faces/Einstein01.png
deleted file mode 100644
index 9fd17d39..00000000
Binary files a/inputs/gray_faces/Einstein01.png and /dev/null differ
diff --git a/inputs/gray_faces/Einstein02.png b/inputs/gray_faces/Einstein02.png
deleted file mode 100644
index 4650dd3b..00000000
Binary files a/inputs/gray_faces/Einstein02.png and /dev/null differ
diff --git a/inputs/gray_faces/Hepburn01.png b/inputs/gray_faces/Hepburn01.png
deleted file mode 100644
index 7ef44227..00000000
Binary files a/inputs/gray_faces/Hepburn01.png and /dev/null differ
diff --git a/inputs/gray_faces/Hepburn02.png b/inputs/gray_faces/Hepburn02.png
deleted file mode 100644
index f0f364b8..00000000
Binary files a/inputs/gray_faces/Hepburn02.png and /dev/null differ
diff --git a/inputs/masked_faces/00105.png b/inputs/masked_faces/00105.png
deleted file mode 100644
index 28782c91..00000000
Binary files a/inputs/masked_faces/00105.png and /dev/null differ
diff --git a/inputs/masked_faces/00108.png b/inputs/masked_faces/00108.png
deleted file mode 100644
index 25745db1..00000000
Binary files a/inputs/masked_faces/00108.png and /dev/null differ
diff --git a/inputs/masked_faces/00169.png b/inputs/masked_faces/00169.png
deleted file mode 100644
index c5a64150..00000000
Binary files a/inputs/masked_faces/00169.png and /dev/null differ
diff --git a/inputs/masked_faces/00588.png b/inputs/masked_faces/00588.png
deleted file mode 100644
index c734740e..00000000
Binary files a/inputs/masked_faces/00588.png and /dev/null differ
diff --git a/inputs/masked_faces/00664.png b/inputs/masked_faces/00664.png
deleted file mode 100644
index 1afc6218..00000000
Binary files a/inputs/masked_faces/00664.png and /dev/null differ
diff --git a/inputs/whole_imgs/00.jpg b/inputs/whole_imgs/00.jpg
deleted file mode 100644
index d6e323e5..00000000
Binary files a/inputs/whole_imgs/00.jpg and /dev/null differ
diff --git a/inputs/whole_imgs/01.jpg b/inputs/whole_imgs/01.jpg
deleted file mode 100644
index 485fc6a5..00000000
Binary files a/inputs/whole_imgs/01.jpg and /dev/null differ
diff --git a/inputs/whole_imgs/02.png b/inputs/whole_imgs/02.png
deleted file mode 100644
index 378e7b15..00000000
Binary files a/inputs/whole_imgs/02.png and /dev/null differ
diff --git a/inputs/whole_imgs/03.jpg b/inputs/whole_imgs/03.jpg
deleted file mode 100644
index b6c84281..00000000
Binary files a/inputs/whole_imgs/03.jpg and /dev/null differ
diff --git a/inputs/whole_imgs/04.jpg b/inputs/whole_imgs/04.jpg
deleted file mode 100644
index bb94681a..00000000
Binary files a/inputs/whole_imgs/04.jpg and /dev/null differ
diff --git a/inputs/whole_imgs/05.jpg b/inputs/whole_imgs/05.jpg
deleted file mode 100644
index 4dc33735..00000000
Binary files a/inputs/whole_imgs/05.jpg and /dev/null differ
diff --git a/inputs/whole_imgs/06.png b/inputs/whole_imgs/06.png
deleted file mode 100644
index 49c2fff2..00000000
Binary files a/inputs/whole_imgs/06.png and /dev/null differ
diff --git a/options/CodeFormer_colorization.yml b/options/CodeFormer_colorization.yml
deleted file mode 100755
index 6fa595b4..00000000
--- a/options/CodeFormer_colorization.yml
+++ /dev/null
@@ -1,145 +0,0 @@
-# general settings
-name: CodeFormer_colorization
-model_type: CodeFormerIdxModel
-num_gpu: 8
-manual_seed: 0
-
-# dataset and data loader settings
-datasets:
- train:
- name: FFHQ
- type: FFHQBlindDataset
- dataroot_gt: datasets/ffhq/ffhq_512
- filename_tmpl: '{}'
- io_backend:
- type: disk
-
- in_size: 512
- gt_size: 512
- mean: [0.5, 0.5, 0.5]
- std: [0.5, 0.5, 0.5]
- use_hflip: true
- use_corrupt: true
-
- # large degradation in stageII
- blur_kernel_size: 41
- use_motion_kernel: false
- motion_kernel_prob: 0.001
- kernel_list: ['iso', 'aniso']
- kernel_prob: [0.5, 0.5]
- blur_sigma: [1, 15]
- downsample_range: [4, 30]
- noise_range: [0, 20]
- jpeg_range: [30, 80]
-
- # color jitter and gray
- color_jitter_prob: 0.3
- color_jitter_shift: 20
- color_jitter_pt_prob: 0.3
- gray_prob: 0.01
-
- latent_gt_path: ~ # without pre-calculated latent code
- # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
-
- # data loader
- num_worker_per_gpu: 2
- batch_size_per_gpu: 4
- dataset_enlarge_ratio: 100
- prefetch_mode: ~
-
- # val:
- # name: CelebA-HQ-512
- # type: PairedImageDataset
- # dataroot_lq: datasets/faces/validation/lq
- # dataroot_gt: datasets/faces/validation/gt
- # io_backend:
- # type: disk
- # mean: [0.5, 0.5, 0.5]
- # std: [0.5, 0.5, 0.5]
- # scale: 1
-
-# network structures
-network_g:
- type: CodeFormer
- dim_embd: 512
- n_head: 8
- n_layers: 9
- codebook_size: 1024
- connect_list: ['32', '64', '128', '256']
- fix_modules: ['quantize','generator']
- vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
-
-network_vqgan: # this config is needed if no pre-calculated latent
- type: VQAutoEncoder
- img_size: 512
- nf: 64
- ch_mult: [1, 2, 2, 4, 4, 8]
- quantizer: 'nearest'
- codebook_size: 1024
-
-# path
-path:
- pretrain_network_g: ~
- param_key_g: params_ema
- strict_load_g: false
- pretrain_network_d: ~
- strict_load_d: true
- resume_state: ~
-
-# base_lr(4.5e-6)*bach_size(4)
-train:
- use_hq_feat_loss: true
- feat_loss_weight: 1.0
- cross_entropy_loss: true
- entropy_loss_weight: 0.5
- fidelity_weight: 0
-
- optim_g:
- type: Adam
- lr: !!float 1e-4
- weight_decay: 0
- betas: [0.9, 0.99]
-
- scheduler:
- type: MultiStepLR
- milestones: [400000, 450000]
- gamma: 0.5
-
- total_iter: 500000
-
- warmup_iter: -1 # no warm up
- ema_decay: 0.995
-
- use_adaptive_weight: true
-
- net_g_start_iter: 0
- net_d_iters: 1
- net_d_start_iter: 0
- manual_seed: 0
-
-# validation settings
-val:
- val_freq: !!float 5e10 # no validation
- save_img: true
-
- metrics:
- psnr: # metric name, can be arbitrary
- type: calculate_psnr
- crop_border: 4
- test_y_channel: false
-
-# logging settings
-logger:
- print_freq: 100
- save_checkpoint_freq: !!float 1e4
- use_tb_logger: true
- wandb:
- project: ~
- resume_id: ~
-
-# dist training settings
-dist_params:
- backend: nccl
- port: 29419
-
-find_unused_parameters: true
diff --git a/options/CodeFormer_inpainting.yml b/options/CodeFormer_inpainting.yml
deleted file mode 100755
index ddd68452..00000000
--- a/options/CodeFormer_inpainting.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# general settings
-name: CodeFormer_inpainting
-model_type: CodeFormerModel
-num_gpu: 4
-manual_seed: 0
-
-# dataset and data loader settings
-datasets:
- train:
- name: FFHQ
- type: FFHQBlindDataset
- dataroot_gt: datasets/ffhq/ffhq_512
- filename_tmpl: '{}'
- io_backend:
- type: disk
-
- in_size: 512
- gt_size: 512
- mean: [0.5, 0.5, 0.5]
- std: [0.5, 0.5, 0.5]
- use_hflip: true
- use_corrupt: false
- gen_inpaint_mask: true
-
- latent_gt_path: ~ # without pre-calculated latent code
- # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
-
- # data loader
- num_worker_per_gpu: 2
- batch_size_per_gpu: 3
- dataset_enlarge_ratio: 100
- prefetch_mode: ~
-
- # val:
- # name: CelebA-HQ-512
- # type: PairedImageDataset
- # dataroot_lq: datasets/faces/validation/lq
- # dataroot_gt: datasets/faces/validation/gt
- # io_backend:
- # type: disk
- # mean: [0.5, 0.5, 0.5]
- # std: [0.5, 0.5, 0.5]
- # scale: 1
-
-# network structures
-network_g:
- type: CodeFormer
- dim_embd: 512
- n_head: 8
- n_layers: 9
- codebook_size: 1024
- connect_list: ['32', '64', '128']
- fix_modules: ['quantize','generator']
- vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
-
-network_vqgan: # this config is needed if no pre-calculated latent
- type: VQAutoEncoder
- img_size: 512
- nf: 64
- ch_mult: [1, 2, 2, 4, 4, 8]
- quantizer: 'nearest'
- codebook_size: 1024
-
-network_d:
- type: VQGANDiscriminator
- nc: 3
- ndf: 64
- n_layers: 4
- model_path: ~
-
-# path
-path:
- pretrain_network_g: ~
- param_key_g: params_ema
- strict_load_g: true
- pretrain_network_d: ~
- strict_load_d: true
- resume_state: ~
-
-# base_lr(4.5e-6)*bach_size(4)
-train:
- use_hq_feat_loss: true
- feat_loss_weight: 1.0
- cross_entropy_loss: true
- entropy_loss_weight: 0.5
- scale_adaptive_gan_weight: 0.1
- fidelity_weight: 1.0
-
- optim_g:
- type: Adam
- lr: !!float 7e-5
- weight_decay: 0
- betas: [0.9, 0.99]
- optim_d:
- type: Adam
- lr: !!float 7e-5
- weight_decay: 0
- betas: [0.9, 0.99]
-
- scheduler:
- type: MultiStepLR
- milestones: [250000, 300000]
- gamma: 0.5
-
- total_iter: 300000
-
- warmup_iter: -1 # no warm up
- ema_decay: 0.997
-
- pixel_opt:
- type: L1Loss
- loss_weight: 1.0
- reduction: mean
-
- perceptual_opt:
- type: LPIPSLoss
- loss_weight: 1.0
- use_input_norm: true
- range_norm: true
-
- gan_opt:
- type: GANLoss
- gan_type: hinge
- loss_weight: !!float 1.0 # adaptive_weighting
-
-
- use_adaptive_weight: true
-
- net_g_start_iter: 0
- net_d_iters: 1
- net_d_start_iter: 296001
- manual_seed: 0
-
-# validation settings
-val:
- val_freq: !!float 5e10 # no validation
- save_img: true
-
- metrics:
- psnr: # metric name, can be arbitrary
- type: calculate_psnr
- crop_border: 4
- test_y_channel: false
-
-# logging settings
-logger:
- print_freq: 100
- save_checkpoint_freq: !!float 1e4
- use_tb_logger: true
- wandb:
- project: ~
- resume_id: ~
-
-# dist training settings
-dist_params:
- backend: nccl
- port: 29420
-
-find_unused_parameters: true
diff --git a/options/CodeFormer_stage3.yml b/options/CodeFormer_stage3.yml
deleted file mode 100755
index fbca3f2a..00000000
--- a/options/CodeFormer_stage3.yml
+++ /dev/null
@@ -1,171 +0,0 @@
-# general settings
-name: CodeFormer_stage3
-model_type: CodeFormerJointModel
-num_gpu: 8
-manual_seed: 0
-
-# dataset and data loader settings
-datasets:
- train:
- name: FFHQ
- type: FFHQBlindJointDataset
- dataroot_gt: datasets/ffhq/ffhq_512
- filename_tmpl: '{}'
- io_backend:
- type: disk
-
- in_size: 512
- gt_size: 512
- mean: [0.5, 0.5, 0.5]
- std: [0.5, 0.5, 0.5]
- use_hflip: true
- use_corrupt: true
-
- blur_kernel_size: 41
- use_motion_kernel: false
- motion_kernel_prob: 0.001
- kernel_list: ['iso', 'aniso']
- kernel_prob: [0.5, 0.5]
- # small degradation in stageIII
- blur_sigma: [0.1, 10]
- downsample_range: [1, 12]
- noise_range: [0, 15]
- jpeg_range: [60, 100]
- # large degradation in stageII
- blur_sigma_large: [1, 15]
- downsample_range_large: [4, 30]
- noise_range_large: [0, 20]
- jpeg_range_large: [30, 80]
-
- latent_gt_path: ~ # without pre-calculated latent code
- # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
-
- # data loader
- num_worker_per_gpu: 1
- batch_size_per_gpu: 3
- dataset_enlarge_ratio: 100
- prefetch_mode: ~
-
- # val:
- # name: CelebA-HQ-512
- # type: PairedImageDataset
- # dataroot_lq: datasets/faces/validation/lq
- # dataroot_gt: datasets/faces/validation/gt
- # io_backend:
- # type: disk
- # mean: [0.5, 0.5, 0.5]
- # std: [0.5, 0.5, 0.5]
- # scale: 1
-
-# network structures
-network_g:
- type: CodeFormer
- dim_embd: 512
- n_head: 8
- n_layers: 9
- codebook_size: 1024
- connect_list: ['32', '64', '128', '256']
- fix_modules: ['quantize','generator']
-
-network_vqgan: # this config is needed if no pre-calculated latent
- type: VQAutoEncoder
- img_size: 512
- nf: 64
- ch_mult: [1, 2, 2, 4, 4, 8]
- quantizer: 'nearest'
- codebook_size: 1024
-
-network_d:
- type: VQGANDiscriminator
- nc: 3
- ndf: 64
- n_layers: 4
-
-# path
-path:
- pretrain_network_g: './experiments/pretrained_models/CodeFormer_stage2/net_g_latest.pth' # pretrained G model in StageII
- param_key_g: params_ema
- strict_load_g: false
- pretrain_network_d: './experiments/pretrained_models/CodeFormer_stage2/net_d_latest.pth' # pretrained D model in StageII
- resume_state: ~
-
-# base_lr(4.5e-6)*bach_size(4)
-train:
- use_hq_feat_loss: true
- feat_loss_weight: 1.0
- cross_entropy_loss: true
- entropy_loss_weight: 0.5
- scale_adaptive_gan_weight: 0.1
-
- optim_g:
- type: Adam
- lr: !!float 5e-5
- weight_decay: 0
- betas: [0.9, 0.99]
- optim_d:
- type: Adam
- lr: !!float 5e-5
- weight_decay: 0
- betas: [0.9, 0.99]
-
- scheduler:
- type: CosineAnnealingRestartLR
- periods: [150000]
- restart_weights: [1]
- eta_min: !!float 2e-5
-
-
- total_iter: 150000
-
- warmup_iter: -1 # no warm up
- ema_decay: 0.997
-
- pixel_opt:
- type: L1Loss
- loss_weight: 1.0
- reduction: mean
-
- perceptual_opt:
- type: LPIPSLoss
- loss_weight: 1.0
- use_input_norm: true
- range_norm: true
-
- gan_opt:
- type: GANLoss
- gan_type: hinge
- loss_weight: !!float 1.0 # adaptive_weighting
-
- use_adaptive_weight: true
-
- net_g_start_iter: 0
- net_d_iters: 1
- net_d_start_iter: 5001
- manual_seed: 0
-
-# validation settings
-val:
- val_freq: !!float 5e10 # no validation
- save_img: true
-
- metrics:
- psnr: # metric name, can be arbitrary
- type: calculate_psnr
- crop_border: 4
- test_y_channel: false
-
-# logging settings
-logger:
- print_freq: 100
- save_checkpoint_freq: !!float 5e3
- use_tb_logger: true
- wandb:
- project: ~
- resume_id: ~
-
-# dist training settings
-dist_params:
- backend: nccl
- port: 29413
-
-find_unused_parameters: true
diff --git a/options/VQGAN_512_ds32_nearest_stage1.yml b/options/VQGAN_512_ds32_nearest_stage1.yml
deleted file mode 100755
index 0753fc36..00000000
--- a/options/VQGAN_512_ds32_nearest_stage1.yml
+++ /dev/null
@@ -1,136 +0,0 @@
-# general settings
-name: VQGAN-512-ds32-nearest-stage1
-model_type: VQGANModel
-num_gpu: 8
-manual_seed: 0
-
-# dataset and data loader settings
-datasets:
- train:
- name: FFHQ
- type: FFHQBlindDataset
- dataroot_gt: datasets/ffhq/ffhq_512
- filename_tmpl: '{}'
- io_backend:
- type: disk
-
- in_size: 512
- gt_size: 512
- mean: [0.5, 0.5, 0.5]
- std: [0.5, 0.5, 0.5]
- use_hflip: true
- use_corrupt: false # for VQGAN
-
- # data loader
- num_worker_per_gpu: 2
- batch_size_per_gpu: 4
- dataset_enlarge_ratio: 100
-
- prefetch_mode: cpu
- num_prefetch_queue: 4
-
- # val:
- # name: CelebA-HQ-512
- # type: PairedImageDataset
- # dataroot_lq: datasets/faces/validation/gt
- # dataroot_gt: datasets/faces/validation/gt
- # io_backend:
- # type: disk
- # mean: [0.5, 0.5, 0.5]
- # std: [0.5, 0.5, 0.5]
- # scale: 1
-
-# network structures
-network_g:
- type: VQAutoEncoder
- img_size: 512
- nf: 64
- ch_mult: [1, 2, 2, 4, 4, 8]
- quantizer: 'nearest'
- codebook_size: 1024
-
-network_d:
- type: VQGANDiscriminator
- nc: 3
- ndf: 64
-
-# path
-path:
- pretrain_network_g: ~
- param_key_g: params_ema
- strict_load_g: true
- pretrain_network_d: ~
- strict_load_d: true
- resume_state: ~
-
-# base_lr(4.5e-6)*bach_size(4)
-train:
- optim_g:
- type: Adam
- lr: !!float 7e-5
- weight_decay: 0
- betas: [0.9, 0.99]
- optim_d:
- type: Adam
- lr: !!float 7e-5
- weight_decay: 0
- betas: [0.9, 0.99]
-
- scheduler:
- type: CosineAnnealingRestartLR
- periods: [1600000]
- restart_weights: [1]
- eta_min: !!float 6e-5 # no lr reduce in official vqgan code
-
- total_iter: 1600000
-
- warmup_iter: -1 # no warm up
- ema_decay: 0.995 # GFPGAN: 0.5**(32 / (10 * 1000) == 0.998; Unleashing: 0.995
-
- pixel_opt:
- type: L1Loss
- loss_weight: 1.0
- reduction: mean
-
- perceptual_opt:
- type: LPIPSLoss
- loss_weight: 1.0
- use_input_norm: true
- range_norm: true
-
- gan_opt:
- type: GANLoss
- gan_type: hinge
- loss_weight: !!float 1.0 # adaptive_weighting
-
- net_g_start_iter: 0
- net_d_iters: 1
- net_d_start_iter: 30001
- manual_seed: 0
-
-# validation settings
-val:
- val_freq: !!float 5e10 # no validation
- save_img: true
-
- metrics:
- psnr: # metric name, can be arbitrary
- type: calculate_psnr
- crop_border: 4
- test_y_channel: false
-
-# logging settings
-logger:
- print_freq: 100
- save_checkpoint_freq: !!float 1e4
- use_tb_logger: true
- wandb:
- project: ~
- resume_id: ~
-
-# dist training settings
-dist_params:
- backend: nccl
- port: 29411
-
-find_unused_parameters: true
diff --git a/requirements.txt b/requirements.txt
old mode 100644
new mode 100755
index 7e1950a0..0232d12a
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,17 +1,31 @@
-addict
-future
-lmdb
-numpy
-opencv-python
-Pillow
-pyyaml
-requests
-scikit-image
-scipy
-tb-nightly
-torch>=1.7.1
-torchvision
-tqdm
-yapf
+accelerate==0.28.0
+audio-separator==0.17.2
+av==12.1.0
+bitsandbytes==0.43.1
+decord==0.6.0
+diffusers==0.27.2
+einops==0.8.0
+insightface==0.7.3
+librosa==0.10.2.post1
+mediapipe[vision]==0.10.14
+mlflow==2.13.1
+moviepy==1.0.3
+numpy==1.26.4
+omegaconf==2.3.0
+onnx2torch==1.5.14
+onnx==1.16.1
+onnxruntime-gpu==1.18.0
+opencv-contrib-python==4.9.0.80
+opencv-python-headless==4.9.0.80
+opencv-python==4.9.0.80
+pillow==10.3.0
+setuptools==70.0.0
+tqdm==4.66.4
+transformers==4.39.2
+xformers==0.0.25.post1
+isort==5.13.2
+pylint==3.2.2
+pre-commit==3.7.1
+gradio==4.36.1
lpips
-gdown # supports downloading the large file from Google Drive
\ No newline at end of file
+ffmpeg-python==0.2.0
\ No newline at end of file
diff --git a/scripts/crop_align_face.py b/scripts/crop_align_face.py
deleted file mode 100755
index c44d6e8f..00000000
--- a/scripts/crop_align_face.py
+++ /dev/null
@@ -1,205 +0,0 @@
-"""
-brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
-author: lzhbrian (https://lzhbrian.me)
-link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
-date: 2020.1.5
-note: code is heavily borrowed from
- https://github.com/NVlabs/ffhq-dataset
- http://dlib.net/face_landmark_detection.py.html
-requirements:
- conda install Pillow numpy scipy
- conda install -c conda-forge dlib
- # download face landmark model from:
- # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
-"""
-
-import os
-import glob
-import numpy as np
-import PIL
-import PIL.Image
-import scipy
-import scipy.ndimage
-import argparse
-from basicsr.utils.download_util import load_file_from_url
-
-try:
- import dlib
-except ImportError:
- print('Please install dlib by running:' 'conda install -c conda-forge dlib')
-
-# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
-shape_predictor_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
-ckpt_path = load_file_from_url(url=shape_predictor_url,
- model_dir='weights/dlib', progress=True, file_name=None)
-predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
-
-
-def get_landmark(filepath, only_keep_largest=True):
- """get landmark with dlib
- :return: np.array shape=(68, 2)
- """
- detector = dlib.get_frontal_face_detector()
-
- img = dlib.load_rgb_image(filepath)
- dets = detector(img, 1)
-
- # Shangchen modified
- print("\tNumber of faces detected: {}".format(len(dets)))
- if only_keep_largest:
- print('\tOnly keep the largest.')
- face_areas = []
- for k, d in enumerate(dets):
- face_area = (d.right() - d.left()) * (d.bottom() - d.top())
- face_areas.append(face_area)
-
- largest_idx = face_areas.index(max(face_areas))
- d = dets[largest_idx]
- shape = predictor(img, d)
- # print("Part 0: {}, Part 1: {} ...".format(
- # shape.part(0), shape.part(1)))
- else:
- for k, d in enumerate(dets):
- # print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
- # k, d.left(), d.top(), d.right(), d.bottom()))
- # Get the landmarks/parts for the face in box d.
- shape = predictor(img, d)
- # print("Part 0: {}, Part 1: {} ...".format(
- # shape.part(0), shape.part(1)))
-
- t = list(shape.parts())
- a = []
- for tt in t:
- a.append([tt.x, tt.y])
- lm = np.array(a)
- # lm is a shape=(68,2) np.array
- return lm
-
-def align_face(filepath, out_path):
- """
- :param filepath: str
- :return: PIL Image
- """
- try:
- lm = get_landmark(filepath)
- except:
- print('No landmark ...')
- return
-
- lm_chin = lm[0:17] # left-right
- lm_eyebrow_left = lm[17:22] # left-right
- lm_eyebrow_right = lm[22:27] # left-right
- lm_nose = lm[27:31] # top-down
- lm_nostrils = lm[31:36] # top-down
- lm_eye_left = lm[36:42] # left-clockwise
- lm_eye_right = lm[42:48] # left-clockwise
- lm_mouth_outer = lm[48:60] # left-clockwise
- lm_mouth_inner = lm[60:68] # left-clockwise
-
- # Calculate auxiliary vectors.
- eye_left = np.mean(lm_eye_left, axis=0)
- eye_right = np.mean(lm_eye_right, axis=0)
- eye_avg = (eye_left + eye_right) * 0.5
- eye_to_eye = eye_right - eye_left
- mouth_left = lm_mouth_outer[0]
- mouth_right = lm_mouth_outer[6]
- mouth_avg = (mouth_left + mouth_right) * 0.5
- eye_to_mouth = mouth_avg - eye_avg
-
- # Choose oriented crop rectangle.
- x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
- x /= np.hypot(*x)
- x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
- y = np.flipud(x) * [-1, 1]
- c = eye_avg + eye_to_mouth * 0.1
- quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
- qsize = np.hypot(*x) * 2
-
- # read image
- img = PIL.Image.open(filepath)
-
- output_size = 512
- transform_size = 4096
- enable_padding = False
-
- # Shrink.
- shrink = int(np.floor(qsize / output_size * 0.5))
- if shrink > 1:
- rsize = (int(np.rint(float(img.size[0]) / shrink)),
- int(np.rint(float(img.size[1]) / shrink)))
- img = img.resize(rsize, PIL.Image.ANTIALIAS)
- quad /= shrink
- qsize /= shrink
-
- # Crop.
- border = max(int(np.rint(qsize * 0.1)), 3)
- crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
- int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
- crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
- min(crop[2] + border,
- img.size[0]), min(crop[3] + border, img.size[1]))
- if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
- img = img.crop(crop)
- quad -= crop[0:2]
-
- # Pad.
- pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
- int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
- pad = (max(-pad[0] + border,
- 0), max(-pad[1] + border,
- 0), max(pad[2] - img.size[0] + border,
- 0), max(pad[3] - img.size[1] + border, 0))
- if enable_padding and max(pad) > border - 4:
- pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
- img = np.pad(
- np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
- 'reflect')
- h, w, _ = img.shape
- y, x, _ = np.ogrid[:h, :w, :1]
- mask = np.maximum(
- 1.0 -
- np.minimum(np.float32(x) / pad[0],
- np.float32(w - 1 - x) / pad[2]), 1.0 -
- np.minimum(np.float32(y) / pad[1],
- np.float32(h - 1 - y) / pad[3]))
- blur = qsize * 0.02
- img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
- img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
- img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
- img = PIL.Image.fromarray(
- np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
- quad += pad[:2]
-
- img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
- (quad + 0.5).flatten(), PIL.Image.BILINEAR)
-
- if output_size < transform_size:
- img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
-
- # Save aligned image.
- # print('saveing: ', out_path)
- img.save(out_path)
-
- return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--in_dir', type=str, default='./inputs/whole_imgs')
- parser.add_argument('-o', '--out_dir', type=str, default='./inputs/cropped_faces')
- args = parser.parse_args()
-
- if args.out_dir.endswith('/'): # solve when path ends with /
- args.out_dir = args.out_dir[:-1]
- dir_name = os.path.abspath(args.out_dir)
- os.makedirs(dir_name, exist_ok=True)
-
- img_list = sorted(glob.glob(os.path.join(args.in_dir, '*.[jpJP][pnPN]*[gG]')))
- test_img_num = len(img_list)
-
- for i, in_path in enumerate(img_list):
- img_name = os.path.basename(in_path)
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
- out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
- out_path = out_path.replace('.jpg', '.png')
- size_ = align_face(in_path, out_path)
\ No newline at end of file
diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py
deleted file mode 100644
index 70737833..00000000
--- a/scripts/download_pretrained_models.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import argparse
-import os
-from os import path as osp
-
-from basicsr.utils.download_util import load_file_from_url
-
-
-def download_pretrained_models(method, file_urls):
- if method == 'CodeFormer_train':
- method = 'CodeFormer'
- save_path_root = f'./weights/{method}'
- os.makedirs(save_path_root, exist_ok=True)
-
- for file_name, file_url in file_urls.items():
- save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- 'method',
- type=str,
- help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models."))
- args = parser.parse_args()
-
- file_urls = {
- 'CodeFormer': {
- 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
- },
- 'CodeFormer_train': {
- 'vqgan_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/vqgan_code1024.pth',
- 'latent_gt_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/latent_gt_code1024.pth',
- 'codeformer_stage2.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_stage2.pth',
- 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
- },
- 'facelib': {
- # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
- 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
- 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
- },
- 'dlib': {
- 'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
- 'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
- }
- }
-
- if args.method == 'all':
- for method in file_urls.keys():
- download_pretrained_models(method, file_urls[method])
- else:
- download_pretrained_models(args.method, file_urls[args.method])
\ No newline at end of file
diff --git a/scripts/download_pretrained_models_from_gdrive.py b/scripts/download_pretrained_models_from_gdrive.py
deleted file mode 100644
index 7df5be6f..00000000
--- a/scripts/download_pretrained_models_from_gdrive.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import argparse
-import os
-from os import path as osp
-
-# from basicsr.utils.download_util import download_file_from_google_drive
-import gdown
-
-
-def download_pretrained_models(method, file_ids):
- save_path_root = f'./weights/{method}'
- os.makedirs(save_path_root, exist_ok=True)
-
- for file_name, file_id in file_ids.items():
- file_url = 'https://drive.google.com/uc?id='+file_id
- save_path = osp.abspath(osp.join(save_path_root, file_name))
- if osp.exists(save_path):
- user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
- if user_response.lower() == 'y':
- print(f'Covering {file_name} to {save_path}')
- gdown.download(file_url, save_path, quiet=False)
- # download_file_from_google_drive(file_id, save_path)
- elif user_response.lower() == 'n':
- print(f'Skipping {file_name}')
- else:
- raise ValueError('Wrong input. Only accepts Y/N.')
- else:
- print(f'Downloading {file_name} to {save_path}')
- gdown.download(file_url, save_path, quiet=False)
- # download_file_from_google_drive(file_id, save_path)
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- 'method',
- type=str,
- help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
- args = parser.parse_args()
-
- # file name: file id
- # 'dlib': {
- # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
- # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
- # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
- # }
- file_ids = {
- 'CodeFormer': {
- 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
- },
- 'facelib': {
- 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
- 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
- }
- }
-
- if args.method == 'all':
- for method in file_ids.keys():
- download_pretrained_models(method, file_ids[method])
- else:
- download_pretrained_models(args.method, file_ids[args.method])
\ No newline at end of file
diff --git a/scripts/generate_latent_gt.py b/scripts/generate_latent_gt.py
deleted file mode 100644
index 3f1f17b0..00000000
--- a/scripts/generate_latent_gt.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import argparse
-import glob
-import numpy as np
-import os
-import cv2
-import torch
-from torchvision.transforms.functional import normalize
-from basicsr.utils import imwrite, img2tensor, tensor2img
-
-from basicsr.utils.registry import ARCH_REGISTRY
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
- parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan')
- parser.add_argument('--codebook_size', type=int, default=1024)
- parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
- args = parser.parse_args()
-
- if args.save_root.endswith('/'): # solve when path ends with /
- args.save_root = args.save_root[:-1]
- dir_name = os.path.abspath(args.save_root)
- os.makedirs(dir_name, exist_ok=True)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- test_path = args.test_path
- save_root = args.save_root
- ckpt_path = args.ckpt_path
- codebook_size = args.codebook_size
-
- vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
- codebook_size=codebook_size).to(device)
- checkpoint = torch.load(ckpt_path)['params_ema']
-
- vqgan.load_state_dict(checkpoint)
- vqgan.eval()
-
- sum_latent = np.zeros((codebook_size)).astype('float64')
- size_latent = 16
- latent = {}
- latent['orig'] = {}
- latent['hflip'] = {}
- for i in ['orig', 'hflip']:
- # for i in ['hflip']:
- for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
- img_name = os.path.basename(img_path)
- img = cv2.imread(img_path)
- if i == 'hflip':
- cv2.flip(img, 1, img)
- img = img2tensor(img / 255., bgr2rgb=True, float32=True)
- normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- img = img.unsqueeze(0).to(device)
- with torch.no_grad():
- # output = net(img)[0]
- x, feat_dict = vqgan.encoder(img, True)
- x, _, log = vqgan.quantize(x)
- # del output
- torch.cuda.empty_cache()
-
- min_encoding_indices = log['min_encoding_indices']
- min_encoding_indices = min_encoding_indices.view(size_latent,size_latent)
- latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
- print(img_name, latent[i][img_name[:-4]].shape)
-
- latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
- torch.save(latent, latent_save_path)
- print(f'\nLatent GT code are saved in {save_root}')
diff --git a/scripts/inference_vqgan.py b/scripts/inference_vqgan.py
deleted file mode 100644
index 62644bba..00000000
--- a/scripts/inference_vqgan.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import argparse
-import glob
-import numpy as np
-import os
-import cv2
-import torch
-from torchvision.transforms.functional import normalize
-from basicsr.utils import imwrite, img2tensor, tensor2img
-
-from basicsr.utils.registry import ARCH_REGISTRY
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
- parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec')
- parser.add_argument('--codebook_size', type=int, default=1024)
- parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
- args = parser.parse_args()
-
- if args.save_root.endswith('/'): # solve when path ends with /
- args.save_root = args.save_root[:-1]
- dir_name = os.path.abspath(args.save_root)
- os.makedirs(dir_name, exist_ok=True)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- test_path = args.test_path
- save_root = args.save_root
- ckpt_path = args.ckpt_path
- codebook_size = args.codebook_size
-
- vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
- codebook_size=codebook_size).to(device)
- checkpoint = torch.load(ckpt_path)['params_ema']
-
- vqgan.load_state_dict(checkpoint)
- vqgan.eval()
-
- for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
- img_name = os.path.basename(img_path)
- print(img_name)
- img = cv2.imread(img_path)
- img = img2tensor(img / 255., bgr2rgb=True, float32=True)
- normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- img = img.unsqueeze(0).to(device)
- with torch.no_grad():
- output = vqgan(img)[0]
- output = tensor2img(output, min_max=[-1,1])
- img = tensor2img(img, min_max=[-1,1])
- restored_img = np.concatenate([img, output], axis=1)
- restored_img = output
- del output
- torch.cuda.empty_cache()
-
- path = os.path.splitext(os.path.join(save_root, img_name))[0]
- save_path = f'{path}.png'
- imwrite(restored_img, save_path)
-
- print(f'\nAll results are saved in {save_root}')
-
diff --git a/inference_codeformer.py b/scripts/video_sr.py
old mode 100644
new mode 100755
similarity index 58%
rename from inference_codeformer.py
rename to scripts/video_sr.py
index 1a38cc95..8c885995
--- a/inference_codeformer.py
+++ b/scripts/video_sr.py
@@ -2,8 +2,13 @@
import cv2
import argparse
import glob
+import sys
+
import torch
from torchvision.transforms.functional import normalize
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import gpu_is_available, get_device
@@ -12,9 +17,6 @@
from basicsr.utils.registry import ARCH_REGISTRY
-pretrain_model_url = {
- 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
-}
def set_realesrgan():
from basicsr.archs.rrdbnet_arch import RRDBNet
@@ -36,7 +38,7 @@ def set_realesrgan():
)
upsampler = RealESRGANer(
scale=2,
- model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
+ model_path="./pretrained_models/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=args.bg_tile,
tile_pad=40,
@@ -52,15 +54,17 @@ def set_realesrgan():
category=RuntimeWarning)
return upsampler
+
+
+
if __name__ == '__main__':
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = get_device()
parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
- help='Input image, video or folder. Default: inputs/whole_imgs')
+ parser.add_argument('-i', '--input_path', type=str, help='Input video')
parser.add_argument('-o', '--output_path', type=str, default=None,
- help='Output folder. Default: results/_')
+ help='Output folder')
parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
help='Balance the quality and fidelity. Default: 0.5')
parser.add_argument('-s', '--upscale', type=int, default=2,
@@ -71,23 +75,19 @@ def set_realesrgan():
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
- help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
+ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \
Default: retinaface_resnet50')
parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
- parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
-
+
args = parser.parse_args()
# ------------------------ input & output ------------------------
w = args.fidelity_weight
input_video = False
- if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
- input_img_list = [args.input_path]
- result_root = f'results/test_img_{w}'
- elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
+ if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
from basicsr.utils.video_util import VideoReader, VideoWriter
input_img_list = []
vidreader = VideoReader(args.input_path)
@@ -96,17 +96,13 @@ def set_realesrgan():
input_img_list.append(image)
image = vidreader.get_frame()
audio = vidreader.get_audio()
- fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
+ fps = vidreader.get_fps()
video_name = os.path.basename(args.input_path)[:-4]
- result_root = f'results/{video_name}_{w}'
+ result_root = f'./hq_results/{video_name}_{w}_{args.upscale}'
input_video = True
vidreader.close()
- else: # input img folder
- if args.input_path.endswith('/'): # solve when path ends with /
- args.input_path = args.input_path[:-1]
- # scan all the jpg and png images
- input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
- result_root = f'results/{os.path.basename(args.input_path)}_{w}'
+ else:
+ raise RuntimeError("input should be mp4 file")
if not args.output_path is None: # set output path
result_root = args.output_path
@@ -135,11 +131,12 @@ def set_realesrgan():
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
- # ckpt_path = 'weights/CodeFormer/codeformer.pth'
- ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
- model_dir='weights/CodeFormer', progress=True, file_name=None)
+ ckpt_path = './pretrained_models/hallo2/net_g.pth'
+
checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
+ m, n = net.load_state_dict(checkpoint, strict=False)
+ print("missing key: ", m)
+ assert len(n)==0
net.eval()
# ------------------ set up FaceRestoreHelper -------------------
@@ -160,96 +157,124 @@ def set_realesrgan():
save_ext='png',
use_parse=True,
device=device)
+
+ n = -1
+ input_img_list = input_img_list[:n]
+ length = len(input_img_list)
+
+ overlay = 4
+ chunk = 16
+ idx_list = []
+
+ i=0
+ j=0
+ while i < length and j < length:
+ j = min(i+chunk, length)
+ idx_list.append([i, j])
+ i = j-overlay
+
+
+ id_list = []
# -------------------- start to processing ---------------------
- for i, img_path in enumerate(input_img_list):
+ for i, idx in enumerate(idx_list):
# clean all the intermediate results to process the next image
face_helper.clean_all()
+
+ start = idx[0]
+ end = idx[1]
+
+ img_list = input_img_list[start:end]
+
+ for j, img_path in enumerate(img_list):
- if isinstance(img_path, str):
- img_name = os.path.basename(img_path)
- basename, ext = os.path.splitext(img_name)
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
- img = cv2.imread(img_path, cv2.IMREAD_COLOR)
- else: # for video processing
- basename = str(i).zfill(6)
- img_name = f'{video_name}_{basename}' if input_video else basename
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
- img = img_path
-
- if args.has_aligned:
- # the input faces are already cropped and aligned
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
- face_helper.is_gray = is_gray(img, threshold=10)
- if face_helper.is_gray:
- print('Grayscale input: True')
- face_helper.cropped_faces = [img]
- else:
- face_helper.read_image(img)
- # get face landmarks for each face
- num_det_faces = face_helper.get_face_landmarks_5(
- only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
- print(f'\tdetect {num_det_faces} faces')
- # align and warp each face
- face_helper.align_warp_face()
+ if isinstance(img_path, str):
+ img_name = os.path.basename(img_path)
+ basename, ext = os.path.splitext(img_name)
+ print(f'[{j+1}/{chunk}] Processing: {img_name}')
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+ else: # for video processing
+ basename = str(i).zfill(4)
+ img_name = f'{video_name}_{basename}_{j}' if input_video else basename
+ print(f'[{j+1}/{chunk}] Processing: {img_name}')
+ img = img_path
+ if args.has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.is_gray = is_gray(img, threshold=10)
+ if face_helper.is_gray:
+ print('Grayscale input: True')
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = face_helper.get_face_landmarks_5(
+ only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
+ print(f'\tdetect {num_det_faces} faces')
+ # align and warp each face
+ face_helper.align_warp_face()
+
+ crop_image = []
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
-
- try:
- with torch.no_grad():
- output = net(cropped_face_t, w=w, adain=True)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}')
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+ cropped_face_t = cropped_face_t.unsqueeze(0)
+
+ crop_image.append(cropped_face_t)
+
+ assert len(crop_image)==len(img_list)
+
+ crop_image = torch.cat(crop_image, dim=0).to(device)
+ crop_image = crop_image.unsqueeze(0)
+
+ output, top_idx = net.inference(crop_image, w=w, adain=True)
+ assert output.shape==crop_image.shape
+
+ for k in range(output.shape[1]):
+ face_output = output[:, k:k+1]
+ restored_face = tensor2img(face_output.squeeze_(1), rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
+ cropped_face = face_helper.cropped_faces[k]
face_helper.add_restored_face(restored_face, cropped_face)
+ bg_img_list = []
# paste_back
if not args.has_aligned:
- # upsample the background
- if bg_upsampler is not None:
- # Now only support RealESRGAN for upsampling background
- bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
- else:
- bg_img = None
+ for img in img_list:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
+ else:
+ bg_img = None
+ bg_img_list.append(bg_img)
+
+
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if args.face_upsample and face_upsampler is not None:
- restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
- else:
- restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
-
- # save faces
- for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
- # save cropped face
- if not args.has_aligned:
- save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
- imwrite(cropped_face, save_crop_path)
- # save restored face
- if args.has_aligned:
- save_face_name = f'{basename}.png'
+ restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box, face_upsampler=face_upsampler)
else:
- save_face_name = f'{basename}_{idx:02d}.png'
- if args.suffix is not None:
- save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
- save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
- imwrite(restored_face, save_restore_path)
+ restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box)
+
+ torch.cuda.empty_cache()
+
+ if i!=0:
+ restored_img_list = restored_img_list[overlay:]
+
# save restored img
- if not args.has_aligned and restored_img is not None:
+ if not args.has_aligned and len(restored_img_list)!=0:
if args.suffix is not None:
- basename = f'{basename}_{args.suffix}'
- save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
- imwrite(restored_img, save_restore_path)
+ basename = f'{video_name}_{args.suffix}_{i}'
+ for k, restored_img in enumerate(restored_img_list):
+ kk = str(k).zfill(3)
+ save_restore_path = os.path.join(result_root, 'final_results', f'{basename}_{kk}.png')
+ imwrite(restored_img, save_restore_path)
# save enhanced video
if input_video:
@@ -257,18 +282,24 @@ def set_realesrgan():
# load images
video_frames = []
img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
- for img_path in img_list:
- img = cv2.imread(img_path)
- video_frames.append(img)
+
+ assert len(img_list)==length, print(len(img_list), length)
+
# write images to video
- height, width = video_frames[0].shape[:2]
+ sample_img = cv2.imread(img_list[0])
+ height, width = sample_img.shape[:2]
+
if args.suffix is not None:
video_name = f'{video_name}_{args.suffix}.png'
save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
+
vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
- for f in video_frames:
- vidwriter.write_frame(f)
+ for img_path in img_list:
+ print(img_path)
+ img = cv2.imread(img_path)
+ vidwriter.write_frame(img)
+
vidwriter.close()
print(f'\nAll results are saved in {result_root}')
diff --git a/web-demos/hugging_face/app.py b/web-demos/hugging_face/app.py
deleted file mode 100644
index c614e7c8..00000000
--- a/web-demos/hugging_face/app.py
+++ /dev/null
@@ -1,283 +0,0 @@
-"""
-This file is used for deploying hugging face demo:
-https://huggingface.co/spaces/sczhou/CodeFormer
-"""
-
-import sys
-sys.path.append('CodeFormer')
-import os
-import cv2
-import torch
-import torch.nn.functional as F
-import gradio as gr
-
-from torchvision.transforms.functional import normalize
-
-from basicsr.archs.rrdbnet_arch import RRDBNet
-from basicsr.utils import imwrite, img2tensor, tensor2img
-from basicsr.utils.download_util import load_file_from_url
-from basicsr.utils.misc import gpu_is_available, get_device
-from basicsr.utils.realesrgan_utils import RealESRGANer
-from basicsr.utils.registry import ARCH_REGISTRY
-
-from facelib.utils.face_restoration_helper import FaceRestoreHelper
-from facelib.utils.misc import is_gray
-
-
-os.system("pip freeze")
-
-pretrain_model_url = {
- 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
- 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
- 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
- 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
-}
-# download weights
-if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
- load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
-if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
- load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
-if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
- load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
-if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
- load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
-
-# download images
-torch.hub.download_url_to_file(
- 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
- '01.png')
-torch.hub.download_url_to_file(
- 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
- '02.jpg')
-torch.hub.download_url_to_file(
- 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
- '03.jpg')
-torch.hub.download_url_to_file(
- 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
- '04.jpg')
-torch.hub.download_url_to_file(
- 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
- '05.jpg')
-
-def imread(img_path):
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- return img
-
-# set enhancer with RealESRGAN
-def set_realesrgan():
- # half = True if torch.cuda.is_available() else False
- half = True if gpu_is_available() else False
- model = RRDBNet(
- num_in_ch=3,
- num_out_ch=3,
- num_feat=64,
- num_block=23,
- num_grow_ch=32,
- scale=2,
- )
- upsampler = RealESRGANer(
- scale=2,
- model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
- model=model,
- tile=400,
- tile_pad=40,
- pre_pad=0,
- half=half,
- )
- return upsampler
-
-upsampler = set_realesrgan()
-# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-device = get_device()
-codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
- dim_embd=512,
- codebook_size=1024,
- n_head=8,
- n_layers=9,
- connect_list=["32", "64", "128", "256"],
-).to(device)
-ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
-checkpoint = torch.load(ckpt_path)["params_ema"]
-codeformer_net.load_state_dict(checkpoint)
-codeformer_net.eval()
-
-os.makedirs('output', exist_ok=True)
-
-def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
- """Run a single prediction on the model"""
- try: # global try
- # take the default setting for the demo
- has_aligned = False
- only_center_face = False
- draw_box = False
- detection_model = "retinaface_resnet50"
- print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
-
- img = cv2.imread(str(image), cv2.IMREAD_COLOR)
- print('\timage size:', img.shape)
-
- upscale = int(upscale) # convert type to int
- if upscale > 4: # avoid memory exceeded due to too large upscale
- upscale = 4
- if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution
- upscale = 2
- if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
- upscale = 1
- background_enhance = False
- face_upsample = False
-
- face_helper = FaceRestoreHelper(
- upscale,
- face_size=512,
- crop_ratio=(1, 1),
- det_model=detection_model,
- save_ext="png",
- use_parse=True,
- device=device,
- )
- bg_upsampler = upsampler if background_enhance else None
- face_upsampler = upsampler if face_upsample else None
-
- if has_aligned:
- # the input faces are already cropped and aligned
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
- face_helper.is_gray = is_gray(img, threshold=5)
- if face_helper.is_gray:
- print('\tgrayscale input: True')
- face_helper.cropped_faces = [img]
- else:
- face_helper.read_image(img)
- # get face landmarks for each face
- num_det_faces = face_helper.get_face_landmarks_5(
- only_center_face=only_center_face, resize=640, eye_dist_threshold=5
- )
- print(f'\tdetect {num_det_faces} faces')
- # align and warp each face
- face_helper.align_warp_face()
-
- # face restoration for each cropped face
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
- # prepare data
- cropped_face_t = img2tensor(
- cropped_face / 255.0, bgr2rgb=True, float32=True
- )
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
-
- try:
- with torch.no_grad():
- output = codeformer_net(
- cropped_face_t, w=codeformer_fidelity, adain=True
- )[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- torch.cuda.empty_cache()
- except RuntimeError as error:
- print(f"Failed inference for CodeFormer: {error}")
- restored_face = tensor2img(
- cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
- )
-
- restored_face = restored_face.astype("uint8")
- face_helper.add_restored_face(restored_face)
-
- # paste_back
- if not has_aligned:
- # upsample the background
- if bg_upsampler is not None:
- # Now only support RealESRGAN for upsampling background
- bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
- else:
- bg_img = None
- face_helper.get_inverse_affine(None)
- # paste each restored face to the input image
- if face_upsample and face_upsampler is not None:
- restored_img = face_helper.paste_faces_to_input_image(
- upsample_img=bg_img,
- draw_box=draw_box,
- face_upsampler=face_upsampler,
- )
- else:
- restored_img = face_helper.paste_faces_to_input_image(
- upsample_img=bg_img, draw_box=draw_box
- )
-
- # save restored img
- save_path = f'output/out.png'
- imwrite(restored_img, str(save_path))
-
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
- return restored_img, save_path
- except Exception as error:
- print('Global exception', error)
- return None, None
-
-
-title = "CodeFormer: Robust Face Restoration and Enhancement Network"
-description = r"""
-Official Gradio demo for Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022).
-đĨ CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.
-đ¤ Try CodeFormer for improved stable-diffusion generation!
-"""
-article = r"""
-If CodeFormer is helpful, please help to â the Github Repo. Thanks!
-[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
-
----
-
-đ **Citation**
-
-If our work is useful for your research, please consider citing:
-```bibtex
-@inproceedings{zhou2022codeformer,
- author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
- title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
- booktitle = {NeurIPS},
- year = {2022}
-}
-```
-
-đ **License**
-
-This project is licensed under S-Lab License 1.0.
-Redistribution and use for non-commercial purposes should follow this license.
-
-đ§ **Contact**
-
-If you have any questions, please feel free to reach me out at shangchenzhou@gmail.com.
-
-
- đ¤ Find Me:
-
-
-
-
-
-"""
-
-demo = gr.Interface(
- inference, [
- gr.inputs.Image(type="filepath", label="Input"),
- gr.inputs.Checkbox(default=True, label="Background_Enhance"),
- gr.inputs.Checkbox(default=True, label="Face_Upsample"),
- gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
- gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
- ], [
- gr.outputs.Image(type="numpy", label="Output"),
- gr.outputs.File(label="Download the output")
- ],
- title=title,
- description=description,
- article=article,
- examples=[
- ['01.png', True, True, 2, 0.7],
- ['02.jpg', True, True, 2, 0.7],
- ['03.jpg', True, True, 2, 0.7],
- ['04.jpg', True, True, 2, 0.1],
- ['05.jpg', True, True, 2, 0.1]
- ]
- )
-
-demo.queue(concurrency_count=2)
-demo.launch()
\ No newline at end of file
diff --git a/web-demos/replicate/cog.yaml b/web-demos/replicate/cog.yaml
deleted file mode 100644
index 3f458969..00000000
--- a/web-demos/replicate/cog.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-"""
-This file is used for deploying replicate demo:
-https://replicate.com/sczhou/codeformer
-"""
-
-build:
- gpu: true
- cuda: "11.3"
- python_version: "3.8"
- system_packages:
- - "libgl1-mesa-glx"
- - "libglib2.0-0"
- python_packages:
- - "ipython==8.4.0"
- - "future==0.18.2"
- - "lmdb==1.3.0"
- - "scikit-image==0.19.3"
- - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- - "scipy==1.9.0"
- - "gdown==4.5.1"
- - "pyyaml==6.0"
- - "tb-nightly==2.11.0a20220906"
- - "tqdm==4.64.1"
- - "yapf==0.32.0"
- - "lpips==0.1.4"
- - "Pillow==9.2.0"
- - "opencv-python==4.6.0.66"
-
-predict: "predict.py:Predictor"
diff --git a/web-demos/replicate/predict.py b/web-demos/replicate/predict.py
deleted file mode 100644
index 1b73cbcd..00000000
--- a/web-demos/replicate/predict.py
+++ /dev/null
@@ -1,191 +0,0 @@
-"""
-This file is used for deploying replicate demo:
-https://replicate.com/sczhou/codeformer
-running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2
-push: cog push r8.im/sczhou/codeformer
-"""
-
-import tempfile
-import cv2
-import torch
-from torchvision.transforms.functional import normalize
-try:
- from cog import BasePredictor, Input, Path
-except Exception:
- print('please install cog package')
-
-from basicsr.archs.rrdbnet_arch import RRDBNet
-from basicsr.utils import imwrite, img2tensor, tensor2img
-from basicsr.utils.realesrgan_utils import RealESRGANer
-from basicsr.utils.misc import gpu_is_available
-from basicsr.utils.registry import ARCH_REGISTRY
-
-from facelib.utils.face_restoration_helper import FaceRestoreHelper
-
-class Predictor(BasePredictor):
- def setup(self):
- """Load the model into memory to make running multiple predictions efficient"""
- self.device = "cuda:0"
- self.upsampler = set_realesrgan()
- self.net = ARCH_REGISTRY.get("CodeFormer")(
- dim_embd=512,
- codebook_size=1024,
- n_head=8,
- n_layers=9,
- connect_list=["32", "64", "128", "256"],
- ).to(self.device)
- ckpt_path = "weights/CodeFormer/codeformer.pth"
- checkpoint = torch.load(ckpt_path)[
- "params_ema"
- ] # update file permission if cannot load
- self.net.load_state_dict(checkpoint)
- self.net.eval()
-
- def predict(
- self,
- image: Path = Input(description="Input image"),
- codeformer_fidelity: float = Input(
- default=0.5,
- ge=0,
- le=1,
- description="Balance the quality (lower number) and fidelity (higher number).",
- ),
- background_enhance: bool = Input(
- description="Enhance background image with Real-ESRGAN", default=True
- ),
- face_upsample: bool = Input(
- description="Upsample restored faces for high-resolution AI-created images",
- default=True,
- ),
- upscale: int = Input(
- description="The final upsampling scale of the image",
- default=2,
- ),
- ) -> Path:
- """Run a single prediction on the model"""
-
- # take the default setting for the demo
- has_aligned = False
- only_center_face = False
- draw_box = False
- detection_model = "retinaface_resnet50"
-
- self.face_helper = FaceRestoreHelper(
- upscale,
- face_size=512,
- crop_ratio=(1, 1),
- det_model=detection_model,
- save_ext="png",
- use_parse=True,
- device=self.device,
- )
-
- bg_upsampler = self.upsampler if background_enhance else None
- face_upsampler = self.upsampler if face_upsample else None
-
- img = cv2.imread(str(image), cv2.IMREAD_COLOR)
-
- if has_aligned:
- # the input faces are already cropped and aligned
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
- self.face_helper.cropped_faces = [img]
- else:
- self.face_helper.read_image(img)
- # get face landmarks for each face
- num_det_faces = self.face_helper.get_face_landmarks_5(
- only_center_face=only_center_face, resize=640, eye_dist_threshold=5
- )
- print(f"\tdetect {num_det_faces} faces")
- # align and warp each face
- self.face_helper.align_warp_face()
-
- # face restoration for each cropped face
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
- # prepare data
- cropped_face_t = img2tensor(
- cropped_face / 255.0, bgr2rgb=True, float32=True
- )
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
-
- try:
- with torch.no_grad():
- output = self.net(
- cropped_face_t, w=codeformer_fidelity, adain=True
- )[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- torch.cuda.empty_cache()
- except Exception as error:
- print(f"\tFailed inference for CodeFormer: {error}")
- restored_face = tensor2img(
- cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
- )
-
- restored_face = restored_face.astype("uint8")
- self.face_helper.add_restored_face(restored_face)
-
- # paste_back
- if not has_aligned:
- # upsample the background
- if bg_upsampler is not None:
- # Now only support RealESRGAN for upsampling background
- bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
- else:
- bg_img = None
- self.face_helper.get_inverse_affine(None)
- # paste each restored face to the input image
- if face_upsample and face_upsampler is not None:
- restored_img = self.face_helper.paste_faces_to_input_image(
- upsample_img=bg_img,
- draw_box=draw_box,
- face_upsampler=face_upsampler,
- )
- else:
- restored_img = self.face_helper.paste_faces_to_input_image(
- upsample_img=bg_img, draw_box=draw_box
- )
-
- # save restored img
- out_path = Path(tempfile.mkdtemp()) / 'output.png'
- imwrite(restored_img, str(out_path))
-
- return out_path
-
-
-def imread(img_path):
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- return img
-
-
-def set_realesrgan():
- # if not torch.cuda.is_available(): # CPU
- if not gpu_is_available(): # CPU
- import warnings
-
- warnings.warn(
- "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
- "If you really want to use it, please modify the corresponding codes.",
- category=RuntimeWarning,
- )
- upsampler = None
- else:
- model = RRDBNet(
- num_in_ch=3,
- num_out_ch=3,
- num_feat=64,
- num_block=23,
- num_grow_ch=32,
- scale=2,
- )
- upsampler = RealESRGANer(
- scale=2,
- model_path="./weights/realesrgan/RealESRGAN_x2plus.pth",
- model=model,
- tile=400,
- tile_pad=40,
- pre_pad=0,
- half=True,
- )
- return upsampler
diff --git a/weights/CodeFormer/.gitkeep b/weights/CodeFormer/.gitkeep
deleted file mode 100644
index e69de29b..00000000
diff --git a/weights/README.md b/weights/README.md
deleted file mode 100644
index 67ad334b..00000000
--- a/weights/README.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Weights
-
-Put the downloaded pre-trained models to this folder.
\ No newline at end of file
diff --git a/weights/facelib/.gitkeep b/weights/facelib/.gitkeep
deleted file mode 100644
index e69de29b..00000000