-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_srgan.py
73 lines (56 loc) · 2.18 KB
/
test_srgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@文件 :test.py
@说明 :执行单张样本测试
@时间 :2020/03/02 12:14:50
@作者 :徐通
@版本 :1.0
'''
from utils import *
from torch import nn
from model_SRRGAN import SRResNet,Generator
import time
from PIL import Image
# 测试图像
imgPath = './results/test.bmp'
# 模型参数
large_kernel_size = 9 # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3 # 中间层卷积的核大小
n_channels = 64 # 中间层通道数
n_blocks = 16 # 残差模块数量
scaling_factor = 4 # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 预训练模型
srgan_checkpoint = "./results/checkpoint_srgan.pth"
# 加载模型SRGAN
checkpoint = torch.load(srgan_checkpoint)
generator = Generator(large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor)
generator = generator.to(device)
generator.load_state_dict(checkpoint['generator'])
generator.eval()
model = generator
# 加载图像
img = Image.open(imgPath, mode='r')
img = img.convert('RGB')
# 双线性上采样
Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
Bicubic_img.save('./results/test_bicubic.bmp')
# 图像预处理
lr_img = convert_image(img, source='pil', target='imagenet-norm')
lr_img.unsqueeze_(0)
# 记录时间
start = time.time()
# 转移数据至设备
lr_img = lr_img.to(device) # (1, 3, w, h ), imagenet-normed
# 模型推理
with torch.no_grad():
sr_img = model(lr_img).squeeze(0).cpu().detach() # (1, 3, w*scale, h*scale), in [-1, 1]
sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
sr_img.save('./results/test_srgan.bmp')
print('用时 {:.3f} 秒'.format(time.time()-start))