forked from NVIDIA/tacotron2
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference_hifigan.py
131 lines (105 loc) · 4.04 KB
/
inference_hifigan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# coding: utf-8
"""
Synthesize waveform using Hifi-GAN.
usage: inference_hifigan.py [options] <dst_dir>
options:
--hparams=<parmas> Hyper parameters [default: ].
-i, --input-file=<p> Input txt file path.
-t, --tacotron-checkpoint=<p> Tacotron Checkpoint Path
-w, --hifigan-checkpoint=<p> Hifi-GAN Checkpoint Path
-h, --help Show help message.
"""
from docopt import docopt
import torch
import json
import numpy as np
from os import makedirs
from os.path import isfile, dirname, join, basename, split, splitext
import soundfile as sf
# Tacotron2 modules
from model import Tacotron2
from hparams import create_hparams
from train import load_model
from text import text_to_sequence
import sys
sys.path.append('hifigan/')
# HiFiGAN modules
from hifigan.env import AttrDict
from hifigan.models import Generator
MAX_WAV_VALUE = 32767.5
hparams = create_hparams()
hparams.sampling_rate = 22050
hparams.filter_length = 1024
hparams.hop_length = 256
hparams.win_length = 1024
def load_hifigan(checkpoint_file):
config_file = join(split(checkpoint_file)[0], 'config.json')
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
try:
h.use_speaker_embedding
except:
h.use_speaker_embedding = False
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda')
else:
device = torch.device('cpu')
generator = Generator(h).to(device)
state_dict_g = torch.load(checkpoint_file, map_location=device)
generator.load_state_dict(state_dict_g['generator'])
return generator
def load_tacotron2(tacotron_checkpoint_path):
global hparams
model = load_model(hparams)
model.load_state_dict(torch.load(tacotron_checkpoint_path)['state_dict'])
_ = model.eval()
return model
def hifigan_prediction(generator, spectrogram):
y_g_hat = generator(spectrogram)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.detach().cpu().numpy().astype('int16')
return audio
def waveform_generation(text, tacotron_model, hifigan_model):
# Prepare text input
sequence = np.array(text_to_sequence(text, ['basic_cleaners']))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda().long()
# Tacotron Inference
mel_outputs, mel_outputs_postnet, _, alignments = tacotron_model.inference(sequence)
# HiFiGAN Inference
with torch.no_grad():
waveform = hifigan_prediction(hifigan_model, mel_outputs_postnet)
return waveform
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
tacotron_checkpoint = args["--tacotron-checkpoint"]
hifigan_checkpoint = args["--hifigan-checkpoint"]
input_file_path = args["--input-file"]
dst_dir = args["<dst_dir>"]
checkpoint_wav_name = splitext(basename(hifigan_checkpoint))[0].replace('hifi_', '')
checkpoint_taco_name = splitext(basename(tacotron_checkpoint))[0].replace('checkpoint_', '')
tacotron_model = load_tacotron2(tacotron_checkpoint)
hifigan_model = load_hifigan(hifigan_checkpoint)
try:
with open(input_file_path) as f:
content = f.read().splitlines()
except FileNotFoundError:
print("File {} was not found.".format(input_file_path))
# Create output directory
subdir = 'samples_hifigan_' + checkpoint_wav_name + '_taco_' + checkpoint_taco_name
makedirs(join(dst_dir, subdir), exist_ok=True)
for i, text in enumerate(content):
print("Generating Waveform " + str(i))
waveform = waveform_generation(text, tacotron_model, hifigan_model)
# save
output_filepath = join(dst_dir, subdir, "{}.wav".format(i))
sf.write(output_filepath, waveform, samplerate=hparams.sampling_rate)
print("Waveform {} OK".format(i))
print("Finished! Check out {} for generated audio samples.".format(dst_dir))
sys.exit(0)