From 4e20829399920e634ebd189c3af48479eae8f1f2 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 6 Jan 2018 21:38:42 +0900 Subject: [PATCH 01/12] Add script to generate training data for wavenet vocoder --- generate_aligned_predictions.py | 178 ++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 generate_aligned_predictions.py diff --git a/generate_aligned_predictions.py b/generate_aligned_predictions.py new file mode 100644 index 00000000..7b836421 --- /dev/null +++ b/generate_aligned_predictions.py @@ -0,0 +1,178 @@ +# coding: utf-8 +""" +Generate ground trouth-aligned predictions + +usage: generate_aligned_predictions.py [options] + +options: + --hparams= Hyper parameters [default: ]. + --overwrite Overwrite audio and mel outputs. + -h, --help Show help message. +""" +from docopt import docopt +import os +from tqdm import tqdm +import importlib +from os.path import join +from warnings import warn +import sys + +import numpy as np +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + +# The deepvoice3 model +from deepvoice3_pytorch import frontend +from hparams import hparams + +use_cuda = torch.cuda.is_available() +_frontend = None # to be set later + + +def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, + p=0, speaker_id=None, + fast=False): + """Generate ground truth-aligned prediction + + The output of the network and corresponding audio are saved after time + resolution adjastment if overwrite flag is specified. + """ + r = hparams.outputs_per_step + downsample_step = hparams.downsample_step + + if use_cuda: + model = model.cuda() + model.eval() + if fast: + model.make_generation_fast_() + + mel_org = np.load(join(in_dir, mel_filename)) + mel = Variable(torch.from_numpy(mel_org)).unsqueeze(0).contiguous() + + # Downsample mel spectrogram + if downsample_step > 1: + mel = mel[:, 0::downsample_step, :].contiguous() + + decoder_target_len = mel.shape[1] // r + s, e = 1, decoder_target_len + 1 + frame_positions = torch.arange(s, e).long().unsqueeze(0) + frame_positions = Variable(frame_positions) + + sequence = np.array(_frontend.text_to_sequence(text, p=p)) + sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0) + text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long() + text_positions = Variable(text_positions) + speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id])) + if use_cuda: + sequence = sequence.cuda() + text_positions = text_positions.cuda() + speaker_ids = None if speaker_ids is None else speaker_ids.cuda() + mel = mel.cuda() + frame_positions = frame_positions.cuda() + + # **Teacher forcing** decoding + mel_outputs, _, _, _ = model( + sequence, mel, text_positions=text_positions, + frame_positions=frame_positions, speaker_ids=speaker_ids) + + mel_output = mel_outputs[0].data.cpu().numpy() + + # **Time resolution adjastment** + # remove begenning audio used for first mel prediction + wav = np.load(join(in_dir, audio_filename))[hparams.hop_size * downsample_step:] + assert len(wav) % hparams.hop_size == 0 + + # Coarse upsample just for convenience + # so that we can upsample conditional features by hop_size in wavenet + if downsample_step > 0: + mel_output = np.repeat(mel_output, downsample_step, axis=0) + # downsampling -> upsampling, then we should have length equal to or larger than + # the original mel length + assert mel_output.shape[0] >= mel_org.shape[0] + + # Trim mel output + expected_frames = len(wav) // hparams.hop_size + mel_output = mel_output[:expected_frames] + + # Make sure we have correct lengths + assert mel_output.shape[0] * hparams.hop_size == len(wav) + + timesteps = len(wav) + + # save + np.save(join(out_dir, audio_filename), wav.astype(np.int16), + allow_pickle=False) + np.save(join(out_dir, mel_filename), mel_output.astype(np.float32), + allow_pickle=False) + + if speaker_id is None: + return (audio_filename, mel_filename, timesteps, text) + else: + return (audio_filename, mel_filename, timesteps, text, speaker_id) + + +def write_metadata(metadata, out_dir): + with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: + for m in metadata: + f.write('|'.join([str(x) for x in m]) + '\n') + frames = sum([m[2] for m in metadata]) + sr = hparams.sample_rate + hours = frames / sr / 3600 + print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) + print('Max input length: %d' % max(len(m[3]) for m in metadata)) + print('Max output length: %d' % max(m[2] for m in metadata)) + + +if __name__ == "__main__": + args = docopt(__doc__) + checkpoint_path = args[""] + in_dir = args[""] + out_dir = args[""] + + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "deepvoice3" + + # Presets + if hparams.preset is not None and hparams.preset != "": + preset = hparams.presets[hparams.preset] + import json + hparams.parse_json(json.dumps(preset)) + print("Override hyper parameters with preset \"{}\": {}".format( + hparams.preset, json.dumps(preset, indent=4))) + + _frontend = getattr(frontend, hparams.frontend) + import train + train._frontend = _frontend + from train import build_model + + model = build_model() + + # Load checkpoint + print("Load checkpoint from {}".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint["state_dict"]) + + os.makedirs(out_dir, exist_ok=True) + results = [] + with open(os.path.join(in_dir, "train.txt")) as f: + lines = f.readlines() + + for idx in tqdm(range(len(lines))): + l = lines[idx] + l = l[:-1].split("|") + audio_filename, mel_filename, _, text = l[:4] + speaker_id = int(l[4]) if len(l) > 4 else None + if text == "N/A": + raise RuntimeError("No transcription available") + + result = preprocess(model, in_dir, out_dir, text, audio_filename, + mel_filename, p=0, + speaker_id=speaker_id, fast=True) + results.append(result) + + write_metadata(results, out_dir) + + sys.exit(0) From ba1ecfd2d91325c357331e28b7930b6a1f95053e Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 6 Jan 2018 21:45:06 +0900 Subject: [PATCH 02/12] fix typo --- generate_aligned_predictions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate_aligned_predictions.py b/generate_aligned_predictions.py index 7b836421..776fdbe6 100644 --- a/generate_aligned_predictions.py +++ b/generate_aligned_predictions.py @@ -37,7 +37,7 @@ def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, """Generate ground truth-aligned prediction The output of the network and corresponding audio are saved after time - resolution adjastment if overwrite flag is specified. + resolution adjustment. """ r = hparams.outputs_per_step downsample_step = hparams.downsample_step @@ -79,7 +79,7 @@ def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, mel_output = mel_outputs[0].data.cpu().numpy() - # **Time resolution adjastment** + # **Time resolution adjustment** # remove begenning audio used for first mel prediction wav = np.load(join(in_dir, audio_filename))[hparams.hop_size * downsample_step:] assert len(wav) % hparams.hop_size == 0 From b458377c3f1e559be7a919f4fc677eb2d844c8b1 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Mon, 8 Jan 2018 23:15:00 +0900 Subject: [PATCH 03/12] hparams option for preprocess --- preprocess.py | 1 - 1 file changed, 1 deletion(-) diff --git a/preprocess.py b/preprocess.py index d76de83f..9a4eeac3 100644 --- a/preprocess.py +++ b/preprocess.py @@ -52,7 +52,6 @@ def write_metadata(metadata, out_dir): # Override hyper parameters hparams.parse(args["--hparams"]) assert hparams.name == "deepvoice3" - print(hparams_debug_string()) assert name in ["jsut", "ljspeech", "vctk", "nikl_m", "nikl_s", "json_meta"] mod = importlib.import_module(name) From 3b4fa649a509007b1626b8b8f16e9ca2c12d67e7 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Mon, 8 Jan 2018 23:15:55 +0900 Subject: [PATCH 04/12] ignore my local files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index feb4ef97..9ba52f53 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ notebooks foobar* run.sh README.rst +legacy +notebooks +run.sh pretrained_models deepvoice3_pytorch/version.py checkpoints* From 8360683cd385a894b0b28e38cbdc371836b41a75 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Wed, 10 Jan 2018 19:56:46 +0900 Subject: [PATCH 05/12] Fixtypo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2971dd1d..6536aaa2 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ python preprocess.py nikl_s ${your_nikl_root_path} data/nikl_s --preset=presets/ python train.py --data-root=./data/nikl_s --checkpoint-dir checkpoint_nikl_s --preset=presets/deepvoice3_nikls.json ``` -### 4. Monitor with Tensorboard +### 3. Monitor with Tensorboard Logs are dumped in `./log` directory by default. You can monitor logs by tensorboard: @@ -205,7 +205,7 @@ Logs are dumped in `./log` directory by default. You can monitor logs by tensorb tensorboard --logdir=log ``` -### 5. Synthesize from a checkpoint +### 4. Synthesize from a checkpoint Given a list of text, `synthesis.py` synthesize audio signals from trained model. Usage is: From b54c2793f1d73abbe85c30e3dbe30f6ff90a5ff4 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 3 Mar 2018 13:24:30 +0900 Subject: [PATCH 06/12] fix for RuntimeError --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index b7918066..ce650257 100644 --- a/train.py +++ b/train.py @@ -993,6 +993,9 @@ def restore_parts(path, model): clip_thresh=hparams.clip_thresh, train_seq2seq=train_seq2seq, train_postnet=train_postnet) except KeyboardInterrupt: + print("Interrupted!") + pass + finally: save_checkpoint( model, optimizer, global_step, checkpoint_dir, global_epoch, train_seq2seq, train_postnet) From a04941babbf20fa9e9d6ba30da42b1db8bdf4942 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 3 Mar 2018 13:25:01 +0900 Subject: [PATCH 07/12] Cleanup generate aligned features fixes a bug for r > 1 --- generate_aligned_predictions.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/generate_aligned_predictions.py b/generate_aligned_predictions.py index 776fdbe6..eda182f7 100644 --- a/generate_aligned_predictions.py +++ b/generate_aligned_predictions.py @@ -49,7 +49,13 @@ def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, model.make_generation_fast_() mel_org = np.load(join(in_dir, mel_filename)) - mel = Variable(torch.from_numpy(mel_org)).unsqueeze(0).contiguous() + # zero padd + b_pad = r # imitates initial state + e_pad = r - len(mel_org) % r if len(mel_org) % r > 0 else 0 + mel = np.pad(mel_org, [(b_pad, e_pad), (0, 0)], + mode="constant", constant_values=0) + + mel = Variable(torch.from_numpy(mel)).unsqueeze(0).contiguous() # Downsample mel spectrogram if downsample_step > 1: @@ -78,10 +84,10 @@ def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, frame_positions=frame_positions, speaker_ids=speaker_ids) mel_output = mel_outputs[0].data.cpu().numpy() - # **Time resolution adjustment** - # remove begenning audio used for first mel prediction - wav = np.load(join(in_dir, audio_filename))[hparams.hop_size * downsample_step:] + mel_output = mel_output[:-(b_pad + e_pad)] + + wav = np.load(join(in_dir, audio_filename)) assert len(wav) % hparams.hop_size == 0 # Coarse upsample just for convenience @@ -92,18 +98,13 @@ def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, # the original mel length assert mel_output.shape[0] >= mel_org.shape[0] - # Trim mel output - expected_frames = len(wav) // hparams.hop_size - mel_output = mel_output[:expected_frames] - # Make sure we have correct lengths assert mel_output.shape[0] * hparams.hop_size == len(wav) timesteps = len(wav) # save - np.save(join(out_dir, audio_filename), wav.astype(np.int16), - allow_pickle=False) + np.save(join(out_dir, audio_filename), wav, allow_pickle=False) np.save(join(out_dir, mel_filename), mel_output.astype(np.float32), allow_pickle=False) From d16e345e02ef97a860ab1f6cff75aebd499a8a41 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 3 Mar 2018 22:57:13 +0900 Subject: [PATCH 08/12] try this --- hparams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hparams.py b/hparams.py index 2373a050..950755d6 100644 --- a/hparams.py +++ b/hparams.py @@ -43,7 +43,7 @@ # whether to rescale waveform or not. # Let x is an input waveform, rescaled waveform y is given by: # y = x / np.abs(x).max() * rescaling_max - rescaling=False, + rescaling=True, rescaling_max=0.999, # mel-spectrogram is normalized to [0, 1] for each utterance and clipping may # happen depends on min_level_db and ref_level_db, causing clipping noise. From e40d3cbbe3db026df57f4dcead06b40a547443ff Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sun, 4 Mar 2018 00:57:11 +0900 Subject: [PATCH 09/12] Fix for master --- generate_aligned_predictions.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/generate_aligned_predictions.py b/generate_aligned_predictions.py index eda182f7..a72053a5 100644 --- a/generate_aligned_predictions.py +++ b/generate_aligned_predictions.py @@ -6,6 +6,7 @@ options: --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). --overwrite Overwrite audio and mel outputs. -h, --help Show help message. """ @@ -131,19 +132,16 @@ def write_metadata(metadata, out_dir): checkpoint_path = args[""] in_dir = args[""] out_dir = args[""] + preset = args["--preset"] + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) # Override hyper parameters hparams.parse(args["--hparams"]) assert hparams.name == "deepvoice3" - # Presets - if hparams.preset is not None and hparams.preset != "": - preset = hparams.presets[hparams.preset] - import json - hparams.parse_json(json.dumps(preset)) - print("Override hyper parameters with preset \"{}\": {}".format( - hparams.preset, json.dumps(preset, indent=4))) - _frontend = getattr(frontend, hparams.frontend) import train train._frontend = _frontend From ae3e8ae54d7f93e864d6e0a9ffc373da05a8a891 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sun, 4 Mar 2018 01:02:54 +0900 Subject: [PATCH 10/12] preset parameters for DeepVoice3 + WaveNet --- presets/deepvoice3_ljspeech_wavenet.json | 65 ++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 presets/deepvoice3_ljspeech_wavenet.json diff --git a/presets/deepvoice3_ljspeech_wavenet.json b/presets/deepvoice3_ljspeech_wavenet.json new file mode 100644 index 00000000..637546e1 --- /dev/null +++ b/presets/deepvoice3_ljspeech_wavenet.json @@ -0,0 +1,65 @@ +{ + "name": "deepvoice3", + "frontend": "en", + "replace_pronunciation_prob": 0.5, + "builder": "deepvoice3", + "n_speakers": 1, + "speaker_embed_dim": 16, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "sample_rate": 22050, + "preemphasis": 0.0, + "min_level_db": -100, + "ref_level_db": 20, + "rescaling": true, + "rescaling_max": 0.999, + "allow_clipping_in_normalization": true, + "downsample_step": 1, + "outputs_per_step": 4, + "embedding_weight_std": 0.1, + "speaker_embedding_weight_std": 0.01, + "padding_idx": 0, + "max_positions": 2048, + "dropout": 0.050000000000000044, + "kernel_size": 3, + "text_embed_dim": 256, + "encoder_channels": 512, + "decoder_channels": 256, + "converter_channels": 256, + "query_position_rate": 1.0, + "key_position_rate": 1.385, + "key_projection": true, + "value_projection": false, + "use_memory_mask": true, + "trainable_positional_encodings": false, + "freeze_embedding": false, + "use_decoder_state_for_postnet_input": true, + "pin_memory": true, + "num_workers": 2, + "masked_loss_weight": 0.5, + "priority_freq": 3000, + "priority_freq_weight": 0.0, + "binary_divergence_weight": 0.1, + "use_guided_attention": true, + "guided_attention_sigma": 0.2, + "batch_size": 16, + "adam_beta1": 0.5, + "adam_beta2": 0.9, + "adam_eps": 1e-06, + "initial_learning_rate": 0.0005, + "lr_schedule": "noam_learning_rate_decay", + "lr_schedule_kwargs": {}, + "nepochs": 2000, + "weight_decay": 0.0, + "clip_thresh": 0.1, + "checkpoint_interval": 10000, + "eval_interval": 10000, + "save_optimizer_state": true, + "force_monotonic_attention": true, + "window_ahead": 3, + "window_backward": 1, + "power": 1.4 +} From 672e2492c50b9baaf67b4765d287bc8e728a9ba2 Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sun, 4 Mar 2018 22:54:17 +0900 Subject: [PATCH 11/12] cleanup --- dump_hparams_to_json.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dump_hparams_to_json.py b/dump_hparams_to_json.py index d67e88d3..f0554605 100644 --- a/dump_hparams_to_json.py +++ b/dump_hparams_to_json.py @@ -12,13 +12,9 @@ import sys import os from os.path import dirname, join, basename, splitext +import json -import audio - -# The deepvoice3 model -from deepvoice3_pytorch import frontend from hparams import hparams -import json if __name__ == "__main__": args = docopt(__doc__) From 260d50d92dcbf9a7f63605d83c1909da213e7c5d Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Tue, 1 May 2018 13:40:12 +0900 Subject: [PATCH 12/12] WIP --- hparams.py | 20 +++++------ presets/deepvoice3_ljspeech_wavenet.json | 12 +++---- synthesis.py | 45 ++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/hparams.py b/hparams.py index 950755d6..729c2605 100644 --- a/hparams.py +++ b/hparams.py @@ -51,25 +51,25 @@ allow_clipping_in_normalization=True, # Model: - downsample_step=4, # must be 4 when builder="nyanko" - outputs_per_step=1, # must be 1 when builder="nyanko" + downsample_step=1, # must be 4 when builder="nyanko" + outputs_per_step=4, # must be 1 when builder="nyanko" embedding_weight_std=0.1, speaker_embedding_weight_std=0.01, padding_idx=0, # Maximum number of input text length # try setting larger value if you want to give very long text input - max_positions=512, - dropout=1 - 0.95, - kernel_size=3, - text_embed_dim=128, - encoder_channels=256, - decoder_channels=256, + max_positions=2048, + dropout=1 - 0.90, + kernel_size=5, + text_embed_dim=256, + encoder_channels=512, + decoder_channels=512, # Note: large converter channels requires significant computational cost converter_channels=256, query_position_rate=1.0, # can be computed by `compute_timestamp_ratio.py`. key_position_rate=1.385, # 2.37 for jsut - key_projection=False, + key_projection=True, value_projection=False, use_memory_mask=True, trainable_positional_encodings=False, @@ -99,7 +99,7 @@ adam_beta1=0.5, adam_beta2=0.9, adam_eps=1e-6, - initial_learning_rate=5e-4, # 0.001, + initial_learning_rate=1e-3, # 0.001, lr_schedule="noam_learning_rate_decay", lr_schedule_kwargs={}, nepochs=2000, diff --git a/presets/deepvoice3_ljspeech_wavenet.json b/presets/deepvoice3_ljspeech_wavenet.json index 637546e1..38c757cf 100644 --- a/presets/deepvoice3_ljspeech_wavenet.json +++ b/presets/deepvoice3_ljspeech_wavenet.json @@ -11,7 +11,7 @@ "fft_size": 1024, "hop_size": 256, "sample_rate": 22050, - "preemphasis": 0.0, + "preemphasis": 0.97, "min_level_db": -100, "ref_level_db": 20, "rescaling": true, @@ -23,11 +23,11 @@ "speaker_embedding_weight_std": 0.01, "padding_idx": 0, "max_positions": 2048, - "dropout": 0.050000000000000044, - "kernel_size": 3, + "dropout": 0.09999999999999998, + "kernel_size": 5, "text_embed_dim": 256, "encoder_channels": 512, - "decoder_channels": 256, + "decoder_channels": 512, "converter_channels": 256, "query_position_rate": 1.0, "key_position_rate": 1.385, @@ -49,7 +49,7 @@ "adam_beta1": 0.5, "adam_beta2": 0.9, "adam_eps": 1e-06, - "initial_learning_rate": 0.0005, + "initial_learning_rate": 0.001, "lr_schedule": "noam_learning_rate_decay", "lr_schedule_kwargs": {}, "nepochs": 2000, @@ -62,4 +62,4 @@ "window_ahead": 3, "window_backward": 1, "power": 1.4 -} +} \ No newline at end of file diff --git a/synthesis.py b/synthesis.py index fbecdf2d..9a21629b 100644 --- a/synthesis.py +++ b/synthesis.py @@ -9,6 +9,7 @@ --preset= Path of preset parameters (json). --checkpoint-seq2seq= Load seq2seq model from checkpoint path. --checkpoint-postnet= Load postnet model from checkpoint path. + --checkpoint-wavenet= Load WaveNet vocoder. --file-name-suffix= File name suffix [default: ]. --max-decoder-steps= Max decoder steps [default: 500]. --replace_pronunciation_prob= Prob [default: 0.0]. @@ -39,7 +40,7 @@ _frontend = None # to be set later -def tts(model, text, p=0, speaker_id=None, fast=False): +def tts(model, text, p=0, speaker_id=None, fast=False, wavenet=None): """Convert text to speech waveform given a deepvoice3 model. Args: @@ -73,7 +74,30 @@ def tts(model, text, p=0, speaker_id=None, fast=False): mel = audio._denormalize(mel) # Predicted audio signal - waveform = audio.inv_spectrogram(linear_output.T) + if wavenet is not None: + if use_cuda: + wavenet = wavenet.cuda() + wavenet.eval() + if fast: + wavenet.make_generation_fast_() + + # TODO: assuming scalar input + initial_value = 0.0 + initial_input = Variable(torch.zeros(1, 1, 1)).fill_(initial_value) + # (B, T, C) -> (B, C, T) + c = mel_outputs.transpose(1, 2).contiguous() + g = None + Tc = c.size(-1) + length = Tc * 256 + if use_cuda: + initial_input = initial_input.cuda() + c = c.cuda() + waveform = wavenet.incremental_forward( + initial_input, c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True, + log_scale_min=float(np.log(1e-14))) + waveform = waveform.view(-1).cpu().data.numpy() + else: + waveform = audio.inv_spectrogram(linear_output.T) return waveform, alignment, spectrogram, mel @@ -95,6 +119,7 @@ def _load(checkpoint_path): dst_dir = args[""] checkpoint_seq2seq_path = args["--checkpoint-seq2seq"] checkpoint_postnet_path = args["--checkpoint-postnet"] + checkpoint_wavenet_path = args["--checkpoint-wavenet"] max_decoder_steps = int(args["--max-decoder-steps"]) file_name_suffix = args["--file-name-suffix"] replace_pronunciation_prob = float(args["--replace_pronunciation_prob"]) @@ -132,6 +157,19 @@ def _load(checkpoint_path): model.load_state_dict(checkpoint["state_dict"]) checkpoint_name = splitext(basename(checkpoint_path))[0] + # Load WaveNet vocoder + if checkpoint_wavenet_path is not None: + from wavenet_vocoder import builder + wavenet = builder.wavenet(out_channels=3 * 10, layers=24, stacks=4, residual_channels=512, + gate_channels=512, skip_out_channels=256, dropout=1 - 0.95, + kernel_size=3, weight_normalization=True, cin_channels=80, + upsample_conditional_features=True, upsample_scales=[4, 4, 4, 4], + freq_axis_kernel_size=3, gin_channels=-1, scalar_input=True) + checkpoint = torch.load(checkpoint_wavenet_path) + wavenet.load_state_dict(checkpoint["state_dict"]) + else: + wavenet = None + model.seq2seq.decoder.max_decoder_steps = max_decoder_steps os.makedirs(dst_dir, exist_ok=True) @@ -141,7 +179,8 @@ def _load(checkpoint_path): text = line.decode("utf-8")[:-1] words = nltk.word_tokenize(text) waveform, alignment, _, _ = tts( - model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True) + model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True, + wavenet=wavenet) dst_wav_path = join(dst_dir, "{}_{}{}.wav".format( idx, checkpoint_name, file_name_suffix)) dst_alignment_path = join(