Skip to content

Commit

Permalink
Support downloading SwinIR pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Sep 9, 2021
1 parent e28b38f commit 6c38f59
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions main_download_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
python main_download_pretrained_models.py --models "DPSR" --model_dir "model_zoo"
download SwinIR models:
python main_download_pretrained_models.py --models "SwinIR" --model_dir "model_zoo/swinir"
python main_download_pretrained_models.py --models "SwinIR" --model_dir "model_zoo"
download other models:
python main_download_pretrained_models.py --models "others" --model_dir "model_zoo"
Expand Down Expand Up @@ -76,8 +76,21 @@ def download_pretrained_model(model_dir='model_zoo', model_name='dncnn3.pth'):
'USRNet': ['usrgan.pth', 'usrgan_tiny.pth', 'usrnet.pth', 'usrnet_tiny.pth'],
'DPIR': ['drunet_gray.pth', 'drunet_color.pth', 'drunet_deblocking_color.pth', 'drunet_deblocking_grayscale.pth'],
'BSRGAN': ['BSRGAN.pth', 'BSRNet.pth', 'BSRGANx2.pth'],
'SwinIR': ['001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x3.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x3.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_PSNR.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth'],
'others': ['RRDB.pth', 'ESRGAN.pth', 'FSSR_DPED.pth', 'FSSR_JPEG.pth', 'RealSR_DPED.pth', 'RealSR_JPEG.pth']
'IRCNN': ['ircnn_color.pth', 'ircnn_gray.pth'],
'SwinIR': ['001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x3.pth',
'001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth',
'001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x3.pth',
'001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth',
'002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth',
'002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth',
'003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_PSNR.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth',
'004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth',
'005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth',
'005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth',
'006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth',
'006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth'],
'others': ['msrresnet_x4_psnr.pth', 'msrresnet_x4_gan.pth', 'imdn_x4.pth', 'RRDB.pth', 'ESRGAN.pth',
'FSSR_DPED.pth', 'FSSR_JPEG.pth', 'RealSR_DPED.pth', 'RealSR_JPEG.pth']
}

method_zoo = list(method_model_zoo.keys())
Expand All @@ -91,11 +104,17 @@ def download_pretrained_model(model_dir='model_zoo', model_name='dncnn3.pth'):
download_pretrained_model(args.model_dir, model_name)
else:
for method_model in args.models:
if method_model in method_zoo:
if method_model in method_zoo: # method, need for loop
for model_name in method_model_zoo[method_model]:
download_pretrained_model(args.model_dir, model_name)
elif method_model in model_zoo:
download_pretrained_model(args.model_dir, method_model)
if 'SwinIR' in model_name:
download_pretrained_model(os.path.join(args.model_dir, 'swinir'), model_name)
else:
download_pretrained_model(args.model_dir, model_name)
elif method_model in model_zoo: # model, do not need for loop
if 'SwinIR' in model_name:
download_pretrained_model(os.path.join(args.model_dir, 'swinir'), method_model)
else:
download_pretrained_model(args.model_dir, method_model)
else:
print(f'Do not find {method_model} from the pre-trained model zoo!')

Expand Down

0 comments on commit 6c38f59

Please sign in to comment.