Zheng Chen, Yulun Zhang, Jinjin Gu, Linghe Kong, and Xiaokang Yang, "Recursive Generalization Transformer for Image Super-Resolution", ICLR, 2024
[paper] [arXiv] [supplementary material] [visual results] [pretrained models]
- 2024-02-04: Code and pre-trained models are released. 🎊🎊🎊
- 2023-09-29: This repo is released.
Abstract: Transformer architectures have exhibited remarkable performance in image superresolution (SR). Since the quadratic computational complexity of the selfattention (SA) in Transformer, existing methods tend to adopt SA in a local region to reduce overheads. However, the local design restricts the global context exploitation, which is crucial for accurate image reconstruction. In this work, we propose the Recursive Generalization Transformer (RGT) for image SR, which can capture global spatial information and is suitable for high-resolution images. Specifically, we propose the recursive-generalization self-attention (RG-SA). It recursively aggregates input features into representative feature maps, and then utilizes cross-attention to extract global information. Meanwhile, the channel dimensions of attention matrices (
$query$ ,$key$ , and$value$ ) are further scaled to mitigate the redundancy in the channel domain. Furthermore, we combine the RG-SA with local self-attention to enhance the exploitation of the global context, and propose the hybrid adaptive integration (HAI) for module integration. The HAI allows the direct and effective fusion between features at different levels (local or global). Extensive experiments demonstrate that our RGT outperforms recent state-of-the-art methods quantitatively and qualitatively.
HR | LR | SwinIR | CAT | RGT (ours) |
---|---|---|---|---|
- Python 3.8
- PyTorch 1.9.0
- NVIDIA GPU + CUDA
# Clone the github repo and go to the default directory 'RGT'.
git clone https://github.com/zhengchen1999/RGT.git
conda create -n RGT python=3.8
conda activate RGT
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
python setup.py develop
- Release code and pretrained models
Used training and testing sets can be downloaded as follows:
Training Set | Testing Set | Visual Results |
---|---|---|
DIV2K (800 training images, 100 validation images) + Flickr2K (2650 images) [complete training dataset DF2K: Google Drive / Baidu Disk] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset: Google Drive / Baidu Disk] | Google Drive / Baidu Disk |
Download training and testing datasets and put them into the corresponding folders of datasets/
. See datasets for the detail of the directory structure.
Method | Params (M) | FLOPs (G) | PSNR (dB) | SSIM | Model Zoo | Visual Results |
---|---|---|---|---|---|---|
RGT-S | 10.20 | 193.08 | 27.89 | 0.8347 | Google Drive / Baidu Disk | Google Drive / Baidu Disk |
RGT | 13.37 | 251.07 | 27.98 | 0.8369 | Google Drive / Baidu Disk | Google Drive / Baidu Disk |
The performance is reported on Urban100 (x4). Output size of FLOPs is 3×512×512.
-
Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in
datasets/
. -
Run the following scripts. The training configuration is in
options/train/
.# RGT-S, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x4.yml --launcher pytorch # RGT, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x4.yml --launcher pytorch
-
The training experiment is in
experiments/
.
-
Download the pre-trained models and place them in
experiments/pretrained_models/
.We provide pre-trained models for image SR: RGT-S and RGT (x2, x3, x4).
-
Download testing (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in
datasets/
. -
Run the following scripts. The testing configuration is in
options/test/
(e.g., test_RGT_x2.yml).Note 1: You can set
use_chop: True
(default: False) in YML to chop the image for testing.# No self-ensemble # RGT-S, reproduces results in Table 2 of the main paper python basicsr/test.py -opt options/test/test_RGT_S_x2.yml python basicsr/test.py -opt options/test/test_RGT_S_x3.yml python basicsr/test.py -opt options/test/test_RGT_S_x4.yml # RGT, reproduces results in Table 2 of the main paper python basicsr/test.py -opt options/test/test_RGT_x2.yml python basicsr/test.py -opt options/test/test_RGT_x3.yml python basicsr/test.py -opt options/test/test_RGT_x4.yml
-
The output is in
results/
.
-
Download the pre-trained models and place them in
experiments/pretrained_models/
.We provide pre-trained models for image SR: RGT-S and RGT (x2, x3, x4).
-
Put your dataset (single LR images) in
datasets/single
. Some test images are in this folder. -
Run the following scripts. The testing configuration is in
options/test/
(e.g., test_single_x2.yml).Note 1: The default model is RGT. You can use other models like RGT-S by modifying the YML.
Note 2: You can set
use_chop: True
(default: False) in YML to chop the image for testing.# Test on your dataset python basicsr/test.py -opt options/test/test_single_x2.yml python basicsr/test.py -opt options/test/test_single_x3.yml python basicsr/test.py -opt options/test/test_single_x4.yml
-
The output is in
results/
.
We achieve state-of-the-art performance. Detailed results can be found in the paper.
Visual Comparison (click to expand)
- results in Figure 6 of the main paper
- results in Figure 4 of the supplementary material
- results in Figure 5 of the supplementary material
If you find the code helpful in your research or work, please cite the following paper(s).
@inproceedings{chen2024recursive,
title={Recursive Generalization Transformer for Image Super-Resolution},
author={Chen, Zheng and Zhang, Yulun and Gu, Jinjin and Kong, Linghe and Yang, Xiaokang},
booktitle={ICLR},
year={2024}
}
This code is built on BasicSR.