diff --git a/README.md b/README.md index eb05c84c..15a1f397 100644 --- a/README.md +++ b/README.md @@ -1,167 +1,111 @@ -

- -

+

Hallo2: Long-Duration and High-Resolution Audio-driven Portrait Image Animation

-## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022) +## ⚙ī¸ Installation -[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI) +- System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 12.1 +- Tested GPUs: A100 +Create conda environment: -google colab logo [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) ![Visitors](https://api.infinitescript.com/badgen/count?name=sczhou/CodeFormer<ext=Visitors) - - -[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) - -S-Lab, Nanyang Technological University - - - - -:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs: - - -### Update -- **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) -- **2023.04.19**: :whale: Training codes and config files are public available now. -- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images. -- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity. -- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper: -- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) -- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) -- [**More**](docs/history_changelog.md) - -### TODO -- [x] Add training code and config files -- [x] Add checkpoint and script for face inpainting -- [x] Add checkpoint and script for face colorization -- [x] ~~Add background image enhancement~~ - -#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts -[](https://imgsli.com/MTI3NTE2) [](https://imgsli.com/MTI3NTE1) [](https://imgsli.com/MTI3NTIw) - -#### Face Restoration - - - - -#### Face Color Enhancement and Restoration - - - -#### Face Inpainting - - - - - -### Dependencies and Installation - -- Pytorch >= 1.7.1 -- CUDA >= 10.1 -- Other required packages in `requirements.txt` +```bash + conda create -n hallo python=3.10 + conda activate hallo ``` -# git clone this repository -git clone https://github.com/sczhou/CodeFormer -cd CodeFormer - -# create new anaconda env -conda create -n codeformer python=3.8 -y -conda activate codeformer - -# install python dependencies -pip3 install -r requirements.txt -python basicsr/setup.py develop -conda install -c conda-forge dlib (only for face detection or cropping with dlib) -``` - -### Quick Inference +Install packages with `pip` -#### Download Pre-trained Models: -Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command: -``` -python scripts/download_pretrained_models.py facelib -python scripts/download_pretrained_models.py dlib (only for dlib face detector) -``` - -Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command: -``` -python scripts/download_pretrained_models.py CodeFormer +```bash + pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 + pip install -r requirements.txt ``` -#### Prepare Testing Data: -You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command: +Besides, ffmpeg is also needed: +```bash + apt-get install ffmpeg ``` -# you may need to install dlib via: conda install -c conda-forge dlib -python scripts/crop_align_face.py -i [input folder] -o [output folder] -``` - -#### Testing: -[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison. +### đŸ“Ĩ Download Pretrained Models -Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder. +You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo2). +Clone the pretrained models into `${PROJECT_ROOT}/pretrained_models` directory by cmd below: -🧑đŸģ Face Restoration (cropped and aligned face) -``` -# For cropped and aligned faces (512x512) -python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path] +```shell +git lfs install +git clone https://huggingface.co/fudan-generative-ai/hallo2 pretrained_models ``` -:framed_picture: Whole Image Enhancement -``` -# For whole image -# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN -# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN -python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path] -``` +Or you can download them separately from their source repo: +- [hallo2](https://huggingface.co/fudan-generative-ai/hallo2/blob/main/hallo2/net_g.pth): Our checkpoint of video super-resolution. +- [facelib](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0): pretrained face parse models +- [realesrgan](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth): background upsample model +- [CodeFormer](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0): pretrained [Codeformer](https://github.com/sczhou/CodeFormer) model, it's optional to download it, only if you want to train our video super-resolution model from scratch -:clapper: Video Enhancement -``` -# For Windows/Mac users, please install ffmpeg first -conda install -c conda-forge ffmpeg -``` -``` -# For video clips -# Video path should end with '.mp4'|'.mov'|'.avi' -python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path] -``` +Finally, these pretrained models should be organized as follows: -🌈 Face Colorization (cropped and aligned face) -``` -# For cropped and aligned faces (512x512) -# Colorize black and white or faded photo -python inference_colorization.py --input_path [image folder]|[image path] +```text +./pretrained_models/ +|-- CodeFormer/ +| |-- codeformer.pth +| `-- vqgan_code1024.pth +|-- facelib +| |-- detection_mobilenet0.25_Final.pth +| |-- detection_Resnet50_Final.pth +| |-- parsing_parsenet.pth +| |-- yolov5l-face.pth +| `-- yolov5n-face.pth +|-- hallo2 +| `-- net_g.pth +`-- realesrgan + `-- RealESRGAN_x2plus.pth ``` -🎨 Face Inpainting (cropped and aligned face) -``` -# For cropped and aligned faces (512x512) -# Inputs could be masked by white brush using an image editing app (e.g., Photoshop) -# (check out the examples in inputs/masked_faces) -python inference_inpainting.py --input_path [image folder]|[image path] -``` -### Training: -The training commands can be found in the documents: [English](docs/train.md) **|** [įŽ€äŊ“中文](docs/train_CN.md). +### 🎮 Run Inference +#### High-Resolution animation +Simply to run the `scripts/video_sr.py` and pass `input_path` and `output_path`: -### Citation -If our work is useful for your research, please consider citing: +```bash +python scripts/video_sr.py --input_path [input_video] --output_path [output_dir] --bg_upsampler realesrgan --face_upsample -w 1 -s 4 +``` - @inproceedings{zhou2022codeformer, - author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change}, - title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer}, - booktitle = {NeurIPS}, - year = {2022} - } +Animation results will be saved at `output_dir`. -### License +For more options: -This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. +```shell +usage: video_sr.py [-h] [-i INPUT_PATH] [-o OUTPUT_PATH] [-w FIDELITY_WEIGHT] [-s UPSCALE] [--has_aligned] [--only_center_face] [--draw_box] + [--detection_model DETECTION_MODEL] [--bg_upsampler BG_UPSAMPLER] [--face_upsample] [--bg_tile BG_TILE] [--suffix SUFFIX] -### Acknowledgement +options: + -h, --help show this help message and exit + -i INPUT_PATH, --input_path INPUT_PATH + Input video + -o OUTPUT_PATH, --output_path OUTPUT_PATH + Output folder. + -w FIDELITY_WEIGHT, --fidelity_weight FIDELITY_WEIGHT + Balance the quality and fidelity. Default: 0.5 + -s UPSCALE, --upscale UPSCALE + The final upsampling scale of the image. Default: 2 + --has_aligned Input are cropped and aligned faces. Default: False + --only_center_face Only restore the center face. Default: False + --draw_box Draw the bounding box for the detected faces. Default: False + --detection_model DETECTION_MODEL + Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. Default: retinaface_resnet50 + --bg_upsampler BG_UPSAMPLER + Background upsampler. Optional: realesrgan + --face_upsample Face upsampler after enhancement. Default: False + --bg_tile BG_TILE Tile size for background sampler. Default: 400 + --suffix SUFFIX Suffix of the restored faces. Default: None +``` -This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works. +## Training +##### prepare data for training +We use the VFHQ dataset for training, you can download from its [homepage](https://liangbinxie.github.io/projects/vfhq/). Then updata `dataroot_gt` in `./configs/train/video_sr.yaml`. -### Contact -If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`. +#### training +Start training with the following command: +```shell +python -m torch.distributed.launch --nproc_per_node=8 --master_port=4652 \ +basicsr/train.py -opt ./configs/train/video_sr.yaml \ +--launcher pytorch +``` diff --git a/assets/CodeFormer_logo.png b/assets/CodeFormer_logo.png deleted file mode 100644 index 024cb724..00000000 Binary files a/assets/CodeFormer_logo.png and /dev/null differ diff --git a/assets/color_enhancement_result1.png b/assets/color_enhancement_result1.png deleted file mode 100644 index 34433db6..00000000 Binary files a/assets/color_enhancement_result1.png and /dev/null differ diff --git a/assets/color_enhancement_result2.png b/assets/color_enhancement_result2.png deleted file mode 100644 index 228690ac..00000000 Binary files a/assets/color_enhancement_result2.png and /dev/null differ diff --git a/assets/framework.png b/assets/framework.png new file mode 100755 index 00000000..2ed98dbb Binary files /dev/null and b/assets/framework.png differ diff --git a/assets/framework_1.jpg b/assets/framework_1.jpg new file mode 100755 index 00000000..72de2f86 Binary files /dev/null and b/assets/framework_1.jpg differ diff --git a/assets/framework_2.jpg b/assets/framework_2.jpg new file mode 100755 index 00000000..bf9344c2 Binary files /dev/null and b/assets/framework_2.jpg differ diff --git a/assets/imgsli_1.jpg b/assets/imgsli_1.jpg deleted file mode 100644 index 313438a6..00000000 Binary files a/assets/imgsli_1.jpg and /dev/null differ diff --git a/assets/imgsli_2.jpg b/assets/imgsli_2.jpg deleted file mode 100644 index 42dd7f43..00000000 Binary files a/assets/imgsli_2.jpg and /dev/null differ diff --git a/assets/imgsli_3.jpg b/assets/imgsli_3.jpg deleted file mode 100644 index c3f67d9d..00000000 Binary files a/assets/imgsli_3.jpg and /dev/null differ diff --git a/assets/inpainting_result1.png b/assets/inpainting_result1.png deleted file mode 100644 index 2c6fa68a..00000000 Binary files a/assets/inpainting_result1.png and /dev/null differ diff --git a/assets/inpainting_result2.png b/assets/inpainting_result2.png deleted file mode 100644 index 2945f9f9..00000000 Binary files a/assets/inpainting_result2.png and /dev/null differ diff --git a/assets/network.jpg b/assets/network.jpg deleted file mode 100644 index 5aaa6bd1..00000000 Binary files a/assets/network.jpg and /dev/null differ diff --git a/assets/restoration_result1.png b/assets/restoration_result1.png deleted file mode 100644 index 8fd3b67e..00000000 Binary files a/assets/restoration_result1.png and /dev/null differ diff --git a/assets/restoration_result2.png b/assets/restoration_result2.png deleted file mode 100644 index a2ff2827..00000000 Binary files a/assets/restoration_result2.png and /dev/null differ diff --git a/assets/restoration_result3.png b/assets/restoration_result3.png deleted file mode 100644 index 022d7642..00000000 Binary files a/assets/restoration_result3.png and /dev/null differ diff --git a/assets/restoration_result4.png b/assets/restoration_result4.png deleted file mode 100644 index 5e965076..00000000 Binary files a/assets/restoration_result4.png and /dev/null differ diff --git a/assets/wechat.jpeg b/assets/wechat.jpeg new file mode 100755 index 00000000..f641fd9c Binary files /dev/null and b/assets/wechat.jpeg differ diff --git a/basicsr/VERSION b/basicsr/VERSION old mode 100644 new mode 100755 diff --git a/basicsr/__init__.py b/basicsr/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/arcface_arch.py b/basicsr/archs/arcface_arch.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/codeformer_arch.py b/basicsr/archs/codeformer_arch.py old mode 100644 new mode 100755 index de66937d..e74dea82 --- a/basicsr/archs/codeformer_arch.py +++ b/basicsr/archs/codeformer_arch.py @@ -9,6 +9,10 @@ from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY +from einops import rearrange + +from tqdm import tqdm + def calc_mean_std(feat, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. @@ -110,15 +114,25 @@ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="ge self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) + self.temp_norm = nn.LayerNorm(embed_dim) + self.temp_dropout = nn.Dropout(dropout) + self.temp_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + + self.temp_ffn_norm = nn.LayerNorm(embed_dim) + self.temp_linear1 = nn.Linear(embed_dim, dim_mlp) + self.temp_ffn_dropout = nn.Dropout(dropout) + self.temp_linear2 = nn.Linear(dim_mlp, embed_dim) + self.activation = _get_activation_fn(activation) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos - def forward(self, tgt, + def forward(self, tgt, video_length, batch_size, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): + query_pos: Optional[Tensor] = None + ): # self attention tgt2 = self.norm1(tgt) @@ -131,6 +145,22 @@ def forward(self, tgt, tgt2 = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout2(tgt2) + + # tmp attn + tgt = rearrange(tgt, "d (b f) c -> f (b d) c", f=video_length) + tgt2 = self.temp_norm(tgt) + query_pos = rearrange(query_pos, "d (b f) c -> f (b d) c", f=video_length) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.temp_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.temp_dropout(tgt2) + tgt = rearrange(tgt, "f (b d) c -> d (b f) c", b=batch_size) + + # ffn + tgt2 = self.temp_ffn_norm(tgt) + tgt2 = self.temp_linear2(self.temp_ffn_dropout(self.activation(self.temp_linear1(tgt2)))) + tgt = tgt + self.temp_ffn_dropout(tgt2) + return tgt class Fuse_sft_block(nn.Module): @@ -151,9 +181,9 @@ def __init__(self, in_ch, out_ch): def forward(self, enc_feat, dec_feat, w=1): enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual + shift = self.shift(enc_feat) + out = w * (dec_feat * scale + shift) + dec_feat + return out @@ -166,13 +196,13 @@ def __init__(self, dim_embd=512, n_head=8, n_layers=9, super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) if vqgan_path is not None: - self.load_state_dict( + m, n = self.load_state_dict( torch.load(vqgan_path, map_location='cpu')['params_ema']) - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False + # if fix_modules is not None: + # for module in fix_modules: + # for param in getattr(self, module).parameters(): + # param.requires_grad = False self.connect_list = connect_list self.n_layers = n_layers @@ -221,6 +251,8 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + b, f, _, _, _ = x.shape + x = rearrange(x, "b f c h w -> (b f) c h w") # ################### Encoder ##################### enc_feat_dict = {} out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] @@ -238,7 +270,7 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): query_emb = feat_emb # Transformer encoder for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) + query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b) # output logits logits = self.idx_pred_layer(query_emb) # (hw)bn @@ -248,17 +280,10 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): # logits doesn't need softmax before cross_entropy loss return logits, lq_feat - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ soft_one_hot = F.softmax(logits, dim=2) _, top_idx = torch.topk(soft_one_hot, 1, dim=2) quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + if detach_16: quant_feat = quant_feat.detach() # for training stage III @@ -269,12 +294,69 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): x = quant_feat fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.generator.blocks): x = block(x) if i in fuse_list: # fuse after i-th block f_size = str(x.shape[-1]) if w>0: x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + + x = rearrange(x, "(b f) c h w -> b f c h w", f=f) out = x # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat \ No newline at end of file + return out, logits, lq_feat + + + def inference(self, x, w=0, detach_16=True, adain=False): + with torch.no_grad(): + b, f, _, _, _ = x.shape + x = rearrange(x, "b f c h w -> (b f) c h w") + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.detach().cpu().clone() + + lq_feat = x.detach() + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].to(x.device), x, w) + + x = rearrange(x, "(b f) c h w -> b f c h w", f=f) + # logits doesn't need softmax before cross_entropy loss + return x, top_idx + \ No newline at end of file diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py old mode 100644 new mode 100755 diff --git a/basicsr/archs/vqgan_arch.py b/basicsr/archs/vqgan_arch.py old mode 100644 new mode 100755 index 5ac69263..3a65de10 --- a/basicsr/archs/vqgan_arch.py +++ b/basicsr/archs/vqgan_arch.py @@ -11,6 +11,7 @@ from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY + def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py old mode 100644 new mode 100755 diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py old mode 100644 new mode 100755 diff --git a/basicsr/data/ffhq_blind_joint_dataset.py b/basicsr/data/ffhq_blind_joint_dataset.py deleted file mode 100755 index 0dc845f7..00000000 --- a/basicsr/data/ffhq_blind_joint_dataset.py +++ /dev/null @@ -1,324 +0,0 @@ -import cv2 -import math -import random -import numpy as np -import os.path as osp -from scipy.io import loadmat -import torch -import torch.utils.data as data -from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, - adjust_hue, adjust_saturation, normalize) -from basicsr.data import gaussian_kernels as gaussian_kernels -from basicsr.data.transforms import augment -from basicsr.data.data_util import paths_from_folder -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor -from basicsr.utils.registry import DATASET_REGISTRY - -@DATASET_REGISTRY.register() -class FFHQBlindJointDataset(data.Dataset): - - def __init__(self, opt): - super(FFHQBlindJointDataset, self).__init__() - logger = get_root_logger() - self.opt = opt - # file client (io backend) - self.file_client = None - self.io_backend_opt = opt['io_backend'] - - self.gt_folder = opt['dataroot_gt'] - self.gt_size = opt.get('gt_size', 512) - self.in_size = opt.get('in_size', 512) - assert self.gt_size >= self.in_size, 'Wrong setting.' - - self.mean = opt.get('mean', [0.5, 0.5, 0.5]) - self.std = opt.get('std', [0.5, 0.5, 0.5]) - - self.component_path = opt.get('component_path', None) - self.latent_gt_path = opt.get('latent_gt_path', None) - - if self.component_path is not None: - self.crop_components = True - self.components_dict = torch.load(self.component_path) - self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) - self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) - self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) - else: - self.crop_components = False - - if self.latent_gt_path is not None: - self.load_latent_gt = True - self.latent_gt_dict = torch.load(self.latent_gt_path) - else: - self.load_latent_gt = False - - if self.io_backend_opt['type'] == 'lmdb': - self.io_backend_opt['db_paths'] = self.gt_folder - if not self.gt_folder.endswith('.lmdb'): - raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') - with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: - self.paths = [line.split('.')[0] for line in fin] - else: - self.paths = paths_from_folder(self.gt_folder) - - # perform corrupt - self.use_corrupt = opt.get('use_corrupt', True) - self.use_motion_kernel = False - # self.use_motion_kernel = opt.get('use_motion_kernel', True) - - if self.use_motion_kernel: - self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) - motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') - self.motion_kernels = torch.load(motion_kernel_path) - - if self.use_corrupt: - # degradation configurations - self.blur_kernel_size = self.opt['blur_kernel_size'] - self.kernel_list = self.opt['kernel_list'] - self.kernel_prob = self.opt['kernel_prob'] - # Small degradation - self.blur_sigma = self.opt['blur_sigma'] - self.downsample_range = self.opt['downsample_range'] - self.noise_range = self.opt['noise_range'] - self.jpeg_range = self.opt['jpeg_range'] - # Large degradation - self.blur_sigma_large = self.opt['blur_sigma_large'] - self.downsample_range_large = self.opt['downsample_range_large'] - self.noise_range_large = self.opt['noise_range_large'] - self.jpeg_range_large = self.opt['jpeg_range_large'] - - # print - logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') - logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') - logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') - logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') - - # color jitter - self.color_jitter_prob = opt.get('color_jitter_prob', None) - self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) - self.color_jitter_shift = opt.get('color_jitter_shift', 20) - if self.color_jitter_prob is not None: - logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') - - # to gray - self.gray_prob = opt.get('gray_prob', 0.0) - if self.gray_prob is not None: - logger.info(f'Use random gray. Prob: {self.gray_prob}') - self.color_jitter_shift /= 255. - - @staticmethod - def color_jitter(img, shift): - """jitter color: randomly jitter the RGB values, in numpy formats""" - jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) - img = img + jitter_val - img = np.clip(img, 0, 1) - return img - - @staticmethod - def color_jitter_pt(img, brightness, contrast, saturation, hue): - """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" - fn_idx = torch.randperm(4) - for fn_id in fn_idx: - if fn_id == 0 and brightness is not None: - brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() - img = adjust_brightness(img, brightness_factor) - - if fn_id == 1 and contrast is not None: - contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() - img = adjust_contrast(img, contrast_factor) - - if fn_id == 2 and saturation is not None: - saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() - img = adjust_saturation(img, saturation_factor) - - if fn_id == 3 and hue is not None: - hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() - img = adjust_hue(img, hue_factor) - return img - - - def get_component_locations(self, name, status): - components_bbox = self.components_dict[name] - if status[0]: # hflip - # exchange right and left eye - tmp = components_bbox['left_eye'] - components_bbox['left_eye'] = components_bbox['right_eye'] - components_bbox['right_eye'] = tmp - # modify the width coordinate - components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] - components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] - components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] - components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] - - locations_gt = {} - locations_in = {} - for part in ['left_eye', 'right_eye', 'nose', 'mouth']: - mean = components_bbox[part][0:2] - half_len = components_bbox[part][2] - if 'eye' in part: - half_len *= self.eye_enlarge_ratio - elif part == 'nose': - half_len *= self.nose_enlarge_ratio - elif part == 'mouth': - half_len *= self.mouth_enlarge_ratio - loc = np.hstack((mean - half_len + 1, mean + half_len)) - loc = torch.from_numpy(loc).float() - locations_gt[part] = loc - loc_in = loc/(self.gt_size//self.in_size) - locations_in[part] = loc_in - return locations_gt, locations_in - - - def __getitem__(self, index): - if self.file_client is None: - self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) - - # load gt image - gt_path = self.paths[index] - name = osp.basename(gt_path)[:-4] - img_bytes = self.file_client.get(gt_path) - img_gt = imfrombytes(img_bytes, float32=True) - - # random horizontal flip - img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) - - if self.load_latent_gt: - if status[0]: - latent_gt = self.latent_gt_dict['hflip'][name] - else: - latent_gt = self.latent_gt_dict['orig'][name] - - if self.crop_components: - locations_gt, locations_in = self.get_component_locations(name, status) - - # generate in image - img_in = img_gt - if self.use_corrupt: - # motion blur - if self.use_motion_kernel and random.random() < self.motion_kernel_prob: - m_i = random.randint(0,31) - k = self.motion_kernels[f'{m_i:02d}'] - img_in = cv2.filter2D(img_in,-1,k) - - # gaussian blur - kernel = gaussian_kernels.random_mixed_kernels( - self.kernel_list, - self.kernel_prob, - self.blur_kernel_size, - self.blur_sigma, - self.blur_sigma, - [-math.pi, math.pi], - noise_range=None) - img_in = cv2.filter2D(img_in, -1, kernel) - - # downsample - scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) - img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) - - # noise - if self.noise_range is not None: - noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) - noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma - img_in = img_in + noise - img_in = np.clip(img_in, 0, 1) - - # jpeg - if self.jpeg_range is not None: - jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) - encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] - _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) - img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. - - # resize to in_size - img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) - - - # generate in_large with large degradation - img_in_large = img_gt - - if self.use_corrupt: - # motion blur - if self.use_motion_kernel and random.random() < self.motion_kernel_prob: - m_i = random.randint(0,31) - k = self.motion_kernels[f'{m_i:02d}'] - img_in_large = cv2.filter2D(img_in_large,-1,k) - - # gaussian blur - kernel = gaussian_kernels.random_mixed_kernels( - self.kernel_list, - self.kernel_prob, - self.blur_kernel_size, - self.blur_sigma_large, - self.blur_sigma_large, - [-math.pi, math.pi], - noise_range=None) - img_in_large = cv2.filter2D(img_in_large, -1, kernel) - - # downsample - scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1]) - img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) - - # noise - if self.noise_range_large is not None: - noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.) - noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma - img_in_large = img_in_large + noise - img_in_large = np.clip(img_in_large, 0, 1) - - # jpeg - if self.jpeg_range_large is not None: - jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1]) - encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] - _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param) - img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255. - - # resize to in_size - img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) - - - # random color jitter (only for lq) - if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): - img_in = self.color_jitter(img_in, self.color_jitter_shift) - img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift) - # random to gray (only for lq) - if self.gray_prob and np.random.uniform() < self.gray_prob: - img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) - img_in = np.tile(img_in[:, :, None], [1, 1, 3]) - img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY) - img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3]) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True) - - # random color jitter (pytorch version) (only for lq) - if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): - brightness = self.opt.get('brightness', (0.5, 1.5)) - contrast = self.opt.get('contrast', (0.5, 1.5)) - saturation = self.opt.get('saturation', (0, 1.5)) - hue = self.opt.get('hue', (-0.1, 0.1)) - img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) - img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue) - - # round and clip - img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. - img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255. - - # Set vgg range_norm=True if use the normalization here - # normalize - normalize(img_in, self.mean, self.std, inplace=True) - normalize(img_in_large, self.mean, self.std, inplace=True) - normalize(img_gt, self.mean, self.std, inplace=True) - - return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path} - - if self.crop_components: - return_dict['locations_in'] = locations_in - return_dict['locations_gt'] = locations_gt - - if self.load_latent_gt: - return_dict['latent_gt'] = latent_gt - - return return_dict - - - def __len__(self): - return len(self.paths) diff --git a/basicsr/data/gaussian_kernels.py b/basicsr/data/gaussian_kernels.py index 0ce57f0a..a7c05a33 100755 --- a/basicsr/data/gaussian_kernels.py +++ b/basicsr/data/gaussian_kernels.py @@ -632,16 +632,16 @@ def show_one_kernel(): plt.show() -def show_plateau_kernel(): - import matplotlib.pyplot as plt - kernel_size = 21 +# def show_plateau_kernel(): +# import matplotlib.pyplot as plt +# kernel_size = 21 - kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None) - kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5) - kernel_gau = bivariate_generalized_Gaussian( - kernel_size, 2, 4, -math.pi / 8, 2, grid=None) - delta_h, delta_w = mass_center_shift(kernel_size, kernel) - print(delta_h, delta_w) +# kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None) +# kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5) +# kernel_gau = bivariate_generalized_Gaussian( +# kernel_size, 2, 4, -math.pi / 8, 2, grid=None) +# delta_h, delta_w = mass_center_shift(kernel_size, kernel) +# print(delta_h, delta_w) # kernel_slice = kernel[10, :] # kernel_gau_slice = kernel_gau[10, :] @@ -662,29 +662,29 @@ def show_plateau_kernel(): # ax.plot(t, y2) # plt.show() - fig, axs = plt.subplots(nrows=2, ncols=2) - # axs.set_axis_off() - ax = axs[0][0] - im = ax.matshow(kernel, cmap='jet', origin='upper') - fig.colorbar(im, ax=ax) - - # image - ax = axs[0][1] - kernel_vis = kernel - np.min(kernel) - kernel_vis = kernel_vis / np.max(kernel_vis) * 255. - ax.imshow(kernel_vis, interpolation='nearest') - - _, xx, yy = mesh_grid(kernel_size) - # contour - ax = axs[1][0] - CS = ax.contour(xx, yy, kernel, origin='upper') - ax.clabel(CS, inline=1, fontsize=3) + # fig, axs = plt.subplots(nrows=2, ncols=2) + # # axs.set_axis_off() + # ax = axs[0][0] + # im = ax.matshow(kernel, cmap='jet', origin='upper') + # fig.colorbar(im, ax=ax) + + # # image + # ax = axs[0][1] + # kernel_vis = kernel - np.min(kernel) + # kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + # ax.imshow(kernel_vis, interpolation='nearest') + + # _, xx, yy = mesh_grid(kernel_size) + # # contour + # ax = axs[1][0] + # CS = ax.contour(xx, yy, kernel, origin='upper') + # ax.clabel(CS, inline=1, fontsize=3) + + # # contourf + # ax = axs[1][1] + # kernel = kernel / np.max(kernel) + # p = ax.contourf( + # xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + # fig.colorbar(p) - # contourf - ax = axs[1][1] - kernel = kernel / np.max(kernel) - p = ax.contourf( - xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) - fig.colorbar(p) - - plt.show() + # plt.show() diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py deleted file mode 100644 index c6a6c07b..00000000 --- a/basicsr/data/paired_image_dataset.py +++ /dev/null @@ -1,101 +0,0 @@ -from torch.utils import data as data -from torchvision.transforms.functional import normalize - -from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file -from basicsr.data.transforms import augment, paired_random_crop -from basicsr.utils import FileClient, imfrombytes, img2tensor -from basicsr.utils.registry import DATASET_REGISTRY - - -@DATASET_REGISTRY.register() -class PairedImageDataset(data.Dataset): - """Paired image dataset for image restoration. - - Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and - GT image pairs. - - There are three modes: - 1. 'lmdb': Use lmdb files. - If opt['io_backend'] == lmdb. - 2. 'meta_info_file': Use meta information file to generate paths. - If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. - 3. 'folder': Scan folders to generate paths. - The rest. - - Args: - opt (dict): Config for train datasets. It contains the following keys: - dataroot_gt (str): Data root path for gt. - dataroot_lq (str): Data root path for lq. - meta_info_file (str): Path for meta information file. - io_backend (dict): IO backend type and other kwarg. - filename_tmpl (str): Template for each filename. Note that the - template excludes the file extension. Default: '{}'. - gt_size (int): Cropped patched size for gt patches. - use_flip (bool): Use horizontal flips. - use_rot (bool): Use rotation (use vertical flip and transposing h - and w for implementation). - - scale (bool): Scale, which will be added automatically. - phase (str): 'train' or 'val'. - """ - - def __init__(self, opt): - super(PairedImageDataset, self).__init__() - self.opt = opt - # file client (io backend) - self.file_client = None - self.io_backend_opt = opt['io_backend'] - self.mean = opt['mean'] if 'mean' in opt else None - self.std = opt['std'] if 'std' in opt else None - - self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] - if 'filename_tmpl' in opt: - self.filename_tmpl = opt['filename_tmpl'] - else: - self.filename_tmpl = '{}' - - if self.io_backend_opt['type'] == 'lmdb': - self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] - self.io_backend_opt['client_keys'] = ['lq', 'gt'] - self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) - elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: - self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], - self.opt['meta_info_file'], self.filename_tmpl) - else: - self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) - - def __getitem__(self, index): - if self.file_client is None: - self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) - - scale = self.opt['scale'] - - # Load gt and lq images. Dimension order: HWC; channel order: BGR; - # image range: [0, 1], float32. - gt_path = self.paths[index]['gt_path'] - img_bytes = self.file_client.get(gt_path, 'gt') - img_gt = imfrombytes(img_bytes, float32=True) - lq_path = self.paths[index]['lq_path'] - img_bytes = self.file_client.get(lq_path, 'lq') - img_lq = imfrombytes(img_bytes, float32=True) - - # augmentation for training - if self.opt['phase'] == 'train': - gt_size = self.opt['gt_size'] - # random crop - img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) - # flip, rotation - img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) - - # TODO: color space transform - # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) - # normalize - if self.mean is not None or self.std is not None: - normalize(img_lq, self.mean, self.std, inplace=True) - normalize(img_gt, self.mean, self.std, inplace=True) - - return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} - - def __len__(self): - return len(self.paths) diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py old mode 100644 new mode 100755 diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py old mode 100644 new mode 100755 diff --git a/basicsr/data/ffhq_blind_dataset.py b/basicsr/data/vfhq_dataset.py similarity index 57% rename from basicsr/data/ffhq_blind_dataset.py rename to basicsr/data/vfhq_dataset.py index 9f900606..eadbdeb5 100755 --- a/basicsr/data/ffhq_blind_dataset.py +++ b/basicsr/data/vfhq_dataset.py @@ -15,16 +15,20 @@ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY +from pathlib import Path +import torchvision.transforms as transforms + @DATASET_REGISTRY.register() -class FFHQBlindDataset(data.Dataset): +class VFHQBlindDataset(data.Dataset): def __init__(self, opt): - super(FFHQBlindDataset, self).__init__() + super(VFHQBlindDataset, self).__init__() logger = get_root_logger() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] + self.video_length = opt['video_length'] self.gt_folder = opt['dataroot_gt'] self.gt_size = opt.get('gt_size', 512) @@ -59,26 +63,25 @@ def __init__(self, opt): with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: - self.paths = paths_from_folder(self.gt_folder) + gt_folder = Path(self.gt_folder) + sub_dir = gt_folder.iterdir() + self.paths = [] + for p in sub_dir: + if p.is_dir(): + l = list(p.glob('*.png')) + if len(l) > self.video_length: + self.paths.append(str(p)) + # inpainting mask self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False) if self.gen_inpaint_mask: logger.info(f'generate mask ...') - # self.mask_max_angle = opt.get('mask_max_angle', 10) - # self.mask_max_len = opt.get('mask_max_len', 150) - # self.mask_max_width = opt.get('mask_max_width', 50) - # self.mask_draw_times = opt.get('mask_draw_times', 10) - # # print - # logger.info(f'mask_max_angle: {self.mask_max_angle}') - # logger.info(f'mask_max_len: {self.mask_max_len}') - # logger.info(f'mask_max_width: {self.mask_max_width}') - # logger.info(f'mask_draw_times: {self.mask_draw_times}') + # perform corrupt self.use_corrupt = opt.get('use_corrupt', True) self.use_motion_kernel = False - # self.use_motion_kernel = opt.get('use_motion_kernel', True) if self.use_motion_kernel: self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) @@ -113,6 +116,7 @@ def __init__(self, opt): logger.info(f'Use random gray. Prob: {self.gray_prob}') self.color_jitter_shift /= 255. + @staticmethod def color_jitter(img, shift): """jitter color: randomly jitter the RGB values, in numpy formats""" @@ -182,118 +186,113 @@ def __getitem__(self, index): # load gt image gt_path = self.paths[index] - name = osp.basename(gt_path)[:-4] - img_bytes = self.file_client.get(gt_path) - img_gt = imfrombytes(img_bytes, float32=True) - - # random horizontal flip - img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) - - if self.load_latent_gt: - if status[0]: - latent_gt = self.latent_gt_dict['hflip'][name] - else: - latent_gt = self.latent_gt_dict['orig'][name] - - if self.crop_components: - locations_gt, locations_in = self.get_component_locations(name, status) - - # generate in image - img_in = img_gt - if self.use_corrupt and not self.gen_inpaint_mask: - # motion blur - if self.use_motion_kernel and random.random() < self.motion_kernel_prob: - m_i = random.randint(0,31) - k = self.motion_kernels[f'{m_i:02d}'] - img_in = cv2.filter2D(img_in,-1,k) - - # gaussian blur - kernel = gaussian_kernels.random_mixed_kernels( - self.kernel_list, - self.kernel_prob, - self.blur_kernel_size, - self.blur_sigma, - self.blur_sigma, - [-math.pi, math.pi], - noise_range=None) - img_in = cv2.filter2D(img_in, -1, kernel) - - # downsample - scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) - img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) - - # noise - if self.noise_range is not None: - noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) - noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma - img_in = img_in + noise - img_in = np.clip(img_in, 0, 1) - - # jpeg - if self.jpeg_range is not None: - jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) - encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] - _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) - img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. - - # resize to in_size - img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) - - # if self.gen_inpaint_mask: - # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size), - # max_angle = self.mask_max_angle, max_len = self.mask_max_len, - # max_width = self.mask_max_width, times = self.mask_draw_times) - # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \ - # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1) - - # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size) - - if self.gen_inpaint_mask: - img_in = (img_in*255).astype('uint8') - img_in = brush_stroke_mask(Image.fromarray(img_in)) - img_in = np.array(img_in) / 255. - - # random color jitter (only for lq) - if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): - img_in = self.color_jitter(img_in, self.color_jitter_shift) - # random to gray (only for lq) - if self.gray_prob and np.random.uniform() < self.gray_prob: - img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) - img_in = np.tile(img_in[:, :, None], [1, 1, 3]) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True) - # random color jitter (pytorch version) (only for lq) - if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): - brightness = self.opt.get('brightness', (0.5, 1.5)) - contrast = self.opt.get('contrast', (0.5, 1.5)) - saturation = self.opt.get('saturation', (0, 1.5)) - hue = self.opt.get('hue', (-0.1, 0.1)) - img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + image_list = list(Path(gt_path).glob('*.png')) + lenght = len(image_list) - # round and clip - img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + start_idx = random.randint(0, lenght-self.video_length-1) + in_list = [] + gt_list = [] - # Set vgg range_norm=True if use the normalization here - # normalize - normalize(img_in, self.mean, self.std, inplace=True) - normalize(img_gt, self.mean, self.std, inplace=True) + for i in range(start_idx, start_idx+self.video_length): + gt_path_idx = image_list[i] - return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path} - - if self.crop_components: - return_dict['locations_in'] = locations_in - return_dict['locations_gt'] = locations_gt - - if self.load_latent_gt: - return_dict['latent_gt'] = latent_gt - - # if self.gen_inpaint_mask: - # return_dict['inpaint_mask'] = inpaint_mask + img_bytes = self.file_client.get(gt_path_idx) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + + # generate in image + img_in = img_gt + if self.use_corrupt and not self.gen_inpaint_mask: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + if self.gen_inpaint_mask: + img_in = (img_in*255).astype('uint8') + img_in = brush_stroke_mask(Image.fromarray(img_in)) + img_in = np.array(img_in) / 255. + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + img_in = img_in.unsqueeze(0) + img_gt = img_gt.unsqueeze(0) + + in_list.append(img_in) + gt_list.append(img_gt) + + in_video = torch.cat(in_list, dim=0) + gt_video = torch.cat(gt_list, dim=0) + + return_dict = {'in': in_video, 'gt': gt_video} return return_dict def __len__(self): - return len(self.paths) \ No newline at end of file + return len(self.paths) + + \ No newline at end of file diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py old mode 100644 new mode 100755 diff --git a/basicsr/losses/losses.py b/basicsr/losses/losses.py old mode 100644 new mode 100755 diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py old mode 100644 new mode 100755 diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py old mode 100644 new mode 100755 diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py old mode 100644 new mode 100755 diff --git a/basicsr/models/codeformer_joint_model.py b/basicsr/models/codeformer_joint_model.py deleted file mode 100644 index 9e4b9d12..00000000 --- a/basicsr/models/codeformer_joint_model.py +++ /dev/null @@ -1,350 +0,0 @@ -import torch -from collections import OrderedDict -from os import path as osp -from tqdm import tqdm - - -from basicsr.archs import build_network -from basicsr.losses import build_loss -from basicsr.metrics import calculate_metric -from basicsr.utils import get_root_logger, imwrite, tensor2img -from basicsr.utils.registry import MODEL_REGISTRY -import torch.nn.functional as F -from .sr_model import SRModel - - -@MODEL_REGISTRY.register() -class CodeFormerJointModel(SRModel): - def feed_data(self, data): - self.gt = data['gt'].to(self.device) - self.input = data['in'].to(self.device) - self.input_large_de = data['in_large_de'].to(self.device) - self.b = self.gt.shape[0] - - if 'latent_gt' in data: - self.idx_gt = data['latent_gt'].to(self.device) - self.idx_gt = self.idx_gt.view(self.b, -1) - else: - self.idx_gt = None - - def init_training_settings(self): - logger = get_root_logger() - train_opt = self.opt['train'] - - self.ema_decay = train_opt.get('ema_decay', 0) - if self.ema_decay > 0: - logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') - # define network net_g with Exponential Moving Average (EMA) - # net_g_ema is used only for testing on one GPU and saving - # There is no need to wrap with DistributedDataParallel - self.net_g_ema = build_network(self.opt['network_g']).to(self.device) - # load pretrained model - load_path = self.opt['path'].get('pretrain_network_g', None) - if load_path is not None: - self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') - else: - self.model_ema(0) # copy net_g weight - self.net_g_ema.eval() - - if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: - self.generate_idx_gt = False - elif self.opt.get('network_vqgan', None) is not None: - self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) - self.hq_vqgan_fix.eval() - self.generate_idx_gt = True - for param in self.hq_vqgan_fix.parameters(): - param.requires_grad = False - else: - raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') - - logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') - - self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) - self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) - self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) - self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) - self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) - - # define network net_d - self.net_d = build_network(self.opt['network_d']) - self.net_d = self.model_to_device(self.net_d) - self.print_network(self.net_d) - - # load pretrained models - load_path = self.opt['path'].get('pretrain_network_d', None) - if load_path is not None: - self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) - - self.net_g.train() - self.net_d.train() - - # define losses - if train_opt.get('pixel_opt'): - self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) - else: - self.cri_pix = None - - if train_opt.get('perceptual_opt'): - self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) - else: - self.cri_perceptual = None - - if train_opt.get('gan_opt'): - self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) - - - self.fix_generator = train_opt.get('fix_generator', True) - logger.info(f'fix_generator: {self.fix_generator}') - - self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) - self.net_d_iters = train_opt.get('net_d_iters', 1) - self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) - - # set up optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - - def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): - recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - - d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() - return d_weight - - def setup_optimizers(self): - train_opt = self.opt['train'] - # optimizer g - optim_params_g = [] - for k, v in self.net_g.named_parameters(): - if v.requires_grad: - optim_params_g.append(v) - else: - logger = get_root_logger() - logger.warning(f'Params {k} will not be optimized.') - optim_type = train_opt['optim_g'].pop('type') - self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) - self.optimizers.append(self.optimizer_g) - # optimizer d - optim_type = train_opt['optim_d'].pop('type') - self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) - self.optimizers.append(self.optimizer_d) - - def gray_resize_for_identity(self, out, size=128): - out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) - out_gray = out_gray.unsqueeze(1) - out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) - return out_gray - - def optimize_parameters(self, current_iter): - logger = get_root_logger() - # optimize net_g - for p in self.net_d.parameters(): - p.requires_grad = False - - self.optimizer_g.zero_grad() - - if self.generate_idx_gt: - x = self.hq_vqgan_fix.encoder(self.gt) - output, _, quant_stats = self.hq_vqgan_fix.quantize(x) - min_encoding_indices = quant_stats['min_encoding_indices'] - self.idx_gt = min_encoding_indices.view(self.b, -1) - - if current_iter <= 40000: # small degradation - small_per_n = 1 - w = 1 - elif current_iter <= 80000: # small degradation - small_per_n = 1 - w = 1.3 - elif current_iter <= 120000: # large degradation - small_per_n = 120000 - w = 0 - else: # mixed degradation - small_per_n = 15 - w = 1.3 - - if current_iter % small_per_n == 0: - self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True) - large_de = False - else: - logits, lq_feat = self.net_g(self.input_large_de, code_only=True) - large_de = True - - if self.hq_feat_loss: - # quant_feats - quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) - - l_g_total = 0 - loss_dict = OrderedDict() - if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: - # hq_feat_loss - if not 'transformer' in self.opt['network_g']['fix_modules']: - if self.hq_feat_loss: # codebook loss - l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight - l_g_total += l_feat_encoder - loss_dict['l_feat_encoder'] = l_feat_encoder - - # cross_entropy_loss - if self.cross_entropy_loss: - # b(hw)n -> bn(hw) - cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight - l_g_total += cross_entropy_loss - loss_dict['cross_entropy_loss'] = cross_entropy_loss - - # pixel loss - if not large_de: # when large degradation don't need image-level loss - if self.cri_pix: - l_g_pix = self.cri_pix(self.output, self.gt) - l_g_total += l_g_pix - loss_dict['l_g_pix'] = l_g_pix - - # perceptual loss - if self.cri_perceptual: - l_g_percep = self.cri_perceptual(self.output, self.gt) - l_g_total += l_g_percep - loss_dict['l_g_percep'] = l_g_percep - - # gan loss - if current_iter > self.net_d_start_iter: - fake_g_pred = self.net_d(self.output) - l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) - recon_loss = l_g_pix + l_g_percep - if not self.fix_generator: - last_layer = self.net_g.module.generator.blocks[-1].weight - d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) - else: - largest_fuse_size = self.opt['network_g']['connect_list'][-1] - last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight - d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) - - d_weight *= self.scale_adaptive_gan_weight # 0.8 - loss_dict['d_weight'] = d_weight - l_g_total += d_weight * l_g_gan - loss_dict['l_g_gan'] = d_weight * l_g_gan - - l_g_total.backward() - self.optimizer_g.step() - - if self.ema_decay > 0: - self.model_ema(decay=self.ema_decay) - - # optimize net_d - if not large_de: - if current_iter > self.net_d_start_iter: - for p in self.net_d.parameters(): - p.requires_grad = True - - self.optimizer_d.zero_grad() - # real - real_d_pred = self.net_d(self.gt) - l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) - loss_dict['l_d_real'] = l_d_real - loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) - l_d_real.backward() - # fake - fake_d_pred = self.net_d(self.output.detach()) - l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) - loss_dict['l_d_fake'] = l_d_fake - loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) - l_d_fake.backward() - - self.optimizer_d.step() - - self.log_dict = self.reduce_loss_dict(loss_dict) - - - def test(self): - with torch.no_grad(): - if hasattr(self, 'net_g_ema'): - self.net_g_ema.eval() - self.output, _, _ = self.net_g_ema(self.input, w=1) - else: - logger = get_root_logger() - logger.warning('Do not have self.net_g_ema, use self.net_g.') - self.net_g.eval() - self.output, _, _ = self.net_g(self.input, w=1) - self.net_g.train() - - - def dist_validation(self, dataloader, current_iter, tb_logger, save_img): - if self.opt['rank'] == 0: - self.nondist_validation(dataloader, current_iter, tb_logger, save_img) - - - def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): - dataset_name = dataloader.dataset.opt['name'] - with_metrics = self.opt['val'].get('metrics') is not None - if with_metrics: - self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} - pbar = tqdm(total=len(dataloader), unit='image') - - for idx, val_data in enumerate(dataloader): - img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] - self.feed_data(val_data) - self.test() - - visuals = self.get_current_visuals() - sr_img = tensor2img([visuals['result']]) - if 'gt' in visuals: - gt_img = tensor2img([visuals['gt']]) - del self.gt - - # tentative for out of GPU memory - del self.lq - del self.output - torch.cuda.empty_cache() - - if save_img: - if self.opt['is_train']: - save_img_path = osp.join(self.opt['path']['visualization'], img_name, - f'{img_name}_{current_iter}.png') - else: - if self.opt['val']['suffix']: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["val"]["suffix"]}.png') - else: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["name"]}.png') - imwrite(sr_img, save_img_path) - - if with_metrics: - # calculate metrics - for name, opt_ in self.opt['val']['metrics'].items(): - metric_data = dict(img1=sr_img, img2=gt_img) - self.metric_results[name] += calculate_metric(metric_data, opt_) - pbar.update(1) - pbar.set_description(f'Test {img_name}') - pbar.close() - - if with_metrics: - for metric in self.metric_results.keys(): - self.metric_results[metric] /= (idx + 1) - - self._log_validation_metric_values(current_iter, dataset_name, tb_logger) - - - def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): - log_str = f'Validation {dataset_name}\n' - for metric, value in self.metric_results.items(): - log_str += f'\t # {metric}: {value:.4f}\n' - logger = get_root_logger() - logger.info(log_str) - if tb_logger: - for metric, value in self.metric_results.items(): - tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) - - - def get_current_visuals(self): - out_dict = OrderedDict() - out_dict['gt'] = self.gt.detach().cpu() - out_dict['result'] = self.output.detach().cpu() - return out_dict - - - def save(self, epoch, current_iter): - if self.ema_decay > 0: - self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) - else: - self.save_network(self.net_g, 'net_g', current_iter) - self.save_network(self.net_d, 'net_d', current_iter) - self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/codeformer_model.py b/basicsr/models/codeformer_model.py deleted file mode 100644 index 61829f12..00000000 --- a/basicsr/models/codeformer_model.py +++ /dev/null @@ -1,332 +0,0 @@ -import torch -from collections import OrderedDict -from os import path as osp -from tqdm import tqdm - -from basicsr.archs import build_network -from basicsr.losses import build_loss -from basicsr.metrics import calculate_metric -from basicsr.utils import get_root_logger, imwrite, tensor2img -from basicsr.utils.registry import MODEL_REGISTRY -import torch.nn.functional as F -from .sr_model import SRModel - - -@MODEL_REGISTRY.register() -class CodeFormerModel(SRModel): - def feed_data(self, data): - self.gt = data['gt'].to(self.device) - self.input = data['in'].to(self.device) - self.b = self.gt.shape[0] - - if 'latent_gt' in data: - self.idx_gt = data['latent_gt'].to(self.device) - self.idx_gt = self.idx_gt.view(self.b, -1) - else: - self.idx_gt = None - - def init_training_settings(self): - logger = get_root_logger() - train_opt = self.opt['train'] - - self.ema_decay = train_opt.get('ema_decay', 0) - if self.ema_decay > 0: - logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') - # define network net_g with Exponential Moving Average (EMA) - # net_g_ema is used only for testing on one GPU and saving - # There is no need to wrap with DistributedDataParallel - self.net_g_ema = build_network(self.opt['network_g']).to(self.device) - # load pretrained model - load_path = self.opt['path'].get('pretrain_network_g', None) - if load_path is not None: - self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') - else: - self.model_ema(0) # copy net_g weight - self.net_g_ema.eval() - - if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None: - self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) - self.hq_vqgan_fix.eval() - self.generate_idx_gt = True - for param in self.hq_vqgan_fix.parameters(): - param.requires_grad = False - else: - self.generate_idx_gt = False - - self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) - self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) - self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) - self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) - self.fidelity_weight = train_opt.get('fidelity_weight', 1.0) - self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) - - - self.net_g.train() - # define network net_d - if self.fidelity_weight > 0: - self.net_d = build_network(self.opt['network_d']) - self.net_d = self.model_to_device(self.net_d) - self.print_network(self.net_d) - - # load pretrained models - load_path = self.opt['path'].get('pretrain_network_d', None) - if load_path is not None: - self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) - - self.net_d.train() - - # define losses - if train_opt.get('pixel_opt'): - self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) - else: - self.cri_pix = None - - if train_opt.get('perceptual_opt'): - self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) - else: - self.cri_perceptual = None - - if train_opt.get('gan_opt'): - self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) - - - self.fix_generator = train_opt.get('fix_generator', True) - logger.info(f'fix_generator: {self.fix_generator}') - - self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) - self.net_d_iters = train_opt.get('net_d_iters', 1) - self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) - - # set up optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - - def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): - recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - - d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() - return d_weight - - def setup_optimizers(self): - train_opt = self.opt['train'] - # optimizer g - optim_params_g = [] - for k, v in self.net_g.named_parameters(): - if v.requires_grad: - optim_params_g.append(v) - else: - logger = get_root_logger() - logger.warning(f'Params {k} will not be optimized.') - optim_type = train_opt['optim_g'].pop('type') - self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) - self.optimizers.append(self.optimizer_g) - # optimizer d - if self.fidelity_weight > 0: - optim_type = train_opt['optim_d'].pop('type') - self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) - self.optimizers.append(self.optimizer_d) - - def gray_resize_for_identity(self, out, size=128): - out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) - out_gray = out_gray.unsqueeze(1) - out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) - return out_gray - - def optimize_parameters(self, current_iter): - logger = get_root_logger() - # optimize net_g - for p in self.net_d.parameters(): - p.requires_grad = False - - self.optimizer_g.zero_grad() - - if self.generate_idx_gt: - x = self.hq_vqgan_fix.encoder(self.gt) - output, _, quant_stats = self.hq_vqgan_fix.quantize(x) - min_encoding_indices = quant_stats['min_encoding_indices'] - self.idx_gt = min_encoding_indices.view(self.b, -1) - - if self.fidelity_weight > 0: - self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True) - else: - logits, lq_feat = self.net_g(self.input, w=0, code_only=True) - - if self.hq_feat_loss: - # quant_feats - quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) - - l_g_total = 0 - loss_dict = OrderedDict() - if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: - # hq_feat_loss - if self.hq_feat_loss: # codebook loss - l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight - l_g_total += l_feat_encoder - loss_dict['l_feat_encoder'] = l_feat_encoder - - # cross_entropy_loss - if self.cross_entropy_loss: - # b(hw)n -> bn(hw) - cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight - l_g_total += cross_entropy_loss - loss_dict['cross_entropy_loss'] = cross_entropy_loss - - if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss - # pixel loss - if self.cri_pix: - l_g_pix = self.cri_pix(self.output, self.gt) - l_g_total += l_g_pix - loss_dict['l_g_pix'] = l_g_pix - - # perceptual loss - if self.cri_perceptual: - l_g_percep = self.cri_perceptual(self.output, self.gt) - l_g_total += l_g_percep - loss_dict['l_g_percep'] = l_g_percep - - # gan loss - if current_iter > self.net_d_start_iter: - fake_g_pred = self.net_d(self.output) - l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) - recon_loss = l_g_pix + l_g_percep - if not self.fix_generator: - last_layer = self.net_g.module.generator.blocks[-1].weight - d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) - else: - largest_fuse_size = self.opt['network_g']['connect_list'][-1] - last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight - d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) - - d_weight *= self.scale_adaptive_gan_weight # 0.8 - loss_dict['d_weight'] = d_weight - l_g_total += d_weight * l_g_gan - loss_dict['l_g_gan'] = d_weight * l_g_gan - - l_g_total.backward() - self.optimizer_g.step() - - if self.ema_decay > 0: - self.model_ema(decay=self.ema_decay) - - # optimize net_d - if current_iter > self.net_d_start_iter and self.fidelity_weight > 0: - for p in self.net_d.parameters(): - p.requires_grad = True - - self.optimizer_d.zero_grad() - # real - real_d_pred = self.net_d(self.gt) - l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) - loss_dict['l_d_real'] = l_d_real - loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) - l_d_real.backward() - # fake - fake_d_pred = self.net_d(self.output.detach()) - l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) - loss_dict['l_d_fake'] = l_d_fake - loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) - l_d_fake.backward() - - self.optimizer_d.step() - - self.log_dict = self.reduce_loss_dict(loss_dict) - - - def test(self): - with torch.no_grad(): - if hasattr(self, 'net_g_ema'): - self.net_g_ema.eval() - self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight) - else: - logger = get_root_logger() - logger.warning('Do not have self.net_g_ema, use self.net_g.') - self.net_g.eval() - self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight) - self.net_g.train() - - - def dist_validation(self, dataloader, current_iter, tb_logger, save_img): - if self.opt['rank'] == 0: - self.nondist_validation(dataloader, current_iter, tb_logger, save_img) - - - def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): - dataset_name = dataloader.dataset.opt['name'] - with_metrics = self.opt['val'].get('metrics') is not None - if with_metrics: - self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} - pbar = tqdm(total=len(dataloader), unit='image') - - for idx, val_data in enumerate(dataloader): - img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] - self.feed_data(val_data) - self.test() - - visuals = self.get_current_visuals() - sr_img = tensor2img([visuals['result']]) - if 'gt' in visuals: - gt_img = tensor2img([visuals['gt']]) - del self.gt - - # tentative for out of GPU memory - del self.lq - del self.output - torch.cuda.empty_cache() - - if save_img: - if self.opt['is_train']: - save_img_path = osp.join(self.opt['path']['visualization'], img_name, - f'{img_name}_{current_iter}.png') - else: - if self.opt['val']['suffix']: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["val"]["suffix"]}.png') - else: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["name"]}.png') - imwrite(sr_img, save_img_path) - - if with_metrics: - # calculate metrics - for name, opt_ in self.opt['val']['metrics'].items(): - metric_data = dict(img1=sr_img, img2=gt_img) - self.metric_results[name] += calculate_metric(metric_data, opt_) - pbar.update(1) - pbar.set_description(f'Test {img_name}') - pbar.close() - - if with_metrics: - for metric in self.metric_results.keys(): - self.metric_results[metric] /= (idx + 1) - - self._log_validation_metric_values(current_iter, dataset_name, tb_logger) - - - def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): - log_str = f'Validation {dataset_name}\n' - for metric, value in self.metric_results.items(): - log_str += f'\t # {metric}: {value:.4f}\n' - logger = get_root_logger() - logger.info(log_str) - if tb_logger: - for metric, value in self.metric_results.items(): - tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) - - - def get_current_visuals(self): - out_dict = OrderedDict() - out_dict['gt'] = self.gt.detach().cpu() - out_dict['result'] = self.output.detach().cpu() - return out_dict - - - def save(self, epoch, current_iter): - if self.ema_decay > 0: - self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) - else: - self.save_network(self.net_g, 'net_g', current_iter) - if self.fidelity_weight > 0: - self.save_network(self.net_d, 'net_d', current_iter) - self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/codeformer_idx_model.py b/basicsr/models/codeformer_temporal_model.py old mode 100644 new mode 100755 similarity index 90% rename from basicsr/models/codeformer_idx_model.py rename to basicsr/models/codeformer_temporal_model.py index 005957dd..44ff2d96 --- a/basicsr/models/codeformer_idx_model.py +++ b/basicsr/models/codeformer_temporal_model.py @@ -2,6 +2,7 @@ from collections import OrderedDict from os import path as osp from tqdm import tqdm +from einops import rearrange from basicsr.archs import build_network from basicsr.metrics import calculate_metric @@ -10,13 +11,16 @@ import torch.nn.functional as F from .sr_model import SRModel +# from icecream import ic @MODEL_REGISTRY.register() -class CodeFormerIdxModel(SRModel): +class CodeFormerTempModel(SRModel): def feed_data(self, data): self.gt = data['gt'].to(self.device) self.input = data['in'].to(self.device) self.b = self.gt.shape[0] + self.f = self.gt.shape[1] + self.bf = self.b * self.f if 'latent_gt' in data: self.idx_gt = data['latent_gt'].to(self.device) @@ -62,6 +66,13 @@ def init_training_settings(self): self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) self.net_g.train() + self.net_g.requires_grad_(False) + + trainable_module = train_opt['trainable_para'] + for name, module in self.net_g.named_modules(): + if trainable_module in name : + for params in module.parameters(): + params.requires_grad_(True) # set up optimizers and schedulers self.setup_optimizers() @@ -72,16 +83,18 @@ def setup_optimizers(self): train_opt = self.opt['train'] # optimizer g optim_params_g = [] + optim_name = [] for k, v in self.net_g.named_parameters(): if v.requires_grad: optim_params_g.append(v) - else: - logger = get_root_logger() - logger.warning(f'Params {k} will not be optimized.') + optim_name.append(k) + optim_type = train_opt['optim_g'].pop('type') self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g) + # ic(optim_name) + def optimize_parameters(self, current_iter): logger = get_root_logger() @@ -89,16 +102,21 @@ def optimize_parameters(self, current_iter): self.optimizer_g.zero_grad() if self.generate_idx_gt: - x = self.hq_vqgan_fix.encoder(self.gt) + x = rearrange(self.gt, "b f c h w -> (b f) c h w") + x = self.hq_vqgan_fix.encoder(x) _, _, quant_stats = self.hq_vqgan_fix.quantize(x) min_encoding_indices = quant_stats['min_encoding_indices'] - self.idx_gt = min_encoding_indices.view(self.b, -1) + # ic(min_encoding_indices.shape) + self.idx_gt = min_encoding_indices.view(self.bf, -1) + # ic(self.idx_gt.shape) if self.hq_feat_loss: # quant_feats - quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.bf,16,16,256]) logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + # ic(logits.shape) + # ic(lq_feat.shape) l_g_total = 0 loss_dict = OrderedDict() diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py old mode 100644 new mode 100755 diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py old mode 100644 new mode 100755 diff --git a/basicsr/models/vqgan_model.py b/basicsr/models/vqgan_model.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/ops/dcn/src/deform_conv_cuda.cpp old mode 100644 new mode 100755 diff --git a/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu old mode 100644 new mode 100755 diff --git a/basicsr/ops/dcn/src/deform_conv_ext.cpp b/basicsr/ops/dcn/src/deform_conv_ext.cpp old mode 100644 new mode 100755 diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/fused_act/src/fused_bias_act.cpp b/basicsr/ops/fused_act/src/fused_bias_act.cpp old mode 100644 new mode 100755 diff --git a/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu old mode 100644 new mode 100755 diff --git a/basicsr/ops/upfirdn2d/__init__.py b/basicsr/ops/upfirdn2d/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp old mode 100644 new mode 100755 diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu old mode 100644 new mode 100755 diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py old mode 100644 new mode 100755 diff --git a/basicsr/setup.py b/basicsr/setup.py old mode 100644 new mode 100755 diff --git a/basicsr/train.py b/basicsr/train.py old mode 100644 new mode 100755 index a01c0dfc..84fe0a28 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -25,7 +25,7 @@ def parse_options(root_path, is_train=True): parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--local-rank','--local_rank', type=int, default=0) args = parser.parse_args() opt = parse(args.opt, root_path, is_train=is_train) diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/realesrgan_utils.py b/basicsr/utils/realesrgan_utils.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py old mode 100644 new mode 100755 diff --git a/basicsr/utils/video_util.py b/basicsr/utils/video_util.py old mode 100644 new mode 100755 diff --git a/options/CodeFormer_stage2.yml b/configs/train/video_sr.yaml old mode 100755 new mode 100644 similarity index 83% rename from options/CodeFormer_stage2.yml rename to configs/train/video_sr.yaml index 4dfe9c9d..28ba65db --- a/options/CodeFormer_stage2.yml +++ b/configs/train/video_sr.yaml @@ -1,15 +1,15 @@ # general settings -name: CodeFormer_stage2 -model_type: CodeFormerIdxModel +name: CodeFormer_temp +model_type: CodeFormerTempModel num_gpu: 8 manual_seed: 0 # dataset and data loader settings datasets: train: - name: FFHQ - type: FFHQBlindDataset - dataroot_gt: datasets/ffhq/ffhq_512 + name: VFHQ + type: VFHQBlindDataset + dataroot_gt: ./VFHQ/image filename_tmpl: '{}' io_backend: type: disk @@ -20,6 +20,7 @@ datasets: std: [0.5, 0.5, 0.5] use_hflip: true use_corrupt: true + video_length: 16 # large degradation in stageII blur_kernel_size: 41 @@ -33,12 +34,11 @@ datasets: jpeg_range: [30, 80] latent_gt_path: ~ # without pre-calculated latent code - # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' # data loader - num_worker_per_gpu: 2 + num_worker_per_gpu: 8 batch_size_per_gpu: 4 - dataset_enlarge_ratio: 100 + dataset_enlarge_ratio: 1 prefetch_mode: ~ # val: @@ -61,7 +61,7 @@ network_g: codebook_size: 1024 connect_list: ['32', '64', '128', '256'] fix_modules: ['quantize','generator'] - vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN + vqgan_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN network_vqgan: # this config is needed if no pre-calculated latent type: VQAutoEncoder @@ -70,10 +70,11 @@ network_vqgan: # this config is needed if no pre-calculated latent ch_mult: [1, 2, 2, 4, 4, 8] quantizer: 'nearest' codebook_size: 1024 + model_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' # path path: - pretrain_network_g: ~ + pretrain_network_g: './pretrained_models/CodeFormer/codeformer.pth' param_key_g: params_ema strict_load_g: false pretrain_network_d: ~ @@ -88,6 +89,8 @@ train: entropy_loss_weight: 0.5 fidelity_weight: 0 + trainable_para: temp + optim_g: type: Adam lr: !!float 1e-4 @@ -119,7 +122,7 @@ train: # validation settings val: - val_freq: !!float 5e10 # no validation + val_freq: 1000 save_img: true metrics: @@ -130,8 +133,8 @@ val: # logging settings logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 + print_freq: 1 + save_checkpoint_freq: 1000 use_tb_logger: true wandb: project: ~ diff --git a/docs/history_changelog.md b/docs/history_changelog.md deleted file mode 100644 index 6c35e34e..00000000 --- a/docs/history_changelog.md +++ /dev/null @@ -1,15 +0,0 @@ -# History of Changelog - -- **2023.04.19**: :whale: Training codes and config files are public available now. -- **2023.04.09**: Add features of inpainting and colorization for cropped face images. -- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity. -- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper: -- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) -- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) -- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement. -- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement. -- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. -- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`. -- **2022.07.17**: Add Colab demo of CodeFormer. google colab logo -- **2022.07.16**: Release inference code for face restoration. :blush: -- **2022.06.21**: This repo is created. \ No newline at end of file diff --git a/docs/train.md b/docs/train.md deleted file mode 100644 index 873ee7cc..00000000 --- a/docs/train.md +++ /dev/null @@ -1,37 +0,0 @@ -# :milky_way: Training Procedures -[English](train.md) **|** [įŽ€äŊ“中文](train_CN.md) -## Preparing Dataset - -- Download training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset) - ---- - -## Training -``` -For PyTorch versions >= 1.10, please replace `python -m torch.distributed.launch` in the commands below with `torchrun`. -``` - -### 👾 Stage I - VQGAN -- Training VQGAN: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch - -- After VQGAN training, you can pre-calculate code sequence for the training dataset to speed up the later training stages: - > python scripts/generate_latent_gt.py - -- If you don't require training your own VQGAN, you can find pre-trained VQGAN (`vqgan_code1024.pth`) and the corresponding code sequence (`latent_gt_code1024.pth`) in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - -### 🚀 Stage II - CodeFormer (w=0) -- Training Code Sequence Prediction Module: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch - -- Pre-trained CodeFormer of stage II (`codeformer_stage2.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - -### 🛸 Stage III - CodeFormer (w=1) -- Training Controllable Module: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch - -- Pre-trained CodeFormer (`codeformer.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - ---- - -:whale: The project was built using the framework [BasicSR](https://github.com/XPixelGroup/BasicSR). For detailed information on training, resuming, and other related topics, please refer to the documentation: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest.md diff --git a/docs/train_CN.md b/docs/train_CN.md deleted file mode 100644 index c1ac2cc4..00000000 --- a/docs/train_CN.md +++ /dev/null @@ -1,37 +0,0 @@ -# :milky_way: 莭įģƒæ–‡æĄŖ -[English](train.md) **|** [įŽ€äŊ“中文](train_CN.md) - -## 准备数捎集 -- 下čŊŊ莭įģƒæ•°æŽé›†: [FFHQ](https://github.com/NVlabs/ffhq-dataset) - ---- - -## 莭įģƒ -``` -寚äēŽPyTorchį‰ˆæœŦ >= 1.10, č¯ˇå°†ä¸‹éĸå‘Ŋäģ¤ä¸­įš„`python -m torch.distributed.launch`æ›ŋæĸä¸ē`torchrun`. -``` - -### 👾 é˜ļæŽĩ I - VQGAN -- 莭įģƒVQGAN: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch - -- 莭įģƒåŽŒVQGAN后īŧŒå¯äģĨ通čŋ‡ä¸‹éĸäģŖį éĸ„å…ˆčŽˇåž—čŽ­įģƒæ•°æŽé›†įš„密į æœŦåēåˆ—īŧŒäģŽč€ŒåŠ é€ŸåŽéĸé˜ļæŽĩįš„莭įģƒčŋ‡į¨‹: - > python scripts/generate_latent_gt.py - -- åĻ‚æžœäŊ ä¸éœ€čĻčŽ­įģƒč‡Ēåˇąįš„VQGANīŧŒå¯äģĨ在Release v0.1.0文æĄŖ中扞到éĸ„莭įģƒįš„VQGAN (`vqgan_code1024.pth`)和寚åē”įš„密į æœŦåēåˆ— (`latent_gt_code1024.pth`): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - -### 🚀 é˜ļæŽĩ II - CodeFormer (w=0) -- 莭įģƒå¯†į æœŦ莭įģƒéĸ„æĩ‹æ¨Ąå—: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch - -- éĸ„莭įģƒCodeFormerįŦŦäēŒé˜ļæŽĩæ¨Ąåž‹ (`codeformer_stage2.pth`)可äģĨ在Releases v0.1.0文æĄŖ里下čŊŊ: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - -### 🛸 é˜ļæŽĩ III - CodeFormer (w=1) -- 莭įģƒå¯č°ƒæ¨Ąå—: - > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch - -- éĸ„莭įģƒCodeFormeræ¨Ąåž‹ (`codeformer.pth`)可äģĨ在Releases v0.1.0文æĄŖ里下čŊŊ: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 - ---- - -:whale: č¯Ĩ饚į›Žæ˜¯åŸēäēŽ[BasicSR](https://github.com/XPixelGroup/BasicSR)æĄ†æžļ搭åģēīŧŒæœ‰å…ŗ莭įģƒã€Resumeį­‰č¯Ļįģ†äģ‹įģå¯äģĨæŸĨįœ‹æ–‡æĄŖ: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest_CN.md \ No newline at end of file diff --git a/inference_colorization.py b/inference_colorization.py deleted file mode 100644 index 0f1b763c..00000000 --- a/inference_colorization.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -import cv2 -import argparse -import glob -import torch -from torchvision.transforms.functional import normalize -from basicsr.utils import imwrite, img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from basicsr.utils.misc import get_device -from basicsr.utils.registry import ARCH_REGISTRY - -pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_colorization.pth' - -if __name__ == '__main__': - # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - device = get_device() - parser = argparse.ArgumentParser() - - parser.add_argument('-i', '--input_path', type=str, default='./inputs/gray_faces', - help='Input image or folder. Default: inputs/gray_faces') - parser.add_argument('-o', '--output_path', type=str, default=None, - help='Output folder. Default: results/') - parser.add_argument('--suffix', type=str, default=None, - help='Suffix of the restored faces. Default: None') - args = parser.parse_args() - - # ------------------------ input & output ------------------------ - print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.') - if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path - input_img_list = [args.input_path] - result_root = f'results/test_colorization_img' - else: # input img folder - if args.input_path.endswith('/'): # solve when path ends with / - args.input_path = args.input_path[:-1] - # scan all the jpg and png images - input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]'))) - result_root = f'results/{os.path.basename(args.input_path)}' - - if not args.output_path is None: # set output path - result_root = args.output_path - - test_img_num = len(input_img_list) - - # ------------------ set up CodeFormer restorer ------------------- - net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, - connect_list=['32', '64', '128']).to(device) - - # ckpt_path = 'weights/CodeFormer/codeformer.pth' - ckpt_path = load_file_from_url(url=pretrain_model_url, - model_dir='weights/CodeFormer', progress=True, file_name=None) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() - - # -------------------- start to processing --------------------- - for i, img_path in enumerate(input_img_list): - img_name = os.path.basename(img_path) - basename, ext = os.path.splitext(img_name) - print(f'[{i+1}/{test_img_num}] Processing: {img_name}') - input_face = cv2.imread(img_path) - assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for colorization.' - # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR) - input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True) - normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - input_face = input_face.unsqueeze(0).to(device) - try: - with torch.no_grad(): - # w is fixed to 0 since we didn't train the Stage III for colorization - output_face = net(input_face, w=0, adain=True)[0] - save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1)) - del output_face - torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}') - save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1)) - - save_face = save_face.astype('uint8') - - # save face - if args.suffix is not None: - basename = f'{basename}_{args.suffix}' - save_restore_path = os.path.join(result_root, f'{basename}.png') - imwrite(save_face, save_restore_path) - - print(f'\nAll results are saved in {result_root}') - diff --git a/inference_inpainting.py b/inference_inpainting.py deleted file mode 100644 index 9cbfb69a..00000000 --- a/inference_inpainting.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import cv2 -import argparse -import glob -import torch -from torchvision.transforms.functional import normalize -from basicsr.utils import imwrite, img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from basicsr.utils.misc import get_device -from basicsr.utils.registry import ARCH_REGISTRY - -pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_inpainting.pth' - -if __name__ == '__main__': - # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - device = get_device() - parser = argparse.ArgumentParser() - - parser.add_argument('-i', '--input_path', type=str, default='./inputs/masked_faces', - help='Input image or folder. Default: inputs/masked_faces') - parser.add_argument('-o', '--output_path', type=str, default=None, - help='Output folder. Default: results/') - parser.add_argument('--suffix', type=str, default=None, - help='Suffix of the restored faces. Default: None') - args = parser.parse_args() - - # ------------------------ input & output ------------------------ - print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.') - if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path - input_img_list = [args.input_path] - result_root = f'results/test_inpainting_img' - else: # input img folder - if args.input_path.endswith('/'): # solve when path ends with / - args.input_path = args.input_path[:-1] - # scan all the jpg and png images - input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]'))) - result_root = f'results/{os.path.basename(args.input_path)}' - - if not args.output_path is None: # set output path - result_root = args.output_path - - test_img_num = len(input_img_list) - - # ------------------ set up CodeFormer restorer ------------------- - net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=512, n_head=8, n_layers=9, - connect_list=['32', '64', '128']).to(device) - - # ckpt_path = 'weights/CodeFormer/codeformer.pth' - ckpt_path = load_file_from_url(url=pretrain_model_url, - model_dir='weights/CodeFormer', progress=True, file_name=None) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() - - # -------------------- start to processing --------------------- - for i, img_path in enumerate(input_img_list): - img_name = os.path.basename(img_path) - basename, ext = os.path.splitext(img_name) - print(f'[{i+1}/{test_img_num}] Processing: {img_name}') - input_face = cv2.imread(img_path) - assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for inpainting.' - # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR) - input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True) - normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - input_face = input_face.unsqueeze(0).to(device) - try: - with torch.no_grad(): - mask = torch.zeros(512, 512) - m_ind = torch.sum(input_face[0], dim=0) - mask[m_ind==3] = 1.0 - mask = mask.view(1, 1, 512, 512).to(device) - # w is fixed to 1, adain=False for inpainting - output_face = net(input_face, w=1, adain=False)[0] - output_face = (1-mask)*input_face + mask*output_face - save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1)) - del output_face - torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}') - save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1)) - - save_face = save_face.astype('uint8') - - # save face - if args.suffix is not None: - basename = f'{basename}_{args.suffix}' - save_restore_path = os.path.join(result_root, f'{basename}.png') - imwrite(save_face, save_restore_path) - - print(f'\nAll results are saved in {result_root}') - diff --git a/inputs/cropped_faces/0143.png b/inputs/cropped_faces/0143.png deleted file mode 100644 index 065b3f00..00000000 Binary files a/inputs/cropped_faces/0143.png and /dev/null differ diff --git a/inputs/cropped_faces/0240.png b/inputs/cropped_faces/0240.png deleted file mode 100644 index 7a117017..00000000 Binary files a/inputs/cropped_faces/0240.png and /dev/null differ diff --git a/inputs/cropped_faces/0342.png b/inputs/cropped_faces/0342.png deleted file mode 100644 index 8f5aeeae..00000000 Binary files a/inputs/cropped_faces/0342.png and /dev/null differ diff --git a/inputs/cropped_faces/0345.png b/inputs/cropped_faces/0345.png deleted file mode 100644 index b8f71b6d..00000000 Binary files a/inputs/cropped_faces/0345.png and /dev/null differ diff --git a/inputs/cropped_faces/0368.png b/inputs/cropped_faces/0368.png deleted file mode 100644 index 262778a9..00000000 Binary files a/inputs/cropped_faces/0368.png and /dev/null differ diff --git a/inputs/cropped_faces/0412.png b/inputs/cropped_faces/0412.png deleted file mode 100644 index c4a63b13..00000000 Binary files a/inputs/cropped_faces/0412.png and /dev/null differ diff --git a/inputs/cropped_faces/0444.png b/inputs/cropped_faces/0444.png deleted file mode 100644 index 9028dd0e..00000000 Binary files a/inputs/cropped_faces/0444.png and /dev/null differ diff --git a/inputs/cropped_faces/0478.png b/inputs/cropped_faces/0478.png deleted file mode 100644 index f0924061..00000000 Binary files a/inputs/cropped_faces/0478.png and /dev/null differ diff --git a/inputs/cropped_faces/0500.png b/inputs/cropped_faces/0500.png deleted file mode 100644 index 7a7146b4..00000000 Binary files a/inputs/cropped_faces/0500.png and /dev/null differ diff --git a/inputs/cropped_faces/0599.png b/inputs/cropped_faces/0599.png deleted file mode 100644 index ff26ccda..00000000 Binary files a/inputs/cropped_faces/0599.png and /dev/null differ diff --git a/inputs/cropped_faces/0717.png b/inputs/cropped_faces/0717.png deleted file mode 100644 index 9342b5e5..00000000 Binary files a/inputs/cropped_faces/0717.png and /dev/null differ diff --git a/inputs/cropped_faces/0720.png b/inputs/cropped_faces/0720.png deleted file mode 100644 index af384dce..00000000 Binary files a/inputs/cropped_faces/0720.png and /dev/null differ diff --git a/inputs/cropped_faces/0729.png b/inputs/cropped_faces/0729.png deleted file mode 100644 index 4f70f46e..00000000 Binary files a/inputs/cropped_faces/0729.png and /dev/null differ diff --git a/inputs/cropped_faces/0763.png b/inputs/cropped_faces/0763.png deleted file mode 100644 index 1263df7b..00000000 Binary files a/inputs/cropped_faces/0763.png and /dev/null differ diff --git a/inputs/cropped_faces/0770.png b/inputs/cropped_faces/0770.png deleted file mode 100644 index 40a64e83..00000000 Binary files a/inputs/cropped_faces/0770.png and /dev/null differ diff --git a/inputs/cropped_faces/0777.png b/inputs/cropped_faces/0777.png deleted file mode 100644 index c72cb26f..00000000 Binary files a/inputs/cropped_faces/0777.png and /dev/null differ diff --git a/inputs/cropped_faces/0885.png b/inputs/cropped_faces/0885.png deleted file mode 100644 index f3ea2632..00000000 Binary files a/inputs/cropped_faces/0885.png and /dev/null differ diff --git a/inputs/cropped_faces/0934.png b/inputs/cropped_faces/0934.png deleted file mode 100644 index bf82c2d3..00000000 Binary files a/inputs/cropped_faces/0934.png and /dev/null differ diff --git a/inputs/cropped_faces/Solvay_conference_1927_0018.png b/inputs/cropped_faces/Solvay_conference_1927_0018.png deleted file mode 100644 index 0f79547a..00000000 Binary files a/inputs/cropped_faces/Solvay_conference_1927_0018.png and /dev/null differ diff --git a/inputs/cropped_faces/Solvay_conference_1927_2_16.png b/inputs/cropped_faces/Solvay_conference_1927_2_16.png deleted file mode 100644 index f75b9602..00000000 Binary files a/inputs/cropped_faces/Solvay_conference_1927_2_16.png and /dev/null differ diff --git a/inputs/gray_faces/067_David_Beckham_00.png b/inputs/gray_faces/067_David_Beckham_00.png deleted file mode 100644 index 69dc41a6..00000000 Binary files a/inputs/gray_faces/067_David_Beckham_00.png and /dev/null differ diff --git a/inputs/gray_faces/089_Miley_Cyrus_00.png b/inputs/gray_faces/089_Miley_Cyrus_00.png deleted file mode 100644 index 29f56936..00000000 Binary files a/inputs/gray_faces/089_Miley_Cyrus_00.png and /dev/null differ diff --git a/inputs/gray_faces/099_Victoria_Beckham_00.png b/inputs/gray_faces/099_Victoria_Beckham_00.png deleted file mode 100644 index ccf375f3..00000000 Binary files a/inputs/gray_faces/099_Victoria_Beckham_00.png and /dev/null differ diff --git a/inputs/gray_faces/111_Alexa_Chung_00.png b/inputs/gray_faces/111_Alexa_Chung_00.png deleted file mode 100644 index 759c8ebf..00000000 Binary files a/inputs/gray_faces/111_Alexa_Chung_00.png and /dev/null differ diff --git a/inputs/gray_faces/132_Robert_Downey_Jr_00.png b/inputs/gray_faces/132_Robert_Downey_Jr_00.png deleted file mode 100644 index d4ec0cbd..00000000 Binary files a/inputs/gray_faces/132_Robert_Downey_Jr_00.png and /dev/null differ diff --git a/inputs/gray_faces/158_Jimmy_Fallon_00.png b/inputs/gray_faces/158_Jimmy_Fallon_00.png deleted file mode 100644 index aabb515d..00000000 Binary files a/inputs/gray_faces/158_Jimmy_Fallon_00.png and /dev/null differ diff --git a/inputs/gray_faces/161_Zac_Efron_00.png b/inputs/gray_faces/161_Zac_Efron_00.png deleted file mode 100644 index 264d2f48..00000000 Binary files a/inputs/gray_faces/161_Zac_Efron_00.png and /dev/null differ diff --git a/inputs/gray_faces/169_John_Lennon_00.png b/inputs/gray_faces/169_John_Lennon_00.png deleted file mode 100644 index 45aef40e..00000000 Binary files a/inputs/gray_faces/169_John_Lennon_00.png and /dev/null differ diff --git a/inputs/gray_faces/170_Marilyn_Monroe_00.png b/inputs/gray_faces/170_Marilyn_Monroe_00.png deleted file mode 100644 index 6e1f2a11..00000000 Binary files a/inputs/gray_faces/170_Marilyn_Monroe_00.png and /dev/null differ diff --git a/inputs/gray_faces/Einstein01.png b/inputs/gray_faces/Einstein01.png deleted file mode 100644 index 9fd17d39..00000000 Binary files a/inputs/gray_faces/Einstein01.png and /dev/null differ diff --git a/inputs/gray_faces/Einstein02.png b/inputs/gray_faces/Einstein02.png deleted file mode 100644 index 4650dd3b..00000000 Binary files a/inputs/gray_faces/Einstein02.png and /dev/null differ diff --git a/inputs/gray_faces/Hepburn01.png b/inputs/gray_faces/Hepburn01.png deleted file mode 100644 index 7ef44227..00000000 Binary files a/inputs/gray_faces/Hepburn01.png and /dev/null differ diff --git a/inputs/gray_faces/Hepburn02.png b/inputs/gray_faces/Hepburn02.png deleted file mode 100644 index f0f364b8..00000000 Binary files a/inputs/gray_faces/Hepburn02.png and /dev/null differ diff --git a/inputs/masked_faces/00105.png b/inputs/masked_faces/00105.png deleted file mode 100644 index 28782c91..00000000 Binary files a/inputs/masked_faces/00105.png and /dev/null differ diff --git a/inputs/masked_faces/00108.png b/inputs/masked_faces/00108.png deleted file mode 100644 index 25745db1..00000000 Binary files a/inputs/masked_faces/00108.png and /dev/null differ diff --git a/inputs/masked_faces/00169.png b/inputs/masked_faces/00169.png deleted file mode 100644 index c5a64150..00000000 Binary files a/inputs/masked_faces/00169.png and /dev/null differ diff --git a/inputs/masked_faces/00588.png b/inputs/masked_faces/00588.png deleted file mode 100644 index c734740e..00000000 Binary files a/inputs/masked_faces/00588.png and /dev/null differ diff --git a/inputs/masked_faces/00664.png b/inputs/masked_faces/00664.png deleted file mode 100644 index 1afc6218..00000000 Binary files a/inputs/masked_faces/00664.png and /dev/null differ diff --git a/inputs/whole_imgs/00.jpg b/inputs/whole_imgs/00.jpg deleted file mode 100644 index d6e323e5..00000000 Binary files a/inputs/whole_imgs/00.jpg and /dev/null differ diff --git a/inputs/whole_imgs/01.jpg b/inputs/whole_imgs/01.jpg deleted file mode 100644 index 485fc6a5..00000000 Binary files a/inputs/whole_imgs/01.jpg and /dev/null differ diff --git a/inputs/whole_imgs/02.png b/inputs/whole_imgs/02.png deleted file mode 100644 index 378e7b15..00000000 Binary files a/inputs/whole_imgs/02.png and /dev/null differ diff --git a/inputs/whole_imgs/03.jpg b/inputs/whole_imgs/03.jpg deleted file mode 100644 index b6c84281..00000000 Binary files a/inputs/whole_imgs/03.jpg and /dev/null differ diff --git a/inputs/whole_imgs/04.jpg b/inputs/whole_imgs/04.jpg deleted file mode 100644 index bb94681a..00000000 Binary files a/inputs/whole_imgs/04.jpg and /dev/null differ diff --git a/inputs/whole_imgs/05.jpg b/inputs/whole_imgs/05.jpg deleted file mode 100644 index 4dc33735..00000000 Binary files a/inputs/whole_imgs/05.jpg and /dev/null differ diff --git a/inputs/whole_imgs/06.png b/inputs/whole_imgs/06.png deleted file mode 100644 index 49c2fff2..00000000 Binary files a/inputs/whole_imgs/06.png and /dev/null differ diff --git a/options/CodeFormer_colorization.yml b/options/CodeFormer_colorization.yml deleted file mode 100755 index 6fa595b4..00000000 --- a/options/CodeFormer_colorization.yml +++ /dev/null @@ -1,145 +0,0 @@ -# general settings -name: CodeFormer_colorization -model_type: CodeFormerIdxModel -num_gpu: 8 -manual_seed: 0 - -# dataset and data loader settings -datasets: - train: - name: FFHQ - type: FFHQBlindDataset - dataroot_gt: datasets/ffhq/ffhq_512 - filename_tmpl: '{}' - io_backend: - type: disk - - in_size: 512 - gt_size: 512 - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - use_hflip: true - use_corrupt: true - - # large degradation in stageII - blur_kernel_size: 41 - use_motion_kernel: false - motion_kernel_prob: 0.001 - kernel_list: ['iso', 'aniso'] - kernel_prob: [0.5, 0.5] - blur_sigma: [1, 15] - downsample_range: [4, 30] - noise_range: [0, 20] - jpeg_range: [30, 80] - - # color jitter and gray - color_jitter_prob: 0.3 - color_jitter_shift: 20 - color_jitter_pt_prob: 0.3 - gray_prob: 0.01 - - latent_gt_path: ~ # without pre-calculated latent code - # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' - - # data loader - num_worker_per_gpu: 2 - batch_size_per_gpu: 4 - dataset_enlarge_ratio: 100 - prefetch_mode: ~ - - # val: - # name: CelebA-HQ-512 - # type: PairedImageDataset - # dataroot_lq: datasets/faces/validation/lq - # dataroot_gt: datasets/faces/validation/gt - # io_backend: - # type: disk - # mean: [0.5, 0.5, 0.5] - # std: [0.5, 0.5, 0.5] - # scale: 1 - -# network structures -network_g: - type: CodeFormer - dim_embd: 512 - n_head: 8 - n_layers: 9 - codebook_size: 1024 - connect_list: ['32', '64', '128', '256'] - fix_modules: ['quantize','generator'] - vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN - -network_vqgan: # this config is needed if no pre-calculated latent - type: VQAutoEncoder - img_size: 512 - nf: 64 - ch_mult: [1, 2, 2, 4, 4, 8] - quantizer: 'nearest' - codebook_size: 1024 - -# path -path: - pretrain_network_g: ~ - param_key_g: params_ema - strict_load_g: false - pretrain_network_d: ~ - strict_load_d: true - resume_state: ~ - -# base_lr(4.5e-6)*bach_size(4) -train: - use_hq_feat_loss: true - feat_loss_weight: 1.0 - cross_entropy_loss: true - entropy_loss_weight: 0.5 - fidelity_weight: 0 - - optim_g: - type: Adam - lr: !!float 1e-4 - weight_decay: 0 - betas: [0.9, 0.99] - - scheduler: - type: MultiStepLR - milestones: [400000, 450000] - gamma: 0.5 - - total_iter: 500000 - - warmup_iter: -1 # no warm up - ema_decay: 0.995 - - use_adaptive_weight: true - - net_g_start_iter: 0 - net_d_iters: 1 - net_d_start_iter: 0 - manual_seed: 0 - -# validation settings -val: - val_freq: !!float 5e10 # no validation - save_img: true - - metrics: - psnr: # metric name, can be arbitrary - type: calculate_psnr - crop_border: 4 - test_y_channel: false - -# logging settings -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ - -# dist training settings -dist_params: - backend: nccl - port: 29419 - -find_unused_parameters: true diff --git a/options/CodeFormer_inpainting.yml b/options/CodeFormer_inpainting.yml deleted file mode 100755 index ddd68452..00000000 --- a/options/CodeFormer_inpainting.yml +++ /dev/null @@ -1,159 +0,0 @@ -# general settings -name: CodeFormer_inpainting -model_type: CodeFormerModel -num_gpu: 4 -manual_seed: 0 - -# dataset and data loader settings -datasets: - train: - name: FFHQ - type: FFHQBlindDataset - dataroot_gt: datasets/ffhq/ffhq_512 - filename_tmpl: '{}' - io_backend: - type: disk - - in_size: 512 - gt_size: 512 - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - use_hflip: true - use_corrupt: false - gen_inpaint_mask: true - - latent_gt_path: ~ # without pre-calculated latent code - # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' - - # data loader - num_worker_per_gpu: 2 - batch_size_per_gpu: 3 - dataset_enlarge_ratio: 100 - prefetch_mode: ~ - - # val: - # name: CelebA-HQ-512 - # type: PairedImageDataset - # dataroot_lq: datasets/faces/validation/lq - # dataroot_gt: datasets/faces/validation/gt - # io_backend: - # type: disk - # mean: [0.5, 0.5, 0.5] - # std: [0.5, 0.5, 0.5] - # scale: 1 - -# network structures -network_g: - type: CodeFormer - dim_embd: 512 - n_head: 8 - n_layers: 9 - codebook_size: 1024 - connect_list: ['32', '64', '128'] - fix_modules: ['quantize','generator'] - vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN - -network_vqgan: # this config is needed if no pre-calculated latent - type: VQAutoEncoder - img_size: 512 - nf: 64 - ch_mult: [1, 2, 2, 4, 4, 8] - quantizer: 'nearest' - codebook_size: 1024 - -network_d: - type: VQGANDiscriminator - nc: 3 - ndf: 64 - n_layers: 4 - model_path: ~ - -# path -path: - pretrain_network_g: ~ - param_key_g: params_ema - strict_load_g: true - pretrain_network_d: ~ - strict_load_d: true - resume_state: ~ - -# base_lr(4.5e-6)*bach_size(4) -train: - use_hq_feat_loss: true - feat_loss_weight: 1.0 - cross_entropy_loss: true - entropy_loss_weight: 0.5 - scale_adaptive_gan_weight: 0.1 - fidelity_weight: 1.0 - - optim_g: - type: Adam - lr: !!float 7e-5 - weight_decay: 0 - betas: [0.9, 0.99] - optim_d: - type: Adam - lr: !!float 7e-5 - weight_decay: 0 - betas: [0.9, 0.99] - - scheduler: - type: MultiStepLR - milestones: [250000, 300000] - gamma: 0.5 - - total_iter: 300000 - - warmup_iter: -1 # no warm up - ema_decay: 0.997 - - pixel_opt: - type: L1Loss - loss_weight: 1.0 - reduction: mean - - perceptual_opt: - type: LPIPSLoss - loss_weight: 1.0 - use_input_norm: true - range_norm: true - - gan_opt: - type: GANLoss - gan_type: hinge - loss_weight: !!float 1.0 # adaptive_weighting - - - use_adaptive_weight: true - - net_g_start_iter: 0 - net_d_iters: 1 - net_d_start_iter: 296001 - manual_seed: 0 - -# validation settings -val: - val_freq: !!float 5e10 # no validation - save_img: true - - metrics: - psnr: # metric name, can be arbitrary - type: calculate_psnr - crop_border: 4 - test_y_channel: false - -# logging settings -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ - -# dist training settings -dist_params: - backend: nccl - port: 29420 - -find_unused_parameters: true diff --git a/options/CodeFormer_stage3.yml b/options/CodeFormer_stage3.yml deleted file mode 100755 index fbca3f2a..00000000 --- a/options/CodeFormer_stage3.yml +++ /dev/null @@ -1,171 +0,0 @@ -# general settings -name: CodeFormer_stage3 -model_type: CodeFormerJointModel -num_gpu: 8 -manual_seed: 0 - -# dataset and data loader settings -datasets: - train: - name: FFHQ - type: FFHQBlindJointDataset - dataroot_gt: datasets/ffhq/ffhq_512 - filename_tmpl: '{}' - io_backend: - type: disk - - in_size: 512 - gt_size: 512 - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - use_hflip: true - use_corrupt: true - - blur_kernel_size: 41 - use_motion_kernel: false - motion_kernel_prob: 0.001 - kernel_list: ['iso', 'aniso'] - kernel_prob: [0.5, 0.5] - # small degradation in stageIII - blur_sigma: [0.1, 10] - downsample_range: [1, 12] - noise_range: [0, 15] - jpeg_range: [60, 100] - # large degradation in stageII - blur_sigma_large: [1, 15] - downsample_range_large: [4, 30] - noise_range_large: [0, 20] - jpeg_range_large: [30, 80] - - latent_gt_path: ~ # without pre-calculated latent code - # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' - - # data loader - num_worker_per_gpu: 1 - batch_size_per_gpu: 3 - dataset_enlarge_ratio: 100 - prefetch_mode: ~ - - # val: - # name: CelebA-HQ-512 - # type: PairedImageDataset - # dataroot_lq: datasets/faces/validation/lq - # dataroot_gt: datasets/faces/validation/gt - # io_backend: - # type: disk - # mean: [0.5, 0.5, 0.5] - # std: [0.5, 0.5, 0.5] - # scale: 1 - -# network structures -network_g: - type: CodeFormer - dim_embd: 512 - n_head: 8 - n_layers: 9 - codebook_size: 1024 - connect_list: ['32', '64', '128', '256'] - fix_modules: ['quantize','generator'] - -network_vqgan: # this config is needed if no pre-calculated latent - type: VQAutoEncoder - img_size: 512 - nf: 64 - ch_mult: [1, 2, 2, 4, 4, 8] - quantizer: 'nearest' - codebook_size: 1024 - -network_d: - type: VQGANDiscriminator - nc: 3 - ndf: 64 - n_layers: 4 - -# path -path: - pretrain_network_g: './experiments/pretrained_models/CodeFormer_stage2/net_g_latest.pth' # pretrained G model in StageII - param_key_g: params_ema - strict_load_g: false - pretrain_network_d: './experiments/pretrained_models/CodeFormer_stage2/net_d_latest.pth' # pretrained D model in StageII - resume_state: ~ - -# base_lr(4.5e-6)*bach_size(4) -train: - use_hq_feat_loss: true - feat_loss_weight: 1.0 - cross_entropy_loss: true - entropy_loss_weight: 0.5 - scale_adaptive_gan_weight: 0.1 - - optim_g: - type: Adam - lr: !!float 5e-5 - weight_decay: 0 - betas: [0.9, 0.99] - optim_d: - type: Adam - lr: !!float 5e-5 - weight_decay: 0 - betas: [0.9, 0.99] - - scheduler: - type: CosineAnnealingRestartLR - periods: [150000] - restart_weights: [1] - eta_min: !!float 2e-5 - - - total_iter: 150000 - - warmup_iter: -1 # no warm up - ema_decay: 0.997 - - pixel_opt: - type: L1Loss - loss_weight: 1.0 - reduction: mean - - perceptual_opt: - type: LPIPSLoss - loss_weight: 1.0 - use_input_norm: true - range_norm: true - - gan_opt: - type: GANLoss - gan_type: hinge - loss_weight: !!float 1.0 # adaptive_weighting - - use_adaptive_weight: true - - net_g_start_iter: 0 - net_d_iters: 1 - net_d_start_iter: 5001 - manual_seed: 0 - -# validation settings -val: - val_freq: !!float 5e10 # no validation - save_img: true - - metrics: - psnr: # metric name, can be arbitrary - type: calculate_psnr - crop_border: 4 - test_y_channel: false - -# logging settings -logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ - -# dist training settings -dist_params: - backend: nccl - port: 29413 - -find_unused_parameters: true diff --git a/options/VQGAN_512_ds32_nearest_stage1.yml b/options/VQGAN_512_ds32_nearest_stage1.yml deleted file mode 100755 index 0753fc36..00000000 --- a/options/VQGAN_512_ds32_nearest_stage1.yml +++ /dev/null @@ -1,136 +0,0 @@ -# general settings -name: VQGAN-512-ds32-nearest-stage1 -model_type: VQGANModel -num_gpu: 8 -manual_seed: 0 - -# dataset and data loader settings -datasets: - train: - name: FFHQ - type: FFHQBlindDataset - dataroot_gt: datasets/ffhq/ffhq_512 - filename_tmpl: '{}' - io_backend: - type: disk - - in_size: 512 - gt_size: 512 - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - use_hflip: true - use_corrupt: false # for VQGAN - - # data loader - num_worker_per_gpu: 2 - batch_size_per_gpu: 4 - dataset_enlarge_ratio: 100 - - prefetch_mode: cpu - num_prefetch_queue: 4 - - # val: - # name: CelebA-HQ-512 - # type: PairedImageDataset - # dataroot_lq: datasets/faces/validation/gt - # dataroot_gt: datasets/faces/validation/gt - # io_backend: - # type: disk - # mean: [0.5, 0.5, 0.5] - # std: [0.5, 0.5, 0.5] - # scale: 1 - -# network structures -network_g: - type: VQAutoEncoder - img_size: 512 - nf: 64 - ch_mult: [1, 2, 2, 4, 4, 8] - quantizer: 'nearest' - codebook_size: 1024 - -network_d: - type: VQGANDiscriminator - nc: 3 - ndf: 64 - -# path -path: - pretrain_network_g: ~ - param_key_g: params_ema - strict_load_g: true - pretrain_network_d: ~ - strict_load_d: true - resume_state: ~ - -# base_lr(4.5e-6)*bach_size(4) -train: - optim_g: - type: Adam - lr: !!float 7e-5 - weight_decay: 0 - betas: [0.9, 0.99] - optim_d: - type: Adam - lr: !!float 7e-5 - weight_decay: 0 - betas: [0.9, 0.99] - - scheduler: - type: CosineAnnealingRestartLR - periods: [1600000] - restart_weights: [1] - eta_min: !!float 6e-5 # no lr reduce in official vqgan code - - total_iter: 1600000 - - warmup_iter: -1 # no warm up - ema_decay: 0.995 # GFPGAN: 0.5**(32 / (10 * 1000) == 0.998; Unleashing: 0.995 - - pixel_opt: - type: L1Loss - loss_weight: 1.0 - reduction: mean - - perceptual_opt: - type: LPIPSLoss - loss_weight: 1.0 - use_input_norm: true - range_norm: true - - gan_opt: - type: GANLoss - gan_type: hinge - loss_weight: !!float 1.0 # adaptive_weighting - - net_g_start_iter: 0 - net_d_iters: 1 - net_d_start_iter: 30001 - manual_seed: 0 - -# validation settings -val: - val_freq: !!float 5e10 # no validation - save_img: true - - metrics: - psnr: # metric name, can be arbitrary - type: calculate_psnr - crop_border: 4 - test_y_channel: false - -# logging settings -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ - -# dist training settings -dist_params: - backend: nccl - port: 29411 - -find_unused_parameters: true diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 7e1950a0..0232d12a --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,31 @@ -addict -future -lmdb -numpy -opencv-python -Pillow -pyyaml -requests -scikit-image -scipy -tb-nightly -torch>=1.7.1 -torchvision -tqdm -yapf +accelerate==0.28.0 +audio-separator==0.17.2 +av==12.1.0 +bitsandbytes==0.43.1 +decord==0.6.0 +diffusers==0.27.2 +einops==0.8.0 +insightface==0.7.3 +librosa==0.10.2.post1 +mediapipe[vision]==0.10.14 +mlflow==2.13.1 +moviepy==1.0.3 +numpy==1.26.4 +omegaconf==2.3.0 +onnx2torch==1.5.14 +onnx==1.16.1 +onnxruntime-gpu==1.18.0 +opencv-contrib-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +opencv-python==4.9.0.80 +pillow==10.3.0 +setuptools==70.0.0 +tqdm==4.66.4 +transformers==4.39.2 +xformers==0.0.25.post1 +isort==5.13.2 +pylint==3.2.2 +pre-commit==3.7.1 +gradio==4.36.1 lpips -gdown # supports downloading the large file from Google Drive \ No newline at end of file +ffmpeg-python==0.2.0 \ No newline at end of file diff --git a/scripts/crop_align_face.py b/scripts/crop_align_face.py deleted file mode 100755 index c44d6e8f..00000000 --- a/scripts/crop_align_face.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) -author: lzhbrian (https://lzhbrian.me) -link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5 -date: 2020.1.5 -note: code is heavily borrowed from - https://github.com/NVlabs/ffhq-dataset - http://dlib.net/face_landmark_detection.py.html -requirements: - conda install Pillow numpy scipy - conda install -c conda-forge dlib - # download face landmark model from: - # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 -""" - -import os -import glob -import numpy as np -import PIL -import PIL.Image -import scipy -import scipy.ndimage -import argparse -from basicsr.utils.download_util import load_file_from_url - -try: - import dlib -except ImportError: - print('Please install dlib by running:' 'conda install -c conda-forge dlib') - -# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 -shape_predictor_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_68_face_landmarks-fbdc2cb8.dat' -ckpt_path = load_file_from_url(url=shape_predictor_url, - model_dir='weights/dlib', progress=True, file_name=None) -predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat') - - -def get_landmark(filepath, only_keep_largest=True): - """get landmark with dlib - :return: np.array shape=(68, 2) - """ - detector = dlib.get_frontal_face_detector() - - img = dlib.load_rgb_image(filepath) - dets = detector(img, 1) - - # Shangchen modified - print("\tNumber of faces detected: {}".format(len(dets))) - if only_keep_largest: - print('\tOnly keep the largest.') - face_areas = [] - for k, d in enumerate(dets): - face_area = (d.right() - d.left()) * (d.bottom() - d.top()) - face_areas.append(face_area) - - largest_idx = face_areas.index(max(face_areas)) - d = dets[largest_idx] - shape = predictor(img, d) - # print("Part 0: {}, Part 1: {} ...".format( - # shape.part(0), shape.part(1))) - else: - for k, d in enumerate(dets): - # print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format( - # k, d.left(), d.top(), d.right(), d.bottom())) - # Get the landmarks/parts for the face in box d. - shape = predictor(img, d) - # print("Part 0: {}, Part 1: {} ...".format( - # shape.part(0), shape.part(1))) - - t = list(shape.parts()) - a = [] - for tt in t: - a.append([tt.x, tt.y]) - lm = np.array(a) - # lm is a shape=(68,2) np.array - return lm - -def align_face(filepath, out_path): - """ - :param filepath: str - :return: PIL Image - """ - try: - lm = get_landmark(filepath) - except: - print('No landmark ...') - return - - lm_chin = lm[0:17] # left-right - lm_eyebrow_left = lm[17:22] # left-right - lm_eyebrow_right = lm[22:27] # left-right - lm_nose = lm[27:31] # top-down - lm_nostrils = lm[31:36] # top-down - lm_eye_left = lm[36:42] # left-clockwise - lm_eye_right = lm[42:48] # left-clockwise - lm_mouth_outer = lm[48:60] # left-clockwise - lm_mouth_inner = lm[60:68] # left-clockwise - - # Calculate auxiliary vectors. - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - eye_avg = (eye_left + eye_right) * 0.5 - eye_to_eye = eye_right - eye_left - mouth_left = lm_mouth_outer[0] - mouth_right = lm_mouth_outer[6] - mouth_avg = (mouth_left + mouth_right) * 0.5 - eye_to_mouth = mouth_avg - eye_avg - - # Choose oriented crop rectangle. - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] - x /= np.hypot(*x) - x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) - y = np.flipud(x) * [-1, 1] - c = eye_avg + eye_to_mouth * 0.1 - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) - qsize = np.hypot(*x) * 2 - - # read image - img = PIL.Image.open(filepath) - - output_size = 512 - transform_size = 4096 - enable_padding = False - - # Shrink. - shrink = int(np.floor(qsize / output_size * 0.5)) - if shrink > 1: - rsize = (int(np.rint(float(img.size[0]) / shrink)), - int(np.rint(float(img.size[1]) / shrink))) - img = img.resize(rsize, PIL.Image.ANTIALIAS) - quad /= shrink - qsize /= shrink - - # Crop. - border = max(int(np.rint(qsize * 0.1)), 3) - crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) - crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), - min(crop[2] + border, - img.size[0]), min(crop[3] + border, img.size[1])) - if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: - img = img.crop(crop) - quad -= crop[0:2] - - # Pad. - pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) - pad = (max(-pad[0] + border, - 0), max(-pad[1] + border, - 0), max(pad[2] - img.size[0] + border, - 0), max(pad[3] - img.size[1] + border, 0)) - if enable_padding and max(pad) > border - 4: - pad = np.maximum(pad, int(np.rint(qsize * 0.3))) - img = np.pad( - np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), - 'reflect') - h, w, _ = img.shape - y, x, _ = np.ogrid[:h, :w, :1] - mask = np.maximum( - 1.0 - - np.minimum(np.float32(x) / pad[0], - np.float32(w - 1 - x) / pad[2]), 1.0 - - np.minimum(np.float32(y) / pad[1], - np.float32(h - 1 - y) / pad[3])) - blur = qsize * 0.02 - img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) - img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) - img = PIL.Image.fromarray( - np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') - quad += pad[:2] - - img = img.transform((transform_size, transform_size), PIL.Image.QUAD, - (quad + 0.5).flatten(), PIL.Image.BILINEAR) - - if output_size < transform_size: - img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) - - # Save aligned image. - # print('saveing: ', out_path) - img.save(out_path) - - return img, np.max(quad[:, 0]) - np.min(quad[:, 0]) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('-i', '--in_dir', type=str, default='./inputs/whole_imgs') - parser.add_argument('-o', '--out_dir', type=str, default='./inputs/cropped_faces') - args = parser.parse_args() - - if args.out_dir.endswith('/'): # solve when path ends with / - args.out_dir = args.out_dir[:-1] - dir_name = os.path.abspath(args.out_dir) - os.makedirs(dir_name, exist_ok=True) - - img_list = sorted(glob.glob(os.path.join(args.in_dir, '*.[jpJP][pnPN]*[gG]'))) - test_img_num = len(img_list) - - for i, in_path in enumerate(img_list): - img_name = os.path.basename(in_path) - print(f'[{i+1}/{test_img_num}] Processing: {img_name}') - out_path = os.path.join(args.out_dir, in_path.split("/")[-1]) - out_path = out_path.replace('.jpg', '.png') - size_ = align_face(in_path, out_path) \ No newline at end of file diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py deleted file mode 100644 index 70737833..00000000 --- a/scripts/download_pretrained_models.py +++ /dev/null @@ -1,52 +0,0 @@ -import argparse -import os -from os import path as osp - -from basicsr.utils.download_util import load_file_from_url - - -def download_pretrained_models(method, file_urls): - if method == 'CodeFormer_train': - method = 'CodeFormer' - save_path_root = f'./weights/{method}' - os.makedirs(save_path_root, exist_ok=True) - - for file_name, file_url in file_urls.items(): - save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - parser.add_argument( - 'method', - type=str, - help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models.")) - args = parser.parse_args() - - file_urls = { - 'CodeFormer': { - 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - }, - 'CodeFormer_train': { - 'vqgan_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/vqgan_code1024.pth', - 'latent_gt_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/latent_gt_code1024.pth', - 'codeformer_stage2.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_stage2.pth', - 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - }, - 'facelib': { - # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth', - 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', - 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' - }, - 'dlib': { - 'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat', - 'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat' - } - } - - if args.method == 'all': - for method in file_urls.keys(): - download_pretrained_models(method, file_urls[method]) - else: - download_pretrained_models(args.method, file_urls[args.method]) \ No newline at end of file diff --git a/scripts/download_pretrained_models_from_gdrive.py b/scripts/download_pretrained_models_from_gdrive.py deleted file mode 100644 index 7df5be6f..00000000 --- a/scripts/download_pretrained_models_from_gdrive.py +++ /dev/null @@ -1,60 +0,0 @@ -import argparse -import os -from os import path as osp - -# from basicsr.utils.download_util import download_file_from_google_drive -import gdown - - -def download_pretrained_models(method, file_ids): - save_path_root = f'./weights/{method}' - os.makedirs(save_path_root, exist_ok=True) - - for file_name, file_id in file_ids.items(): - file_url = 'https://drive.google.com/uc?id='+file_id - save_path = osp.abspath(osp.join(save_path_root, file_name)) - if osp.exists(save_path): - user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') - if user_response.lower() == 'y': - print(f'Covering {file_name} to {save_path}') - gdown.download(file_url, save_path, quiet=False) - # download_file_from_google_drive(file_id, save_path) - elif user_response.lower() == 'n': - print(f'Skipping {file_name}') - else: - raise ValueError('Wrong input. Only accepts Y/N.') - else: - print(f'Downloading {file_name} to {save_path}') - gdown.download(file_url, save_path, quiet=False) - # download_file_from_google_drive(file_id, save_path) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - parser.add_argument( - 'method', - type=str, - help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.")) - args = parser.parse_args() - - # file name: file id - # 'dlib': { - # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX', - # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg', - # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq' - # } - file_ids = { - 'CodeFormer': { - 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB' - }, - 'facelib': { - 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV', - 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK' - } - } - - if args.method == 'all': - for method in file_ids.keys(): - download_pretrained_models(method, file_ids[method]) - else: - download_pretrained_models(args.method, file_ids[args.method]) \ No newline at end of file diff --git a/scripts/generate_latent_gt.py b/scripts/generate_latent_gt.py deleted file mode 100644 index 3f1f17b0..00000000 --- a/scripts/generate_latent_gt.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse -import glob -import numpy as np -import os -import cv2 -import torch -from torchvision.transforms.functional import normalize -from basicsr.utils import imwrite, img2tensor, tensor2img - -from basicsr.utils.registry import ARCH_REGISTRY - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512') - parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan') - parser.add_argument('--codebook_size', type=int, default=1024) - parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth') - args = parser.parse_args() - - if args.save_root.endswith('/'): # solve when path ends with / - args.save_root = args.save_root[:-1] - dir_name = os.path.abspath(args.save_root) - os.makedirs(dir_name, exist_ok=True) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - test_path = args.test_path - save_root = args.save_root - ckpt_path = args.ckpt_path - codebook_size = args.codebook_size - - vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', - codebook_size=codebook_size).to(device) - checkpoint = torch.load(ckpt_path)['params_ema'] - - vqgan.load_state_dict(checkpoint) - vqgan.eval() - - sum_latent = np.zeros((codebook_size)).astype('float64') - size_latent = 16 - latent = {} - latent['orig'] = {} - latent['hflip'] = {} - for i in ['orig', 'hflip']: - # for i in ['hflip']: - for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))): - img_name = os.path.basename(img_path) - img = cv2.imread(img_path) - if i == 'hflip': - cv2.flip(img, 1, img) - img = img2tensor(img / 255., bgr2rgb=True, float32=True) - normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - img = img.unsqueeze(0).to(device) - with torch.no_grad(): - # output = net(img)[0] - x, feat_dict = vqgan.encoder(img, True) - x, _, log = vqgan.quantize(x) - # del output - torch.cuda.empty_cache() - - min_encoding_indices = log['min_encoding_indices'] - min_encoding_indices = min_encoding_indices.view(size_latent,size_latent) - latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy() - print(img_name, latent[i][img_name[:-4]].shape) - - latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth') - torch.save(latent, latent_save_path) - print(f'\nLatent GT code are saved in {save_root}') diff --git a/scripts/inference_vqgan.py b/scripts/inference_vqgan.py deleted file mode 100644 index 62644bba..00000000 --- a/scripts/inference_vqgan.py +++ /dev/null @@ -1,59 +0,0 @@ -import argparse -import glob -import numpy as np -import os -import cv2 -import torch -from torchvision.transforms.functional import normalize -from basicsr.utils import imwrite, img2tensor, tensor2img - -from basicsr.utils.registry import ARCH_REGISTRY - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512') - parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec') - parser.add_argument('--codebook_size', type=int, default=1024) - parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth') - args = parser.parse_args() - - if args.save_root.endswith('/'): # solve when path ends with / - args.save_root = args.save_root[:-1] - dir_name = os.path.abspath(args.save_root) - os.makedirs(dir_name, exist_ok=True) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - test_path = args.test_path - save_root = args.save_root - ckpt_path = args.ckpt_path - codebook_size = args.codebook_size - - vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', - codebook_size=codebook_size).to(device) - checkpoint = torch.load(ckpt_path)['params_ema'] - - vqgan.load_state_dict(checkpoint) - vqgan.eval() - - for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))): - img_name = os.path.basename(img_path) - print(img_name) - img = cv2.imread(img_path) - img = img2tensor(img / 255., bgr2rgb=True, float32=True) - normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - img = img.unsqueeze(0).to(device) - with torch.no_grad(): - output = vqgan(img)[0] - output = tensor2img(output, min_max=[-1,1]) - img = tensor2img(img, min_max=[-1,1]) - restored_img = np.concatenate([img, output], axis=1) - restored_img = output - del output - torch.cuda.empty_cache() - - path = os.path.splitext(os.path.join(save_root, img_name))[0] - save_path = f'{path}.png' - imwrite(restored_img, save_path) - - print(f'\nAll results are saved in {save_root}') - diff --git a/inference_codeformer.py b/scripts/video_sr.py old mode 100644 new mode 100755 similarity index 58% rename from inference_codeformer.py rename to scripts/video_sr.py index 1a38cc95..8c885995 --- a/inference_codeformer.py +++ b/scripts/video_sr.py @@ -2,8 +2,13 @@ import cv2 import argparse import glob +import sys + import torch from torchvision.transforms.functional import normalize + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from basicsr.utils.misc import gpu_is_available, get_device @@ -12,9 +17,6 @@ from basicsr.utils.registry import ARCH_REGISTRY -pretrain_model_url = { - 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', -} def set_realesrgan(): from basicsr.archs.rrdbnet_arch import RRDBNet @@ -36,7 +38,7 @@ def set_realesrgan(): ) upsampler = RealESRGANer( scale=2, - model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth", + model_path="./pretrained_models/realesrgan/RealESRGAN_x2plus.pth", model=model, tile=args.bg_tile, tile_pad=40, @@ -52,15 +54,17 @@ def set_realesrgan(): category=RuntimeWarning) return upsampler + + + if __name__ == '__main__': # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = get_device() parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs', - help='Input image, video or folder. Default: inputs/whole_imgs') + parser.add_argument('-i', '--input_path', type=str, help='Input video') parser.add_argument('-o', '--output_path', type=str, default=None, - help='Output folder. Default: results/_') + help='Output folder') parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5, help='Balance the quality and fidelity. Default: 0.5') parser.add_argument('-s', '--upscale', type=int, default=2, @@ -71,23 +75,19 @@ def set_realesrgan(): # large det_model: 'YOLOv5l', 'retinaface_resnet50' # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', - help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \ + help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \ Default: retinaface_resnet50') parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan') parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False') parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400') parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None') - parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None') - + args = parser.parse_args() # ------------------------ input & output ------------------------ w = args.fidelity_weight input_video = False - if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path - input_img_list = [args.input_path] - result_root = f'results/test_img_{w}' - elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path from basicsr.utils.video_util import VideoReader, VideoWriter input_img_list = [] vidreader = VideoReader(args.input_path) @@ -96,17 +96,13 @@ def set_realesrgan(): input_img_list.append(image) image = vidreader.get_frame() audio = vidreader.get_audio() - fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps + fps = vidreader.get_fps() video_name = os.path.basename(args.input_path)[:-4] - result_root = f'results/{video_name}_{w}' + result_root = f'./hq_results/{video_name}_{w}_{args.upscale}' input_video = True vidreader.close() - else: # input img folder - if args.input_path.endswith('/'): # solve when path ends with / - args.input_path = args.input_path[:-1] - # scan all the jpg and png images - input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]'))) - result_root = f'results/{os.path.basename(args.input_path)}_{w}' + else: + raise RuntimeError("input should be mp4 file") if not args.output_path is None: # set output path result_root = args.output_path @@ -135,11 +131,12 @@ def set_realesrgan(): net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device) - # ckpt_path = 'weights/CodeFormer/codeformer.pth' - ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], - model_dir='weights/CodeFormer', progress=True, file_name=None) + ckpt_path = './pretrained_models/hallo2/net_g.pth' + checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) + m, n = net.load_state_dict(checkpoint, strict=False) + print("missing key: ", m) + assert len(n)==0 net.eval() # ------------------ set up FaceRestoreHelper ------------------- @@ -160,96 +157,124 @@ def set_realesrgan(): save_ext='png', use_parse=True, device=device) + + n = -1 + input_img_list = input_img_list[:n] + length = len(input_img_list) + + overlay = 4 + chunk = 16 + idx_list = [] + + i=0 + j=0 + while i < length and j < length: + j = min(i+chunk, length) + idx_list.append([i, j]) + i = j-overlay + + + id_list = [] # -------------------- start to processing --------------------- - for i, img_path in enumerate(input_img_list): + for i, idx in enumerate(idx_list): # clean all the intermediate results to process the next image face_helper.clean_all() + + start = idx[0] + end = idx[1] + + img_list = input_img_list[start:end] + + for j, img_path in enumerate(img_list): - if isinstance(img_path, str): - img_name = os.path.basename(img_path) - basename, ext = os.path.splitext(img_name) - print(f'[{i+1}/{test_img_num}] Processing: {img_name}') - img = cv2.imread(img_path, cv2.IMREAD_COLOR) - else: # for video processing - basename = str(i).zfill(6) - img_name = f'{video_name}_{basename}' if input_video else basename - print(f'[{i+1}/{test_img_num}] Processing: {img_name}') - img = img_path - - if args.has_aligned: - # the input faces are already cropped and aligned - img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) - face_helper.is_gray = is_gray(img, threshold=10) - if face_helper.is_gray: - print('Grayscale input: True') - face_helper.cropped_faces = [img] - else: - face_helper.read_image(img) - # get face landmarks for each face - num_det_faces = face_helper.get_face_landmarks_5( - only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) - print(f'\tdetect {num_det_faces} faces') - # align and warp each face - face_helper.align_warp_face() + if isinstance(img_path, str): + img_name = os.path.basename(img_path) + basename, ext = os.path.splitext(img_name) + print(f'[{j+1}/{chunk}] Processing: {img_name}') + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + else: # for video processing + basename = str(i).zfill(4) + img_name = f'{video_name}_{basename}_{j}' if input_video else basename + print(f'[{j+1}/{chunk}] Processing: {img_name}') + img = img_path + if args.has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=10) + if face_helper.is_gray: + print('Grayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + crop_image = [] # face restoration for each cropped face for idx, cropped_face in enumerate(face_helper.cropped_faces): # prepare data cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - - try: - with torch.no_grad(): - output = net(cropped_face_t, w=w, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}') - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + cropped_face_t = cropped_face_t.unsqueeze(0) + + crop_image.append(cropped_face_t) + + assert len(crop_image)==len(img_list) + + crop_image = torch.cat(crop_image, dim=0).to(device) + crop_image = crop_image.unsqueeze(0) + + output, top_idx = net.inference(crop_image, w=w, adain=True) + assert output.shape==crop_image.shape + + for k in range(output.shape[1]): + face_output = output[:, k:k+1] + restored_face = tensor2img(face_output.squeeze_(1), rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') + cropped_face = face_helper.cropped_faces[k] face_helper.add_restored_face(restored_face, cropped_face) + bg_img_list = [] # paste_back if not args.has_aligned: - # upsample the background - if bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] - else: - bg_img = None + for img in img_list: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] + else: + bg_img = None + bg_img_list.append(bg_img) + + face_helper.get_inverse_affine(None) # paste each restored face to the input image if args.face_upsample and face_upsampler is not None: - restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler) - else: - restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box) - - # save faces - for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)): - # save cropped face - if not args.has_aligned: - save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png') - imwrite(cropped_face, save_crop_path) - # save restored face - if args.has_aligned: - save_face_name = f'{basename}.png' + restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box, face_upsampler=face_upsampler) else: - save_face_name = f'{basename}_{idx:02d}.png' - if args.suffix is not None: - save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png' - save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name) - imwrite(restored_face, save_restore_path) + restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box) + + torch.cuda.empty_cache() + + if i!=0: + restored_img_list = restored_img_list[overlay:] + # save restored img - if not args.has_aligned and restored_img is not None: + if not args.has_aligned and len(restored_img_list)!=0: if args.suffix is not None: - basename = f'{basename}_{args.suffix}' - save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png') - imwrite(restored_img, save_restore_path) + basename = f'{video_name}_{args.suffix}_{i}' + for k, restored_img in enumerate(restored_img_list): + kk = str(k).zfill(3) + save_restore_path = os.path.join(result_root, 'final_results', f'{basename}_{kk}.png') + imwrite(restored_img, save_restore_path) # save enhanced video if input_video: @@ -257,18 +282,24 @@ def set_realesrgan(): # load images video_frames = [] img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g'))) - for img_path in img_list: - img = cv2.imread(img_path) - video_frames.append(img) + + assert len(img_list)==length, print(len(img_list), length) + # write images to video - height, width = video_frames[0].shape[:2] + sample_img = cv2.imread(img_list[0]) + height, width = sample_img.shape[:2] + if args.suffix is not None: video_name = f'{video_name}_{args.suffix}.png' save_restore_path = os.path.join(result_root, f'{video_name}.mp4') + vidwriter = VideoWriter(save_restore_path, height, width, fps, audio) - for f in video_frames: - vidwriter.write_frame(f) + for img_path in img_list: + print(img_path) + img = cv2.imread(img_path) + vidwriter.write_frame(img) + vidwriter.close() print(f'\nAll results are saved in {result_root}') diff --git a/web-demos/hugging_face/app.py b/web-demos/hugging_face/app.py deleted file mode 100644 index c614e7c8..00000000 --- a/web-demos/hugging_face/app.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -This file is used for deploying hugging face demo: -https://huggingface.co/spaces/sczhou/CodeFormer -""" - -import sys -sys.path.append('CodeFormer') -import os -import cv2 -import torch -import torch.nn.functional as F -import gradio as gr - -from torchvision.transforms.functional import normalize - -from basicsr.archs.rrdbnet_arch import RRDBNet -from basicsr.utils import imwrite, img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from basicsr.utils.misc import gpu_is_available, get_device -from basicsr.utils.realesrgan_utils import RealESRGANer -from basicsr.utils.registry import ARCH_REGISTRY - -from facelib.utils.face_restoration_helper import FaceRestoreHelper -from facelib.utils.misc import is_gray - - -os.system("pip freeze") - -pretrain_model_url = { - 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', - 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', - 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', - 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' -} -# download weights -if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'): - load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'): - load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'): - load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'): - load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None) - -# download images -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png', - '01.png') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg', - '02.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg', - '03.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg', - '04.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg', - '05.jpg') - -def imread(img_path): - img = cv2.imread(img_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return img - -# set enhancer with RealESRGAN -def set_realesrgan(): - # half = True if torch.cuda.is_available() else False - half = True if gpu_is_available() else False - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=2, - ) - upsampler = RealESRGANer( - scale=2, - model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", - model=model, - tile=400, - tile_pad=40, - pre_pad=0, - half=half, - ) - return upsampler - -upsampler = set_realesrgan() -# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -device = get_device() -codeformer_net = ARCH_REGISTRY.get("CodeFormer")( - dim_embd=512, - codebook_size=1024, - n_head=8, - n_layers=9, - connect_list=["32", "64", "128", "256"], -).to(device) -ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth" -checkpoint = torch.load(ckpt_path)["params_ema"] -codeformer_net.load_state_dict(checkpoint) -codeformer_net.eval() - -os.makedirs('output', exist_ok=True) - -def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity): - """Run a single prediction on the model""" - try: # global try - # take the default setting for the demo - has_aligned = False - only_center_face = False - draw_box = False - detection_model = "retinaface_resnet50" - print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity) - - img = cv2.imread(str(image), cv2.IMREAD_COLOR) - print('\timage size:', img.shape) - - upscale = int(upscale) # convert type to int - if upscale > 4: # avoid memory exceeded due to too large upscale - upscale = 4 - if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution - upscale = 2 - if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution - upscale = 1 - background_enhance = False - face_upsample = False - - face_helper = FaceRestoreHelper( - upscale, - face_size=512, - crop_ratio=(1, 1), - det_model=detection_model, - save_ext="png", - use_parse=True, - device=device, - ) - bg_upsampler = upsampler if background_enhance else None - face_upsampler = upsampler if face_upsample else None - - if has_aligned: - # the input faces are already cropped and aligned - img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) - face_helper.is_gray = is_gray(img, threshold=5) - if face_helper.is_gray: - print('\tgrayscale input: True') - face_helper.cropped_faces = [img] - else: - face_helper.read_image(img) - # get face landmarks for each face - num_det_faces = face_helper.get_face_landmarks_5( - only_center_face=only_center_face, resize=640, eye_dist_threshold=5 - ) - print(f'\tdetect {num_det_faces} faces') - # align and warp each face - face_helper.align_warp_face() - - # face restoration for each cropped face - for idx, cropped_face in enumerate(face_helper.cropped_faces): - # prepare data - cropped_face_t = img2tensor( - cropped_face / 255.0, bgr2rgb=True, float32=True - ) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - - try: - with torch.no_grad(): - output = codeformer_net( - cropped_face_t, w=codeformer_fidelity, adain=True - )[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - except RuntimeError as error: - print(f"Failed inference for CodeFormer: {error}") - restored_face = tensor2img( - cropped_face_t, rgb2bgr=True, min_max=(-1, 1) - ) - - restored_face = restored_face.astype("uint8") - face_helper.add_restored_face(restored_face) - - # paste_back - if not has_aligned: - # upsample the background - if bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] - else: - bg_img = None - face_helper.get_inverse_affine(None) - # paste each restored face to the input image - if face_upsample and face_upsampler is not None: - restored_img = face_helper.paste_faces_to_input_image( - upsample_img=bg_img, - draw_box=draw_box, - face_upsampler=face_upsampler, - ) - else: - restored_img = face_helper.paste_faces_to_input_image( - upsample_img=bg_img, draw_box=draw_box - ) - - # save restored img - save_path = f'output/out.png' - imwrite(restored_img, str(save_path)) - - restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) - return restored_img, save_path - except Exception as error: - print('Global exception', error) - return None, None - - -title = "CodeFormer: Robust Face Restoration and Enhancement Network" -description = r"""
CodeFormer logo
-Official Gradio demo for Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022).
-đŸ”Ĩ CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.
-🤗 Try CodeFormer for improved stable-diffusion generation!
-""" -article = r""" -If CodeFormer is helpful, please help to ⭐ the Github Repo. Thanks! -[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer) - ---- - -📝 **Citation** - -If our work is useful for your research, please consider citing: -```bibtex -@inproceedings{zhou2022codeformer, - author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change}, - title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer}, - booktitle = {NeurIPS}, - year = {2022} -} -``` - -📋 **License** - -This project is licensed under S-Lab License 1.0. -Redistribution and use for non-commercial purposes should follow this license. - -📧 **Contact** - -If you have any questions, please feel free to reach me out at shangchenzhou@gmail.com. - -
- 🤗 Find Me: - Twitter Follow - Github Follow -
- -
visitors
-""" - -demo = gr.Interface( - inference, [ - gr.inputs.Image(type="filepath", label="Input"), - gr.inputs.Checkbox(default=True, label="Background_Enhance"), - gr.inputs.Checkbox(default=True, label="Face_Upsample"), - gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"), - gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)') - ], [ - gr.outputs.Image(type="numpy", label="Output"), - gr.outputs.File(label="Download the output") - ], - title=title, - description=description, - article=article, - examples=[ - ['01.png', True, True, 2, 0.7], - ['02.jpg', True, True, 2, 0.7], - ['03.jpg', True, True, 2, 0.7], - ['04.jpg', True, True, 2, 0.1], - ['05.jpg', True, True, 2, 0.1] - ] - ) - -demo.queue(concurrency_count=2) -demo.launch() \ No newline at end of file diff --git a/web-demos/replicate/cog.yaml b/web-demos/replicate/cog.yaml deleted file mode 100644 index 3f458969..00000000 --- a/web-demos/replicate/cog.yaml +++ /dev/null @@ -1,30 +0,0 @@ -""" -This file is used for deploying replicate demo: -https://replicate.com/sczhou/codeformer -""" - -build: - gpu: true - cuda: "11.3" - python_version: "3.8" - system_packages: - - "libgl1-mesa-glx" - - "libglib2.0-0" - python_packages: - - "ipython==8.4.0" - - "future==0.18.2" - - "lmdb==1.3.0" - - "scikit-image==0.19.3" - - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - - "scipy==1.9.0" - - "gdown==4.5.1" - - "pyyaml==6.0" - - "tb-nightly==2.11.0a20220906" - - "tqdm==4.64.1" - - "yapf==0.32.0" - - "lpips==0.1.4" - - "Pillow==9.2.0" - - "opencv-python==4.6.0.66" - -predict: "predict.py:Predictor" diff --git a/web-demos/replicate/predict.py b/web-demos/replicate/predict.py deleted file mode 100644 index 1b73cbcd..00000000 --- a/web-demos/replicate/predict.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -This file is used for deploying replicate demo: -https://replicate.com/sczhou/codeformer -running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2 -push: cog push r8.im/sczhou/codeformer -""" - -import tempfile -import cv2 -import torch -from torchvision.transforms.functional import normalize -try: - from cog import BasePredictor, Input, Path -except Exception: - print('please install cog package') - -from basicsr.archs.rrdbnet_arch import RRDBNet -from basicsr.utils import imwrite, img2tensor, tensor2img -from basicsr.utils.realesrgan_utils import RealESRGANer -from basicsr.utils.misc import gpu_is_available -from basicsr.utils.registry import ARCH_REGISTRY - -from facelib.utils.face_restoration_helper import FaceRestoreHelper - -class Predictor(BasePredictor): - def setup(self): - """Load the model into memory to make running multiple predictions efficient""" - self.device = "cuda:0" - self.upsampler = set_realesrgan() - self.net = ARCH_REGISTRY.get("CodeFormer")( - dim_embd=512, - codebook_size=1024, - n_head=8, - n_layers=9, - connect_list=["32", "64", "128", "256"], - ).to(self.device) - ckpt_path = "weights/CodeFormer/codeformer.pth" - checkpoint = torch.load(ckpt_path)[ - "params_ema" - ] # update file permission if cannot load - self.net.load_state_dict(checkpoint) - self.net.eval() - - def predict( - self, - image: Path = Input(description="Input image"), - codeformer_fidelity: float = Input( - default=0.5, - ge=0, - le=1, - description="Balance the quality (lower number) and fidelity (higher number).", - ), - background_enhance: bool = Input( - description="Enhance background image with Real-ESRGAN", default=True - ), - face_upsample: bool = Input( - description="Upsample restored faces for high-resolution AI-created images", - default=True, - ), - upscale: int = Input( - description="The final upsampling scale of the image", - default=2, - ), - ) -> Path: - """Run a single prediction on the model""" - - # take the default setting for the demo - has_aligned = False - only_center_face = False - draw_box = False - detection_model = "retinaface_resnet50" - - self.face_helper = FaceRestoreHelper( - upscale, - face_size=512, - crop_ratio=(1, 1), - det_model=detection_model, - save_ext="png", - use_parse=True, - device=self.device, - ) - - bg_upsampler = self.upsampler if background_enhance else None - face_upsampler = self.upsampler if face_upsample else None - - img = cv2.imread(str(image), cv2.IMREAD_COLOR) - - if has_aligned: - # the input faces are already cropped and aligned - img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) - self.face_helper.cropped_faces = [img] - else: - self.face_helper.read_image(img) - # get face landmarks for each face - num_det_faces = self.face_helper.get_face_landmarks_5( - only_center_face=only_center_face, resize=640, eye_dist_threshold=5 - ) - print(f"\tdetect {num_det_faces} faces") - # align and warp each face - self.face_helper.align_warp_face() - - # face restoration for each cropped face - for idx, cropped_face in enumerate(self.face_helper.cropped_faces): - # prepare data - cropped_face_t = img2tensor( - cropped_face / 255.0, bgr2rgb=True, float32=True - ) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) - - try: - with torch.no_grad(): - output = self.net( - cropped_face_t, w=codeformer_fidelity, adain=True - )[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - except Exception as error: - print(f"\tFailed inference for CodeFormer: {error}") - restored_face = tensor2img( - cropped_face_t, rgb2bgr=True, min_max=(-1, 1) - ) - - restored_face = restored_face.astype("uint8") - self.face_helper.add_restored_face(restored_face) - - # paste_back - if not has_aligned: - # upsample the background - if bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] - else: - bg_img = None - self.face_helper.get_inverse_affine(None) - # paste each restored face to the input image - if face_upsample and face_upsampler is not None: - restored_img = self.face_helper.paste_faces_to_input_image( - upsample_img=bg_img, - draw_box=draw_box, - face_upsampler=face_upsampler, - ) - else: - restored_img = self.face_helper.paste_faces_to_input_image( - upsample_img=bg_img, draw_box=draw_box - ) - - # save restored img - out_path = Path(tempfile.mkdtemp()) / 'output.png' - imwrite(restored_img, str(out_path)) - - return out_path - - -def imread(img_path): - img = cv2.imread(img_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return img - - -def set_realesrgan(): - # if not torch.cuda.is_available(): # CPU - if not gpu_is_available(): # CPU - import warnings - - warnings.warn( - "The unoptimized RealESRGAN is slow on CPU. We do not use it. " - "If you really want to use it, please modify the corresponding codes.", - category=RuntimeWarning, - ) - upsampler = None - else: - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=2, - ) - upsampler = RealESRGANer( - scale=2, - model_path="./weights/realesrgan/RealESRGAN_x2plus.pth", - model=model, - tile=400, - tile_pad=40, - pre_pad=0, - half=True, - ) - return upsampler diff --git a/weights/CodeFormer/.gitkeep b/weights/CodeFormer/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/weights/README.md b/weights/README.md deleted file mode 100644 index 67ad334b..00000000 --- a/weights/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Weights - -Put the downloaded pre-trained models to this folder. \ No newline at end of file diff --git a/weights/facelib/.gitkeep b/weights/facelib/.gitkeep deleted file mode 100644 index e69de29b..00000000