-
Notifications
You must be signed in to change notification settings - Fork 14
/
trainer.py
142 lines (109 loc) · 5.15 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import torch
from tqdm import tqdm
import numpy as np
from scipy.stats import spearmanr, pearsonr
""" train model """
def train_epoch(config, epoch, model_transformer, model_backbone, criterion, optimizer, scheduler, train_loader):
losses = []
model_transformer.train()
model_backbone.train()
# input mask (batch_size x len_sqe+1)
mask_inputs = torch.ones(config.batch_size, config.n_enc_seq+1).to(config.device)
# save data for one epoch
pred_epoch = []
labels_epoch = []
for data in tqdm(train_loader):
# labels: batch size
# d_img_org: 3 x 768 x 1024
# d_img_scale_1: 3 x 288 x 384
# d_img_scale_2: 3 x 160 x 224
d_img_org = data['d_img_org'].to(config.device)
d_img_scale_1 = data['d_img_scale_1'].to(config.device)
d_img_scale_2 = data['d_img_scale_2'].to(config.device)
labels = data['score']
labels = torch.squeeze(labels.type(torch.FloatTensor)).to(config.device)
# backbone feature map (dis)
# feat_dis_org: 2048 x 24 x 32
# feat_dis_scale_1: 2048 x 9 x 12
# feat_dis_scale_2: 2048 x 5 x 7
feat_dis_org = model_backbone(d_img_org)
feat_dis_scale_1 = model_backbone(d_img_scale_1)
feat_dis_scale_2 = model_backbone(d_img_scale_2)
# this value should be extracted from backbone network
# enc_inputs_embed: batch x len_seq x n_feat
# weight update
optimizer.zero_grad()
pred = model_transformer(mask_inputs, feat_dis_org, feat_dis_scale_1, feat_dis_scale_2)
loss = criterion(torch.squeeze(pred), labels)
loss_val = loss.item()
losses.append(loss_val)
loss.backward()
optimizer.step()
scheduler.step()
# save results in one epoch
pred_batch_numpy = pred.data.cpu().numpy()
labels_batch_numpy = labels.data.cpu().numpy()
pred_epoch = np.append(pred_epoch, pred_batch_numpy)
labels_epoch = np.append(labels_epoch, labels_batch_numpy)
# compute correlation coefficient
rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
print('[train] epoch:%d / loss:%f / SROCC:%4f / PLCC:%4f' % (epoch+1, loss.item(), rho_s, rho_p))
# save weights
if (epoch+1) % config.save_freq == 0:
weights_file_name = "epoch%d.pth" % (epoch+1)
weights_file = os.path.join(config.snap_path, weights_file_name)
torch.save({
'epoch': epoch,
'model_backbone_state_dict': model_backbone.state_dict(),
'model_transformer_state_dict': model_transformer.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss
}, weights_file)
print('save weights of epoch %d' % (epoch+1))
return np.mean(losses), rho_s, rho_p
""" validation """
def eval_epoch(config, epoch, model_transformer, model_backbone, criterion, test_loader):
with torch.no_grad():
losses = []
model_transformer.eval()
model_backbone.eval()
# value is not changed
mask_inputs = torch.ones(config.batch_size, config.n_enc_seq+1).to(config.device)
# save data for one epoch
pred_epoch = []
labels_epoch = []
for data in tqdm(test_loader):
# labels: batch size
# d_img_org: batch x 3 x 768 x 1024
# d_img_scale_1: batch x 3 x 288 x 384
# d_img_scale_2: batch x 3 x 160 x 224
d_img_org = data['d_img_org'].to(config.device)
d_img_scale_1 = data['d_img_scale_1'].to(config.device)
d_img_scale_2 = data['d_img_scale_2'].to(config.device)
labels = data['score']
labels = torch.squeeze(labels.type(torch.FloatTensor)).to(config.device)
# backbone featuremap
# feat_dis_org: batch x 2048 x 24 x 32
# feat_dis_scale_1: batch x 2048 x 9 x 12
# feat_dis_scale_2: batch x 2048 x 5 x 12
feat_dis_org = model_backbone(d_img_org)
feat_dis_scale_1 = model_backbone(d_img_scale_1)
feat_dis_scale_2 = model_backbone(d_img_scale_2)
pred = model_transformer(mask_inputs, feat_dis_org, feat_dis_scale_1, feat_dis_scale_2)
# compute loss
loss = criterion(torch.squeeze(pred), labels)
loss_val = loss.item()
losses.append(loss_val)
# save results in one epoch
pred_batch_numpy = pred.data.cpu().numpy()
labels_batch_numpy = labels.data.cpu().numpy()
pred_epoch = np.append(pred_epoch, pred_batch_numpy)
labels_epoch = np.append(labels_epoch, labels_batch_numpy)
# compute correlation coefficient
rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
print('test epoch:%d / loss:%f /SROCC:%4f / PLCC:%4f' % (epoch+1, loss.item(), rho_s, rho_p))
return np.mean(losses), rho_s, rho_p