forked from yang-song/score_sde
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
146 lines (122 loc) · 4.84 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for computing FID/Inception scores."""
import jax
import numpy as np
import six
import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as tfhub
INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'
INCEPTION_OUTPUT = 'logits'
INCEPTION_FINAL_POOL = 'pool_3'
_DEFAULT_DTYPES = {
INCEPTION_OUTPUT: tf.float32,
INCEPTION_FINAL_POOL: tf.float32
}
INCEPTION_DEFAULT_IMAGE_SIZE = 299
def get_inception_model(inceptionv3=False):
if inceptionv3:
return tfhub.load(
'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4')
else:
return tfhub.load(INCEPTION_TFHUB)
def load_dataset_stats(config):
"""Load the pre-computed dataset statistics."""
if config.data.dataset == 'CIFAR10':
filename = 'assets/stats/cifar10_stats.npz'
elif config.data.dataset == 'CELEBA':
filename = 'assets/stats/celeba_stats.npz'
elif config.data.dataset == 'LSUN':
filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz'
else:
raise ValueError(f'Dataset {config.data.dataset} stats not found.')
with tf.io.gfile.GFile(filename, 'rb') as fin:
stats = np.load(fin)
return stats
def classifier_fn_from_tfhub(output_fields, inception_model,
return_tensor=False):
"""Returns a function that can be as a classifier function.
Copied from tfgan but avoid loading the model each time calling _classifier_fn
Args:
output_fields: A string, list, or `None`. If present, assume the module
outputs a dictionary, and select this field.
inception_model: A model loaded from TFHub.
return_tensor: If `True`, return a single tensor instead of a dictionary.
Returns:
A one-argument function that takes an image Tensor and returns outputs.
"""
if isinstance(output_fields, six.string_types):
output_fields = [output_fields]
def _classifier_fn(images):
output = inception_model(images)
if output_fields is not None:
output = {x: output[x] for x in output_fields}
if return_tensor:
assert len(output) == 1
output = list(output.values())[0]
return tf.nest.map_structure(tf.compat.v1.layers.flatten, output)
return _classifier_fn
@tf.function
def run_inception_jit(inputs,
inception_model,
num_batches=1,
inceptionv3=False):
"""Running the inception network. Assuming input is within [0, 255]."""
if not inceptionv3:
inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5
else:
inputs = tf.cast(inputs, tf.float32) / 255.
return tfgan.eval.run_classifier_fn(
inputs,
num_batches=num_batches,
classifier_fn=classifier_fn_from_tfhub(None, inception_model),
dtypes=_DEFAULT_DTYPES)
@tf.function
def run_inception_distributed(input_tensor,
inception_model,
num_batches=1,
inceptionv3=False):
"""Distribute the inception network computation to all available TPUs.
Args:
input_tensor: The input images. Assumed to be within [0, 255].
inception_model: The inception network model obtained from `tfhub`.
num_batches: The number of batches used for dividing the input.
inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1.
Returns:
A dictionary with key `pool_3` and `logits`, representing the pool_3 and
logits of the inception network respectively.
"""
num_tpus = jax.local_device_count()
input_tensors = tf.split(input_tensor, num_tpus, axis=0)
pool3 = []
logits = [] if not inceptionv3 else None
device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}'
for i, tensor in enumerate(input_tensors):
with tf.device(device_format.format(i)):
tensor_on_device = tf.identity(tensor)
res = run_inception_jit(
tensor_on_device, inception_model, num_batches=num_batches,
inceptionv3=inceptionv3)
if not inceptionv3:
pool3.append(res['pool_3'])
logits.append(res['logits']) # pytype: disable=attribute-error
else:
pool3.append(res)
with tf.device('/CPU'):
return {
'pool_3': tf.concat(pool3, axis=0),
'logits': tf.concat(logits, axis=0) if not inceptionv3 else None
}