forked from ScaramuzzinoGiovanna/Watermark-DnCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AuxVisualizer_train.py
137 lines (103 loc) · 5.53 KB
/
AuxVisualizer_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
import os, cv2
import numpy as np
import DnCNNModel
import AuxVisualizerModel
import tensorflow as tf
np.random.seed(0)
# comment here to change source model.'DnCNN_weight' is original model, 'overwrting' is WM trained model
# org_model_path = './DnCNN_weight/'
image_mod = 0
type = 'sign'
daub_size = [320, 320, 2 * image_mod + 1]
DIP_model_name = 'Black_DIP_' + type + '_weight_'
def post_process(input, step):
# print (input.shape)
input = np.squeeze(input, axis=0)
input = input * 255
input = np.clip(input, 0, 255)
input = input.astype(np.uint8)
input = np.squeeze(input, axis=2)
cv2.imwrite('./temp/gt_' + str(step) + '.png', input)
def transition(w):
return w
def ft_DIP_optimizer(loss, lr):
optimizer = tf.train.AdamOptimizer(0.001, name='AdamOptimizer_DIP')
ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
tensors = [k for k in ops if k.name.startswith('DIP')]
with tf.control_dependencies(tensors):
var_list = [t for t in tf.all_variables() if t.name.startswith('DIP')]
gradient = optimizer.compute_gradients(loss, var_list=var_list)
train_op = optimizer.apply_gradients(gradient)
return train_op
def train(train_data='./data/img_clean_pats.npy', org_model_path='./overwriting/', comb_model_path='./combine_weight/',
test_img_dir='./test_img', trigger_img='key_imgs/trigger_image.png', epochs=8, batch_size=128,
learn_rate=0.001, sigma=25):
degraded_image = os.path.join(test_img_dir, type + '.png') # copyright img
# special_num = 20
with tf.Graph().as_default():
lr = tf.placeholder(tf.float32, shape=[], name='learning_rate')
tag = tf.placeholder(tf.float32, shape=[], name='tag')
training = tf.placeholder(tf.bool, name='is_training')
img_clean = tf.placeholder(tf.float32, [None, None, None, 1], name='clean_image')
images_daub = tf.placeholder(tf.float32, [None, daub_size[0], daub_size[1], daub_size[2]])
# #DnCNN model
img_noise = img_clean + tag * tf.random_normal(shape=tf.shape(img_clean),
stddev=sigma / 255.0) # img clean = trigger img
Y, N = DnCNNModel.dncnn(img_noise, is_training=training)
dncnn_loss = DnCNNModel.lossing(Y, img_clean, batch_size)
# extract weight
dncnn_s_out = transition(N)
# DeepPrior model
ldr = AuxVisualizerModel.Encoder_decoder(dncnn_s_out, is_training=True) # dncnn_s_out = verification img
dip_loss = AuxVisualizerModel.lossing(ldr, images_daub)
# Update DIP model
dip_opt = ft_DIP_optimizer(dip_loss, lr)
init = tf.global_variables_initializer()
dncnn_var_list = [v for v in tf.global_variables() if v.name.startswith('block')]
DnCNN_saver = tf.train.Saver(dncnn_var_list)
dip_var_list = [v for v in tf.all_variables() if v.name.startswith('DIP')]
DIP_saver = tf.train.Saver(dip_var_list, max_to_keep=50)
with tf.Session() as sess:
data_total = np.load(train_data)
data_total = data_total.astype(np.float32) / 255.0
num_example, row, col, chanel = data_total.shape
numBatch = num_example // batch_size
daub_Images = cv2.imread(degraded_image, 0) # copyright img
daub_Images = cv2.resize(daub_Images, (daub_size[0], daub_size[1]))
daub_Images = daub_Images.astype(np.float32) / 255
daub_Images = np.expand_dims(daub_Images, axis=0)
daub_Images = np.expand_dims(daub_Images, axis=3)
# daub_Images = np.repeat(daub_Images, special_num, axis=0)
special_input = cv2.imread(trigger_img, 0)
special_input = special_input.astype(np.float32) / 255.0
special_input = np.expand_dims(special_input, 0)
special_input = np.expand_dims(special_input, 3)
# special_input = np.repeat(special_input, special_num, axis=0)
sess.run(init)
ckpt = tf.train.get_checkpoint_state(org_model_path)
if ckpt and ckpt.model_checkpoint_path:
full_path = tf.train.latest_checkpoint(org_model_path)
print(full_path)
DnCNN_saver.restore(sess, full_path)
print("Loading " + os.path.basename(full_path) + " to the model")
else:
print("DnCNN weight must be exist")
assert ckpt != None, 'weights not exist'
step = 0
for epoch in range(0, epochs):
np.random.shuffle(data_total)
for batch_id in range(0, numBatch):
__ = sess.run(dip_opt, feed_dict={img_clean: special_input, lr: learn_rate,
images_daub: daub_Images, tag: 0.0,
training: False})
step += 1
if batch_id % 100 == 0:
dip_lost = sess.run(dip_loss, feed_dict={img_clean: special_input, lr: learn_rate,
images_daub: daub_Images, tag: 0.0,
training: False})
print("step = %d, dncnn_loss = %f, dip_loss = %f" % (step, 0, dip_lost))
save_path = DIP_saver.save(sess, comb_model_path + DIP_model_name + str(epoch + 1) + ".ckpt")
print("+++++ epoch " + str(epoch + 1) + " is saved successfully +++++")
if __name__ == '__main__':
train()