forked from kensakurada/sscdnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
170 lines (127 loc) · 6.15 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from argparse import ArgumentParser
import cv2
import numpy as np
import torch
from torch.autograd import Variable
#import torch.nn.functional as F
import sys
sys.path.append("./correlation_package/build/lib.linux-x86_64-3.6")
import os
import cscdnet
def colormap():
cmap=np.zeros([2, 3]).astype(np.uint8)
cmap[0,:] = np.array([0, 0, 0])
cmap[1,:] = np.array([255, 255, 255])
return cmap
class Colorization:
def __init__(self, n=2):
self.cmap = colormap()
self.cmap = torch.from_numpy(np.array(self.cmap[:n]))
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = gray_image[0] == label
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image
class ChangeDetect:
def __init__(self, img0_path, img1_path, img_row, img_col, out_dir, model_path, use_corr=True):
self.img0_path = img0_path
self.img1_path = img1_path
self.img_row = img_row
self.img_col = img_col
self.out_dir = out_dir
self.model_path = model_path
self.use_corr = use_corr
def preprocess_image(self):
if os.path.isfile(self.img0_path) == False:
print ('Error: File Not Found: ' + self.img0_path)
exit(-1)
if os.path.isfile(self.img1_path) == False:
print ('Error: File Not Found: ' + self.img1_path)
exit(-1)
img0 = cv2.imread(self.img0_path, cv2.IMREAD_COLOR)
img1 = cv2.imread(self.img1_path, cv2.IMREAD_COLOR)
# Images must be 256 x 256
img0_ = np.asarray(img0).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
img1_ = np.asarray(img1).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
# Cut out image 256 x 256 chip
img0_ = img0_[:, self.img_row:self.img_row+256, self.img_col:self.img_col+256]
img1_ = img1_[:, self.img_row:self.img_row+256, self.img_col:self.img_col+256]
input_ = torch.from_numpy(np.concatenate((img0_, img1_), axis=0))
print(input_.shape) # returns torch.size([6, 256, 256])
input_ = input_[np.newaxis, :, :]
input_ = input_.cuda()
input_ = Variable(input_)
return input_
def run_inference(self, input_):
self.color_transform = Colorization(2)
# pretrained keyword refers to resnet feature detector being pretrained
if self.use_corr:
print('Correlated Siamese Change Detection Network (CSCDNet)')
self.model = cscdnet.Model(inc=6, outc=2, corr=True, pretrained=True)
else:
print('Siamese Change Detection Network (Siamese CDResNet)')
self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True)
if os.path.isfile(self.model_path) is False:
print("Error: Cannot read file ... " + self.model_path)
exit(-1)
else:
print("Reading model ... " + self.model_path)
# Load trained model (from dataparallel module if necessary)
state_dict = torch.load(self.model_path)
first_pair = next(iter(state_dict.items()))
if first_pair[0][:7] == "module.":
# create new OrderedDict with generic keys
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
# load params
self.model.load_state_dict(new_state_dict)
else:
self.model.load_state_dict(state_dict)
self.model = self.model.cuda()
output_ = self.model(input_)
inputs = input_[0].cpu().data
img0 = inputs[0:3, :, :]
img1 = inputs[3:6, :, :]
img0 = (img0 + 1.0) * 128
img1 = (img1 + 1.0) * 128
output = output_[0][np.newaxis, :, :, :]
output = output[:, 0:2, :, :]
mask_pred = np.transpose(self.color_transform(output[0].cpu().max(0)[1][np.newaxis, :, :].data).numpy(),
(1, 2, 0)).astype(np.uint8)
img_out = self.display_result(img0, img1, mask_pred)
return mask_pred, img_out
def display_result(self, img0, img1, mask_pred):
rows = cols = 256
img_out = np.zeros((rows, cols * 3, 3), dtype=np.uint8)
img_out[0:rows, 0:cols, :] = np.transpose(img0.numpy(), (1, 2, 0)).astype(np.uint8)
img_out[0:rows, cols:cols * 2, :] = np.transpose(img1.numpy(), (1, 2, 0)).astype(np.uint8)
img_out[0:rows, cols*2:cols*3, :] = mask_pred
img_filename, _ = os.path.splitext(os.path.basename(self.img0_path))
img_save_path = os.path.join(self.out_dir,
'{}.png'.format(img_filename+'_'+str(self.img_row)+'_'+str(self.img_col)))
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)
print('Writing ... ' + img_save_path)
cv2.imwrite(img_save_path, img_out)
return img_out
if __name__ == "__main__":
parser = ArgumentParser(description = 'Class to preprocess images and perform change detection')
parser.add_argument('--img0_path', type=str, help='path to first image')
parser.add_argument('--img1_path', type=str, help='path to second image')
parser.add_argument('--img_row', type=int, default=0, help='row index of upper left corner of image chip')
parser.add_argument('--img_col', type=int, default=0, help='column index of upper left corner of image chip')
parser.add_argument('--out_dir' , type=str, help='path to output path')
parser.add_argument('--model_path', type=str, help='path to trained .pth model')
parser.add_argument('--use_corr', type=bool, default=True, help='use correlation?')
opt = parser.parse_args()
change_det = ChangeDetect(opt.img0_path, opt.img1_path, opt.img_row,
opt.img_col, opt.out_dir, opt.model_path, opt.use_corr)
input_ = change_det.preprocess_image()
mask_pred, img_out = change_det.run_inference(input_)