forked from keithito/tacotron
-
Notifications
You must be signed in to change notification settings - Fork 103
/
export.py
56 lines (47 loc) · 1.7 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import tensorflow as tf
from synthesizer import Synthesizer
from models import create_model
from hparams import hparams, hparams_debug_string
from util import audio
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint_path', required=True, help='path to model checkpoint'
)
parser.add_argument(
'--export_path', required=True, help='path to export model'
)
args = parser.parse_args()
builder = tf.saved_model.builder.SavedModelBuilder(args.export_path)
synth = Synthesizer()
synth.load(args.checkpoint_path)
inputs = tf.saved_model.utils.build_tensor_info(synth.model.inputs)
input_lengths = tf.saved_model.utils.build_tensor_info(
synth.model.input_lengths
)
w_o = audio.inv_spectrogram_tensorflow(
synth.model.linear_outputs
)
wav_output = tf.saved_model.utils.build_tensor_info(w_o)
alignment = tf.saved_model.utils.build_tensor_info(synth.model.alignments)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
'inputs': inputs,
"input_lengths": input_lengths
},
outputs={
'wav_output': wav_output,
'alignment': alignment
},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
)
with synth.session as sess:
builder.add_meta_graph_and_variables(
sess=sess,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={'predict': prediction_signature}
)
builder.save()
print("exported .pb to", args.export_path)