generated from roberttwomey/ml-art-final
-
Notifications
You must be signed in to change notification settings - Fork 2
/
perceptual_model.py
98 lines (75 loc) · 4.46 KB
/
perceptual_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
import PIL.Image
import os
def load_images(images_list, img_size):
loaded_images = list()
for img_path in images_list:
img = image.load_img(img_path, target_size=(img_size, img_size))
img = np.expand_dims(img, 0)
loaded_images.append(img)
loaded_images = np.vstack(loaded_images)
preprocessed_images = preprocess_input(loaded_images)
return preprocessed_images
class PerceptualModel:
def __init__(self, img_size, layer=9, batch_size=1, sess=None, generator=None):
self.generator = generator
self.sess = tf.get_default_session() if sess is None else sess
K.set_session(self.sess)
self.img_size = img_size
self.layer = layer
self.batch_size = batch_size
self.perceptual_model = None
self.ref_img_features = None
self.features_weight = None
self.loss = None
def build_perceptual_model(self, generated_image_tensor):
vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3))
self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
generated_image = preprocess_input(tf.image.resize_images(generated_image_tensor,
(self.img_size, self.img_size), method=1))
generated_img_features = self.perceptual_model(generated_image)
self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
self.loss = tf.losses.mean_squared_error(self.features_weight * self.ref_img_features,
self.features_weight * generated_img_features) / 82890.0
def set_reference_images(self, images_list, names=None):
self.names = names
assert(len(images_list) != 0 and len(images_list) <= self.batch_size)
loaded_image = load_images(images_list, self.img_size)
image_features = self.perceptual_model.predict_on_batch(loaded_image)
# in case if number of images less than actual batch size
# can be optimized further
weight_mask = np.ones(self.features_weight.shape)
if len(images_list) != self.batch_size:
features_space = list(self.features_weight.shape[1:])
existing_features_shape = [len(images_list)] + features_space
empty_features_shape = [self.batch_size - len(images_list)] + features_space
existing_examples = np.ones(shape=existing_features_shape)
empty_examples = np.zeros(shape=empty_features_shape)
weight_mask = np.vstack([existing_examples, empty_examples])
image_features = np.vstack([image_features, np.zeros(empty_features_shape)])
self.sess.run(tf.assign(self.features_weight, weight_mask))
self.sess.run(tf.assign(self.ref_img_features, image_features))
def optimize(self, vars_to_optimize, iterations=500, learning_rate=1.):
vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
for i in range(iterations):
_, loss = self.sess.run([min_op, self.loss])
# Generate images
generate_every = 10
if i % generate_every == 0:
generated_images = self.generator.generate_images()
generated_dlatents = self.generator.get_dlatents()
for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, self.names):
img = PIL.Image.fromarray(img_array, 'RGB')
img.save(os.path.join("images/generated_images/1440ep/", f'{img_name}_{i:04d}ep.png'), 'PNG')
yield loss