-
Notifications
You must be signed in to change notification settings - Fork 0
/
new_customized_test_scanpath_gen.py
167 lines (137 loc) · 6.49 KB
/
new_customized_test_scanpath_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""Test script.
Usage:
test.py <hparams> <checkpoint_dir> <dataset_root> [--cuda=<id>]
test.py -h | --help
Options:
-h --help Show this screen.
--cuda=<id> id of the cuda device [default: 0].
"""
import os
import json
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from docopt import docopt
from os.path import join
from dataset import process_data
from irl_dcb.config import JsonConfig
import cv2 as cv
from irl_dcb.data import LHF_IRL, NEW_LHF_IRL
from irl_dcb.models import LHF_Policy_Cond_Small
from irl_dcb.environment import IRL_Env4LHF, NEW_IRL_Env4LHF
from irl_dcb import utils
torch.manual_seed(42620)
np.random.seed(42620)
def gen_scanpaths(generator,
env_test,
test_img_loader,
patch_num,
max_traj_len,
im_w,
im_h,
num_sample=10):
all_actions = []
for i_sample in range(num_sample):
progress = tqdm(test_img_loader,
desc='trial ({}/{})'.format(i_sample + 1, num_sample))
for i_batch, batch in enumerate(progress):
env_test.set_data(batch)
img_names_batch = batch['img_name']
cat_names_batch = batch['cat_name']
with torch.no_grad():
env_test.reset()
trajs = utils.collect_trajs(env_test,
generator,
patch_num,
max_traj_len,
is_eval=True,
sample_action=True)
all_actions.extend([(cat_names_batch[i], img_names_batch[i],
'present', trajs['actions'][:, i])
for i in range(env_test.batch_size)])
scanpaths = utils.actions2scanpaths(all_actions, patch_num, im_w, im_h)
# Ali: I rempved the following line because I removed bbox_annos from the whole codes
# but maybe I should add it back because it cuts the scanpath on enering the target bounding box
# utils.cutFixOnTarget(scanpaths, bbox_annos)
return scanpaths
def plot_scanpaths_on_images(preds, hyperparams, image_source_location='images/', save_dir='result/'):
for index,elem in enumerate(preds):
filename = elem['task']+"/" + elem['name']
print(str(index)+". "+filename)
image = cv.imread(image_source_location + filename)
X = elem['X']
Y = elem['Y']
image = cv.resize(image, (hyperparams.Data.im_w, hyperparams.Data.im_h))
for i in range(len(X)):
x = int(X[i])
y = int(Y[i])
cv.putText(image, str(i), (x, y), cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
cv.circle(image, (x, y), 2, (255, 255, 255))
if i > 0:
xprec = int(X[i-1])
yprec = int(Y[i-1])
cv.line(image, (xprec, yprec), (x, y), (255, 255, 255))
os.makedirs(save_dir + elem['task'] + "/", exist_ok=True)
cv.imwrite(save_dir + elem['task'] + "/" + elem['name'], image)
if __name__ == '__main__':
# args = docopt(__doc__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hparams = "hparams/coco_search18.json"
dataset_root = "new_dataset_root"
checkpoint = "trained_models"
hparams = JsonConfig(hparams)
# dir of pre-computed DCBs
DCB_dir_HR = join(dataset_root, 'DCBs/HR/')
DCB_dir_LR = join(dataset_root, 'DCBs/LR/')
# place holder for bounding box annotations
# bbox_annos['bottle_111bottle.jpg'] = [43, 195, 34, 105]
# bbox_annos['bottle_222bottle.jpg'] = [54, 205, 32, 114]
# bbox_annos['bottle_333bottle.jpg'] = [54, 205, 32, 114]
with open(join(dataset_root,
'human_scanpaths_TP_trainval_test.json'), encoding='utf-8') as json_file:
human_scanpaths_test = json.load(json_file)
# cat_names = list(np.unique([x['task'] for x in human_scanpaths_train]))
cat_names = ['bottle', 'bowl', 'car', 'chair', 'clock', 'cup', 'fork', 'keyboard', 'knife', 'laptop', 'microwave', 'mouse', 'oven', 'potted plant', 'sink', 'stop sign', 'toilet', 'tv']
catIds = dict(zip(cat_names, list(range(len(cat_names)))))
# dataset = process_data(human_scanpaths_train, human_scanpaths_valid,
# DCB_dir_HR, DCB_dir_LR, bbox_annos, hparams)
# train_task_img_pair = np.unique(
# [traj['task'] + '_' + traj['name'] for traj in human_scanpaths_train])
train_task_img_pair = np.unique(
[traj['task'] + '_' + traj['name'] for traj in human_scanpaths_test])
test_dataset = NEW_LHF_IRL(DCB_dir_HR, DCB_dir_LR, train_task_img_pair, hparams.Data, catIds)
dataloader = torch.utils.data.DataLoader(test_dataset,
batch_size=16,
shuffle=False,
num_workers=2)
# load trained model
input_size = 134 # number of belief maps
task_eye = torch.eye(len(catIds)).to(device)
generator = LHF_Policy_Cond_Small(hparams.Data.patch_count,
len(catIds), task_eye,
input_size).to(device)
generator.eval()
state = torch.load(join(checkpoint, 'trained_generator.pkg'), map_location=device)
generator.load_state_dict(state["model"])
# build environment
env_test = NEW_IRL_Env4LHF(hparams.Data,
max_step=hparams.Data.max_traj_length,
mask_size=hparams.Data.IOR_size,
status_update_mtd=hparams.Train.stop_criteria,
device=device,
inhibit_return=True)
# generate scanpaths
print('sample scanpaths (10 for each testing image)...')
predictions = gen_scanpaths(generator,
env_test,
dataloader,
hparams.Data.patch_num,
hparams.Data.max_traj_length,
hparams.Data.im_w,
hparams.Data.im_h,
num_sample=1)
plot_scanpaths_on_images(preds=predictions,
hyperparams=hparams,
image_source_location='images/',
save_dir='results/')